MLIR  21.0.0git
Shape.cpp
Go to the documentation of this file.
1 //===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
12 
17 #include "mlir/Dialect/Traits.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
27 #include "llvm/ADT/SetOperations.h"
28 #include "llvm/ADT/SmallString.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Support/raw_ostream.h"
31 
32 using namespace mlir;
33 using namespace mlir::shape;
34 
35 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
36 
37 namespace {
38 #include "ShapeCanonicalization.inc"
39 } // namespace
40 
41 RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
42  return RankedTensorType::get({rank}, IndexType::get(ctx));
43 }
44 
46  auto ranked = llvm::dyn_cast<RankedTensorType>(type);
47  return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
48 }
49 
50 LogicalResult shape::getShapeVec(Value input,
51  SmallVectorImpl<int64_t> &shapeValues) {
52  if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
53  auto type = llvm::cast<ShapedType>(inputOp.getArg().getType());
54  if (!type.hasRank())
55  return failure();
56  llvm::append_range(shapeValues, type.getShape());
57  return success();
58  }
60  if (matchPattern(input, m_Constant(&attr))) {
61  llvm::append_range(shapeValues, attr.getValues<int64_t>());
62  return success();
63  }
64  return failure();
65 }
66 
67 static bool isErrorPropagationPossible(TypeRange operandTypes) {
68  return llvm::any_of(operandTypes,
69  llvm::IsaPred<SizeType, ShapeType, ValueShapeType>);
70 }
71 
72 static LogicalResult verifySizeOrIndexOp(Operation *op) {
73  assert(op != nullptr && op->getNumResults() == 1);
74  Type resultTy = op->getResultTypes().front();
76  if (!llvm::isa<SizeType>(resultTy))
77  return op->emitOpError()
78  << "if at least one of the operands can hold error values then "
79  "the result must be of type `size` to propagate them";
80  }
81  return success();
82 }
83 
84 static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
85  assert(op != nullptr && op->getNumResults() == 1);
86  Type resultTy = op->getResultTypes().front();
88  if (!llvm::isa<ShapeType>(resultTy))
89  return op->emitOpError()
90  << "if at least one of the operands can hold error values then "
91  "the result must be of type `shape` to propagate them";
92  }
93  return success();
94 }
95 
96 template <typename... Ty>
97 static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
98  return typeRange.size() == 1 && llvm::isa<Ty...>(typeRange.front());
99 }
100 
101 template <typename... Ty, typename... ranges>
102 static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
103  return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // InlinerInterface
108 //===----------------------------------------------------------------------===//
109 
110 namespace {
111 /// This class defines the interface for inlining shape dialect ops.
112 struct ShapeInlinerInterface : public DialectInlinerInterface {
114 
115  // Returns true if the given region 'src' can be inlined into the region
116  // 'dest' that is attached to an operation registered to the current dialect.
117  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
118  IRMapping &) const final {
119  return true;
120  }
121 
122  // Returns true if the given operation 'op', that is registered to this
123  // dialect, can be inlined into the region 'dest' that is attached to an
124  // operation registered to the current dialect.
125  bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
126  IRMapping &) const final {
127  return true;
128  }
129 };
130 } // namespace
131 
132 void ShapeDialect::initialize() {
133  addOperations<
134 #define GET_OP_LIST
135 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
136  >();
137  addTypes<
138 #define GET_TYPEDEF_LIST
139 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
140  >();
141  addInterfaces<ShapeInlinerInterface>();
142  // Allow unknown operations during prototyping and testing. As the dialect is
143  // still evolving it makes it simple to start with an unregistered ops and
144  // try different variants before actually defining the op.
145  allowUnknownOperations();
146  declarePromisedInterfaces<bufferization::BufferizableOpInterface, AssumingOp,
147  AssumingYieldOp>();
148 }
149 
151  Attribute value, Type type,
152  Location loc) {
153  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
154  return builder.create<ub::PoisonOp>(loc, type, poison);
155 
156  if (llvm::isa<ShapeType>(type) || isExtentTensorType(type))
157  return builder.create<ConstShapeOp>(
158  loc, type, llvm::cast<DenseIntElementsAttr>(value));
159  if (llvm::isa<SizeType>(type))
160  return builder.create<ConstSizeOp>(loc, type,
161  llvm::cast<IntegerAttr>(value));
162  if (llvm::isa<WitnessType>(type))
163  return builder.create<ConstWitnessOp>(loc, type,
164  llvm::cast<BoolAttr>(value));
165 
166  return arith::ConstantOp::materialize(builder, value, type, loc);
167 }
168 
169 LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
170  NamedAttribute attribute) {
171  // Verify shape.lib attribute.
172  if (attribute.getName() == "shape.lib") {
173  if (!op->hasTrait<OpTrait::SymbolTable>())
174  return op->emitError(
175  "shape.lib attribute may only be on op implementing SymbolTable");
176 
177  if (auto symbolRef = llvm::dyn_cast<SymbolRefAttr>(attribute.getValue())) {
178  auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
179  if (!symbol)
180  return op->emitError("shape function library ")
181  << symbolRef << " not found";
182  return isa<shape::FunctionLibraryOp>(symbol)
183  ? success()
184  : op->emitError()
185  << symbolRef << " required to be shape function library";
186  }
187 
188  if (auto arr = llvm::dyn_cast<ArrayAttr>(attribute.getValue())) {
189  // Verify all entries are function libraries and mappings in libraries
190  // refer to unique ops.
192  for (auto it : arr) {
193  if (!llvm::isa<SymbolRefAttr>(it))
194  return op->emitError(
195  "only SymbolRefAttr allowed in shape.lib attribute array");
196 
197  auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
198  SymbolTable::lookupSymbolIn(op, llvm::cast<SymbolRefAttr>(it)));
199  if (!shapeFnLib)
200  return op->emitError()
201  << it << " does not refer to FunctionLibraryOp";
202  for (auto mapping : shapeFnLib.getMapping()) {
203  if (!key.insert(mapping.getName()).second) {
204  return op->emitError("only one op to shape mapping allowed, found "
205  "multiple for `")
206  << mapping.getName() << "`";
207  }
208  }
209  }
210  return success();
211  }
212 
213  return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
214  "allowed as shape.lib attribute");
215  }
216  return success();
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // AnyOp
221 //===----------------------------------------------------------------------===//
222 
223 // TODO: Canonicalization should be implemented for shapes that can be
224 // determined through mixtures of the known dimensions of the inputs.
225 OpFoldResult AnyOp::fold(FoldAdaptor adaptor) {
226  // Only the last operand is checked because AnyOp is commutative.
227  if (adaptor.getInputs().back())
228  return adaptor.getInputs().back();
229 
230  return nullptr;
231 }
232 
233 //===----------------------------------------------------------------------===//
234 // AssumingOp
235 //===----------------------------------------------------------------------===//
236 
237 ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) {
238  result.regions.reserve(1);
239  Region *doRegion = result.addRegion();
240 
241  auto &builder = parser.getBuilder();
243  if (parser.parseOperand(cond) ||
244  parser.resolveOperand(cond, builder.getType<WitnessType>(),
245  result.operands))
246  return failure();
247 
248  // Parse optional results type list.
249  if (parser.parseOptionalArrowTypeList(result.types))
250  return failure();
251 
252  // Parse the region and add a terminator if elided.
253  if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
254  return failure();
255  AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
256 
257  // Parse the optional attribute list.
258  if (parser.parseOptionalAttrDict(result.attributes))
259  return failure();
260  return success();
261 }
262 
264  bool yieldsResults = !getResults().empty();
265 
266  p << " " << getWitness();
267  if (yieldsResults)
268  p << " -> (" << getResultTypes() << ")";
269  p << ' ';
270  p.printRegion(getDoRegion(),
271  /*printEntryBlockArgs=*/false,
272  /*printBlockTerminators=*/yieldsResults);
273  p.printOptionalAttrDict((*this)->getAttrs());
274 }
275 
276 namespace {
277 // Removes AssumingOp with a passing witness and inlines the region.
278 struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
280 
281  LogicalResult matchAndRewrite(AssumingOp op,
282  PatternRewriter &rewriter) const override {
283  auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
284  if (!witness || !witness.getPassingAttr())
285  return failure();
286 
287  AssumingOp::inlineRegionIntoParent(op, rewriter);
288  return success();
289  }
290 };
291 
292 struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
294 
295  LogicalResult matchAndRewrite(AssumingOp op,
296  PatternRewriter &rewriter) const override {
297  Block *body = op.getBody();
298  auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
299 
300  // Find used values.
301  SmallVector<Value, 4> newYieldOperands;
302  for (auto [opResult, yieldOperand] :
303  llvm::zip(op.getResults(), yieldOp.getOperands())) {
304  if (!opResult.getUses().empty()) {
305  newYieldOperands.push_back(yieldOperand);
306  }
307  }
308 
309  // Rewrite only if redundant results exist.
310  if (newYieldOperands.size() == yieldOp->getNumOperands())
311  return failure();
312 
313  // Replace yield op in the old assuming op's body and move the entire region
314  // to the new assuming op.
315  rewriter.setInsertionPointToEnd(body);
316  auto newYieldOp =
317  rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
318  rewriter.setInsertionPoint(op);
319  auto newOp = rewriter.create<AssumingOp>(
320  op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
321  newOp.getDoRegion().takeBody(op.getDoRegion());
322 
323  // Use the new results to replace the previously used ones.
324  SmallVector<Value, 4> replacementValues;
325  auto src = newOp.getResults().begin();
326  for (auto it : op.getResults()) {
327  if (it.getUses().empty())
328  replacementValues.push_back(nullptr);
329  else
330  replacementValues.push_back(*src++);
331  }
332  rewriter.replaceOp(op, replacementValues);
333  return success();
334  }
335 };
336 } // namespace
337 
338 void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
339  MLIRContext *context) {
340  patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
341 }
342 
343 // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
344 void AssumingOp::getSuccessorRegions(
346  // AssumingOp has unconditional control flow into the region and back to the
347  // parent, so return the correct RegionSuccessor purely based on the index
348  // being None or 0.
349  if (!point.isParent()) {
350  regions.push_back(RegionSuccessor(getResults()));
351  return;
352  }
353 
354  regions.push_back(RegionSuccessor(&getDoRegion()));
355 }
356 
357 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
358  PatternRewriter &rewriter) {
359  auto *blockBeforeAssuming = rewriter.getInsertionBlock();
360  auto *assumingBlock = op.getBody();
361  auto initPosition = rewriter.getInsertionPoint();
362  auto *blockAfterAssuming =
363  rewriter.splitBlock(blockBeforeAssuming, initPosition);
364 
365  // Remove the AssumingOp and AssumingYieldOp.
366  auto &yieldOp = assumingBlock->back();
367  rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
368  rewriter.replaceOp(op, yieldOp.getOperands());
369  rewriter.eraseOp(&yieldOp);
370 
371  // Merge blocks together as there was no branching behavior from the
372  // AssumingOp.
373  rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
374  rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
375 }
376 
377 void AssumingOp::build(
378  OpBuilder &builder, OperationState &result, Value witness,
380  OpBuilder::InsertionGuard g(builder);
381 
382  result.addOperands(witness);
383  Region *bodyRegion = result.addRegion();
384  builder.createBlock(bodyRegion);
385 
386  // Build body.
387  SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
388  builder.create<AssumingYieldOp>(result.location, yieldValues);
389 
390  SmallVector<Type, 2> assumingTypes;
391  for (Value v : yieldValues)
392  assumingTypes.push_back(v.getType());
393  result.addTypes(assumingTypes);
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // AddOp
398 //===----------------------------------------------------------------------===//
399 
400 LogicalResult mlir::shape::AddOp::inferReturnTypes(
401  MLIRContext *context, std::optional<Location> location,
402  AddOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
403  if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
404  llvm::isa<SizeType>(adaptor.getRhs().getType()))
405  inferredReturnTypes.assign({SizeType::get(context)});
406  else
407  inferredReturnTypes.assign({IndexType::get(context)});
408  return success();
409 }
410 
411 bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
412  // SizeType is compatible with IndexType.
413  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
414 }
415 
416 OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) {
417  // add(x, 0) -> x
418  if (matchPattern(getRhs(), m_Zero()))
419  return getLhs();
420 
421  return constFoldBinaryOp<IntegerAttr>(
422  adaptor.getOperands(),
423  [](APInt a, const APInt &b) { return std::move(a) + b; });
424 }
425 
426 LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
427 
428 //===----------------------------------------------------------------------===//
429 // AssumingAllOp
430 //===----------------------------------------------------------------------===//
431 
432 namespace {
433 
434 // Merge multiple `shape.assuming_all` operations together.
435 //
436 // %0 = shape.assuming_all %w0, %w1
437 // %1 = shape.assuming_all %w2, %0
438 //
439 // to:
440 //
441 // %0 = shape.assuming_all %w0, %w2, %w2
442 struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
444 
445  LogicalResult matchAndRewrite(AssumingAllOp op,
446  PatternRewriter &rewriter) const override {
447  SmallVector<Value> operands;
448 
449  for (Value operand : op.getInputs()) {
450  if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
451  operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
452  else
453  operands.push_back(operand);
454  }
455 
456  // We didn't find any other `assuming_all` ops to merge with.
457  if (operands.size() == op.getNumOperands())
458  return failure();
459 
460  // Replace with a new `assuming_all` operation with merged constraints.
461  rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
462  return success();
463  }
464 };
465 
466 // Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
467 // are subsumed by others.
468 //
469 // %0 = shape.cstr_broadcastable %shape0, %shape1
470 // %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
471 //
472 // %2 = shape.cstr_broadcastable %shape3, %shape4
473 // %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
474 //
475 // %4 = shape.assuming_all %0, %1, %2, %3
476 //
477 // to:
478 //
479 // %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
480 // %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
481 // %2 = shape.assuming_all %0, %1
482 //
483 // In this example if shapes [0, 1, 2] are broadcastable, then it means that
484 // shapes [0, 1] are broadcastable too, and can be removed from the list of
485 // constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
486 // matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
487 struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
489 
490  LogicalResult matchAndRewrite(AssumingAllOp op,
491  PatternRewriter &rewriter) const override {
492  // Collect all `CstrBroadcastableOp` operands first.
494  for (Value operand : op.getInputs()) {
495  // TODO: Apply this optimization if some of the witnesses are not
496  // produced by the `cstr_broadcastable`.
497  auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
498  if (!broadcastable)
499  return failure();
500 
501  operands.insert(broadcastable);
502  }
503 
504  // Skip trivial `assuming_all` operations.
505  if (operands.size() <= 1)
506  return failure();
507 
508  // Collect shapes checked by `cstr_broadcastable` operands.
510  for (auto cstr : operands) {
511  DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end());
512  shapes.emplace_back(cstr, std::move(shapesSet));
513  }
514 
515  // Sort by the number of shape operands (larger to smaller).
516  llvm::sort(shapes, [](auto a, auto b) {
517  return a.first.getNumOperands() > b.first.getNumOperands();
518  });
519 
520  // We start from the `cst_broadcastable` operations with largest number of
521  // shape operands, and remove redundant `cst_broadcastable` operations. We
522  // do this until we find a set of `cst_broadcastable` operations with
523  // non-overlapping constraints.
524  SmallVector<CstrBroadcastableOp> markedForErase;
525 
526  for (unsigned i = 0; i < shapes.size(); ++i) {
527  auto isSubset = [&](auto pair) {
528  return llvm::set_is_subset(pair.second, shapes[i].second);
529  };
530 
531  // Keep redundant `cstr_broadcastable` operations to be erased.
532  auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
533  for (auto *it0 = it; it0 < shapes.end(); ++it0)
534  markedForErase.push_back(it0->first);
535  shapes.erase(it, shapes.end());
536  }
537 
538  // We didn't find any operands that could be removed.
539  if (markedForErase.empty())
540  return failure();
541 
542  // Collect non-overlapping `cst_broadcastable` constraints.
543  SmallVector<Value> uniqueConstraints;
544  for (auto &shape : shapes)
545  uniqueConstraints.push_back(shape.first.getResult());
546 
547  // Replace with a new `assuming_all` operation ...
548  rewriter.replaceOpWithNewOp<AssumingAllOp>(op, uniqueConstraints);
549 
550  // ... and maybe erase `cstr_broadcastable` ops without uses.
551  for (auto &op : markedForErase)
552  if (op->use_empty())
553  rewriter.eraseOp(op);
554 
555  return success();
556  }
557 };
558 
559 struct AssumingAllToCstrEqCanonicalization
560  : public OpRewritePattern<AssumingAllOp> {
562 
563  LogicalResult matchAndRewrite(AssumingAllOp op,
564  PatternRewriter &rewriter) const override {
565  SmallVector<Value, 8> shapes;
566  for (Value w : op.getInputs()) {
567  auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
568  if (!cstrEqOp)
569  return failure();
570  bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
571  return llvm::is_contained(shapes, s);
572  });
573  if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
574  return failure();
575  shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
576  }
577  rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
578  return success();
579  }
580 };
581 
582 template <typename OpTy>
583 struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
585 
586  LogicalResult matchAndRewrite(OpTy op,
587  PatternRewriter &rewriter) const override {
588  // Find unique operands.
589  SetVector<Value> unique(op.operand_begin(), op.operand_end());
590 
591  // Reduce op to equivalent with unique operands.
592  if (unique.size() < op.getNumOperands()) {
593  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
594  unique.takeVector(), op->getAttrs());
595  return success();
596  }
597 
598  return failure();
599  }
600 };
601 } // namespace
602 
603 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
604  MLIRContext *context) {
605  patterns
606  .add<MergeAssumingAllOps, AssumingAllOneOp,
607  AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
608  RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
609 }
610 
611 OpFoldResult AssumingAllOp::fold(FoldAdaptor adaptor) {
612  // Iterate in reverse to first handle all constant operands. They are
613  // guaranteed to be the tail of the inputs because this is commutative.
614  for (int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) {
615  Attribute a = adaptor.getInputs()[idx];
616  // Cannot fold if any inputs are not constant;
617  if (!a)
618  return nullptr;
619 
620  // We do not need to keep statically known values after handling them in
621  // this method.
622  getOperation()->eraseOperand(idx);
623 
624  // Always false if any input is statically known false
625  if (!llvm::cast<BoolAttr>(a).getValue())
626  return a;
627  }
628  // If this is reached, all inputs were statically known passing.
629  return BoolAttr::get(getContext(), true);
630 }
631 
632 LogicalResult AssumingAllOp::verify() {
633  // Ensure that AssumingAllOp contains at least one operand
634  if (getNumOperands() == 0)
635  return emitOpError("no operands specified");
636 
637  return success();
638 }
639 
640 //===----------------------------------------------------------------------===//
641 // BroadcastOp
642 //===----------------------------------------------------------------------===//
643 
644 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
645  if (getShapes().size() == 1) {
646  // Otherwise, we need a cast which would be a canonicalization, not folding.
647  if (getShapes().front().getType() != getType())
648  return nullptr;
649  return getShapes().front();
650  }
651 
652  if (!adaptor.getShapes().front())
653  return nullptr;
654 
655  SmallVector<int64_t, 6> resultShape(
656  llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
657  .getValues<int64_t>());
658 
659  for (auto next : adaptor.getShapes().drop_front()) {
660  if (!next)
661  return nullptr;
662  auto nextShape = llvm::to_vector<6>(
663  llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>());
664 
665  SmallVector<int64_t, 6> tmpShape;
666  // If the shapes are not compatible, we can't fold it.
667  // TODO: Fold to an "error".
668  if (!OpTrait::util::getBroadcastedShape(resultShape, nextShape, tmpShape))
669  return nullptr;
670 
671  resultShape.clear();
672  std::copy(tmpShape.begin(), tmpShape.end(),
673  std::back_inserter(resultShape));
674  }
675 
676  Builder builder(getContext());
677  return builder.getIndexTensorAttr(resultShape);
678 }
679 
680 LogicalResult BroadcastOp::verify() {
681  return verifyShapeOrExtentTensorOp(*this);
682 }
683 
684 namespace {
685 template <typename OpTy>
686 struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
688 
689  LogicalResult matchAndRewrite(OpTy op,
690  PatternRewriter &rewriter) const override {
691  auto isPotentiallyNonEmptyShape = [](Value shape) {
692  if (auto extentTensorTy =
693  llvm::dyn_cast<RankedTensorType>(shape.getType())) {
694  if (extentTensorTy.getDimSize(0) == 0)
695  return false;
696  }
697  if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
698  if (constShape.getShape().empty())
699  return false;
700  }
701  return true;
702  };
703  auto newOperands = llvm::filter_to_vector<8>(op->getOperands(),
704  isPotentiallyNonEmptyShape);
705 
706  // Replace the op with empty shape constant if all operants are reduced to
707  // be empty.
708  if (newOperands.empty()) {
709  rewriter.replaceOpWithNewOp<ConstShapeOp>(
710  op, op->getResultTypes().front(), rewriter.getIndexTensorAttr({}));
711  return success();
712  }
713 
714  // Reduce op to equivalent without empty shape operands.
715  if (newOperands.size() < op.getNumOperands()) {
716  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
717  op->getAttrs());
718  return success();
719  }
720 
721  return failure();
722  }
723 };
724 
725 struct BroadcastForwardSingleOperandPattern
726  : public OpRewritePattern<BroadcastOp> {
728 
729  LogicalResult matchAndRewrite(BroadcastOp op,
730  PatternRewriter &rewriter) const override {
731  if (op.getNumOperands() != 1)
732  return failure();
733  Value replacement = op.getShapes().front();
734 
735  // Insert cast if needed.
736  if (replacement.getType() != op.getType()) {
737  auto loc = op.getLoc();
738  if (llvm::isa<ShapeType>(op.getType())) {
739  replacement = rewriter.create<FromExtentTensorOp>(loc, replacement);
740  } else {
741  assert(!llvm::isa<ShapeType>(op.getType()) &&
742  !llvm::isa<ShapeType>(replacement.getType()) &&
743  "expect extent tensor cast");
744  replacement =
745  rewriter.create<tensor::CastOp>(loc, op.getType(), replacement);
746  }
747  }
748 
749  rewriter.replaceOp(op, replacement);
750  return success();
751  }
752 };
753 
754 struct BroadcastFoldConstantOperandsPattern
755  : public OpRewritePattern<BroadcastOp> {
757 
758  LogicalResult matchAndRewrite(BroadcastOp op,
759  PatternRewriter &rewriter) const override {
760  SmallVector<int64_t, 8> foldedConstantShape;
761  SmallVector<Value, 8> newShapeOperands;
762  for (Value shape : op.getShapes()) {
763  if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
764  SmallVector<int64_t, 8> newFoldedConstantShape;
766  foldedConstantShape,
767  llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
768  newFoldedConstantShape)) {
769  foldedConstantShape = newFoldedConstantShape;
770  continue;
771  }
772  }
773  newShapeOperands.push_back(shape);
774  }
775 
776  // Need at least two constant operands to fold anything.
777  if (op.getNumOperands() - newShapeOperands.size() < 2)
778  return failure();
779 
780  auto foldedConstantOperandsTy = RankedTensorType::get(
781  {static_cast<int64_t>(foldedConstantShape.size())},
782  rewriter.getIndexType());
783  newShapeOperands.push_back(rewriter.create<ConstShapeOp>(
784  op.getLoc(), foldedConstantOperandsTy,
785  rewriter.getIndexTensorAttr(foldedConstantShape)));
786  rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
787  newShapeOperands);
788  return success();
789  }
790 };
791 
792 template <typename OpTy>
793 struct CanonicalizeCastExtentTensorOperandsPattern
794  : public OpRewritePattern<OpTy> {
796 
797  LogicalResult matchAndRewrite(OpTy op,
798  PatternRewriter &rewriter) const override {
799  // Canonicalize operands.
800  bool anyChange = false;
801  auto canonicalizeOperand = [&](Value operand) -> Value {
802  if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
803  // Only eliminate the cast if it holds no shape information.
804  bool isInformationLoosingCast =
805  llvm::cast<RankedTensorType>(castOp.getType()).isDynamicDim(0);
806  if (isInformationLoosingCast) {
807  anyChange = true;
808  return castOp.getSource();
809  }
810  }
811  return operand;
812  };
813  auto newOperands = llvm::to_vector<8>(
814  llvm::map_range(op.getOperands(), canonicalizeOperand));
815 
816  // Rewrite op if any change required.
817  if (!anyChange)
818  return failure();
819  rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
820  return success();
821  }
822 };
823 
824 struct BroadcastConcretizeResultTypePattern
825  : public OpRewritePattern<BroadcastOp> {
827 
828  LogicalResult matchAndRewrite(BroadcastOp op,
829  PatternRewriter &rewriter) const override {
830  // Only concretize dynamic extent tensor result types.
831  auto resultTy = llvm::dyn_cast<RankedTensorType>(op.getType());
832  if (!resultTy || !resultTy.isDynamicDim(0))
833  return failure();
834 
835  // Infer resulting shape rank if possible.
836  int64_t maxRank = 0;
837  for (Value shape : op.getShapes()) {
838  if (auto extentTensorTy =
839  llvm::dyn_cast<RankedTensorType>(shape.getType())) {
840  // Cannot infer resulting shape rank if any operand is dynamically
841  // ranked.
842  if (extentTensorTy.isDynamicDim(0))
843  return failure();
844  maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
845  }
846  }
847 
848  auto newOp = rewriter.create<BroadcastOp>(
849  op.getLoc(), getExtentTensorType(getContext(), maxRank),
850  op.getShapes());
851  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
852  return success();
853  }
854 };
855 } // namespace
856 
857 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
858  MLIRContext *context) {
859  patterns.add<BroadcastConcretizeResultTypePattern,
860  BroadcastFoldConstantOperandsPattern,
861  BroadcastForwardSingleOperandPattern,
862  CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
863  RemoveDuplicateOperandsPattern<BroadcastOp>,
864  RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
865 }
866 
867 //===----------------------------------------------------------------------===//
868 // ConcatOp
869 //===----------------------------------------------------------------------===//
870 
871 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
872  if (!adaptor.getLhs() || !adaptor.getRhs())
873  return nullptr;
874  auto lhsShape = llvm::to_vector<6>(
875  llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<int64_t>());
876  auto rhsShape = llvm::to_vector<6>(
877  llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<int64_t>());
878  SmallVector<int64_t, 6> resultShape;
879  resultShape.append(lhsShape.begin(), lhsShape.end());
880  resultShape.append(rhsShape.begin(), rhsShape.end());
881  Builder builder(getContext());
882  return builder.getIndexTensorAttr(resultShape);
883 }
884 
885 //===----------------------------------------------------------------------===//
886 // ConstShapeOp
887 //===----------------------------------------------------------------------===//
888 
890  p << " ";
891  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"});
892  p << "[";
893  interleaveComma(getShape().getValues<int64_t>(), p);
894  p << "] : ";
895  p.printType(getType());
896 }
897 
898 ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
899  if (parser.parseOptionalAttrDict(result.attributes))
900  return failure();
901  // We piggy-back on ArrayAttr parsing, though we don't internally store the
902  // shape as an ArrayAttr.
903  // TODO: Implement custom parser and maybe make syntax a bit more concise.
904  Attribute extentsRaw;
905  NamedAttrList dummy;
906  if (parser.parseAttribute(extentsRaw, "dummy", dummy))
907  return failure();
908  auto extentsArray = llvm::dyn_cast<ArrayAttr>(extentsRaw);
909  if (!extentsArray)
910  return failure();
912  for (Attribute extent : extentsArray) {
913  IntegerAttr attr = llvm::dyn_cast<IntegerAttr>(extent);
914  if (!attr)
915  return failure();
916  ints.push_back(attr.getInt());
917  }
918  Builder &builder = parser.getBuilder();
919  result.addAttribute("shape", builder.getIndexTensorAttr(ints));
920  Type resultTy;
921  if (parser.parseColonType(resultTy))
922  return failure();
923  result.types.push_back(resultTy);
924  return success();
925 }
926 
927 OpFoldResult ConstShapeOp::fold(FoldAdaptor) { return getShapeAttr(); }
928 
929 void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
930  MLIRContext *context) {
931  patterns.add<TensorCastConstShape>(context);
932 }
933 
934 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
935  MLIRContext *context, std::optional<Location> location,
936  ConstShapeOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
937  Builder b(context);
938  const Properties prop = adaptor.getProperties();
939  inferredReturnTypes.assign({RankedTensorType::get(
940  {static_cast<int64_t>(prop.shape.size())}, b.getIndexType())});
941  return success();
942 }
943 
944 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
945  TypeRange r) {
946  if (l.size() != 1 || r.size() != 1)
947  return false;
948 
949  Type lhs = l.front();
950  Type rhs = r.front();
951 
952  if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
953  // Shape type is compatible with all other valid return types.
954  return true;
955  return lhs == rhs;
956 }
957 
958 //===----------------------------------------------------------------------===//
959 // CstrBroadcastableOp
960 //===----------------------------------------------------------------------===//
961 
962 void CstrBroadcastableOp::getCanonicalizationPatterns(
964  // Canonicalization patterns have overlap with the considerations during
965  // folding in case additional shape information is inferred at some point that
966  // does not result in folding.
967  patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
968  CstrBroadcastableEqOps,
969  RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
970  RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
971 }
972 
973 // Return true if there is exactly one attribute not representing a scalar
974 // broadcast.
976  bool nonScalarSeen = false;
977  for (Attribute a : attributes) {
978  if (!a || llvm::cast<DenseIntElementsAttr>(a).getNumElements() != 0) {
979  if (nonScalarSeen)
980  return false;
981  nonScalarSeen = true;
982  }
983  }
984  return true;
985 }
986 
987 OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
988  // No broadcasting is needed if all operands but one are scalar.
989  if (hasAtMostSingleNonScalar(adaptor.getShapes()))
990  return BoolAttr::get(getContext(), true);
991 
992  if ([&] {
994  for (const auto &operand : adaptor.getShapes()) {
995  if (!operand)
996  return false;
997  extents.push_back(llvm::to_vector<6>(
998  llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>()));
999  }
1001  }())
1002  return BoolAttr::get(getContext(), true);
1003 
1004  // Lastly, see if folding can be completed based on what constraints are known
1005  // on the input shapes.
1006  if ([&] {
1008  for (auto shapeValue : getShapes()) {
1009  extents.emplace_back();
1010  if (failed(getShapeVec(shapeValue, extents.back())))
1011  return false;
1012  }
1014  }())
1015  return BoolAttr::get(getContext(), true);
1016 
1017  // Because a failing witness result here represents an eventual assertion
1018  // failure, we do not replace it with a constant witness.
1019  return nullptr;
1020 }
1021 
1022 LogicalResult CstrBroadcastableOp::verify() {
1023  // Ensure that CstrBroadcastableOp contains at least two operands
1024  if (getNumOperands() < 2)
1025  return emitOpError("required at least 2 input shapes");
1026  return success();
1027 }
1028 
1029 //===----------------------------------------------------------------------===//
1030 // CstrEqOp
1031 //===----------------------------------------------------------------------===//
1032 
1033 void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1034  MLIRContext *context) {
1035  // If inputs are equal, return passing witness
1036  patterns.add<CstrEqEqOps>(context);
1037 }
1038 
1039 OpFoldResult CstrEqOp::fold(FoldAdaptor adaptor) {
1040  if (llvm::all_of(adaptor.getShapes(), [&](Attribute a) {
1041  return a && a == adaptor.getShapes().front();
1042  }))
1043  return BoolAttr::get(getContext(), true);
1044 
1045  // Because a failing witness result here represents an eventual assertion
1046  // failure, we do not try to replace it with a constant witness. Similarly, we
1047  // cannot if there are any non-const inputs.
1048  return nullptr;
1049 }
1050 
1051 //===----------------------------------------------------------------------===//
1052 // ConstSizeOp
1053 //===----------------------------------------------------------------------===//
1054 
1055 void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
1056  int64_t value) {
1057  build(builder, result, builder.getIndexAttr(value));
1058 }
1059 
1060 OpFoldResult ConstSizeOp::fold(FoldAdaptor) { return getValueAttr(); }
1061 
1062 void ConstSizeOp::getAsmResultNames(
1063  llvm::function_ref<void(Value, StringRef)> setNameFn) {
1064  SmallString<4> buffer;
1065  llvm::raw_svector_ostream os(buffer);
1066  os << "c" << getValue();
1067  setNameFn(getResult(), os.str());
1068 }
1069 
1070 //===----------------------------------------------------------------------===//
1071 // ConstWitnessOp
1072 //===----------------------------------------------------------------------===//
1073 
1074 OpFoldResult ConstWitnessOp::fold(FoldAdaptor) { return getPassingAttr(); }
1075 
1076 //===----------------------------------------------------------------------===//
1077 // CstrRequireOp
1078 //===----------------------------------------------------------------------===//
1079 
1080 OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) {
1081  return adaptor.getPred();
1082 }
1083 
1084 //===----------------------------------------------------------------------===//
1085 // DimOp
1086 //===----------------------------------------------------------------------===//
1087 
1088 std::optional<int64_t> DimOp::getConstantIndex() {
1089  if (auto constSizeOp = getIndex().getDefiningOp<ConstSizeOp>())
1090  return constSizeOp.getValue().getLimitedValue();
1091  if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
1092  return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1093  return std::nullopt;
1094 }
1095 
1096 OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1097  Type valType = getValue().getType();
1098  auto valShapedType = llvm::dyn_cast<ShapedType>(valType);
1099  if (!valShapedType || !valShapedType.hasRank())
1100  return nullptr;
1101  std::optional<int64_t> index = getConstantIndex();
1102  if (!index.has_value())
1103  return nullptr;
1104  if (index.value() < 0 || index.value() >= valShapedType.getRank())
1105  return nullptr;
1106  auto extent = valShapedType.getDimSize(*index);
1107  if (ShapedType::isDynamic(extent))
1108  return nullptr;
1109  return IntegerAttr::get(IndexType::get(getContext()), extent);
1110 }
1111 
1112 LogicalResult mlir::shape::DimOp::inferReturnTypes(
1113  MLIRContext *context, std::optional<Location> location,
1114  DimOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1115  inferredReturnTypes.assign({adaptor.getIndex().getType()});
1116  return success();
1117 }
1118 
1119 bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1120  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1121 }
1122 
1123 //===----------------------------------------------------------------------===//
1124 // DivOp
1125 //===----------------------------------------------------------------------===//
1126 
1127 OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
1128  auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1129  if (!lhs)
1130  return nullptr;
1131  auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1132  if (!rhs || rhs.getValue().isZero())
1133  return nullptr;
1134 
1135  // Division in APInt does not follow floor(lhs, rhs) when the result is
1136  // negative. Rather, APInt rounds toward zero.
1137  APInt quotient, remainder;
1138  APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1139  if (quotient.isNegative() && !remainder.isZero()) {
1140  quotient -= 1;
1141  }
1142 
1143  Type indexTy = IndexType::get(getContext());
1144  return IntegerAttr::get(indexTy, quotient);
1145 }
1146 
1147 LogicalResult mlir::shape::DivOp::inferReturnTypes(
1148  MLIRContext *context, std::optional<Location> location,
1149  DivOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1150  if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1151  llvm::isa<SizeType>(adaptor.getRhs().getType()))
1152  inferredReturnTypes.assign({SizeType::get(context)});
1153  else
1154  inferredReturnTypes.assign({IndexType::get(context)});
1155  return success();
1156 }
1157 
1158 bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1159  // SizeType is compatible with IndexType.
1160  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1161 }
1162 
1163 LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); }
1164 
1165 //===----------------------------------------------------------------------===//
1166 // ShapeEqOp
1167 //===----------------------------------------------------------------------===//
1168 
1169 OpFoldResult ShapeEqOp::fold(FoldAdaptor adaptor) {
1170  bool allSame = true;
1171  if (!adaptor.getShapes().empty() && !adaptor.getShapes().front())
1172  return {};
1173  for (Attribute operand : adaptor.getShapes().drop_front()) {
1174  if (!operand)
1175  return {};
1176  allSame = allSame && operand == adaptor.getShapes().front();
1177  }
1178  return BoolAttr::get(getContext(), allSame);
1179 }
1180 
1181 //===----------------------------------------------------------------------===//
1182 // IndexToSizeOp
1183 //===----------------------------------------------------------------------===//
1184 
1185 OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) {
1186  // Constant values of both types, `shape.size` and `index`, are represented as
1187  // `IntegerAttr`s which makes constant folding simple.
1188  if (Attribute arg = adaptor.getArg())
1189  return arg;
1190  return {};
1191 }
1192 
1193 void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1194  MLIRContext *context) {
1195  patterns.add<SizeToIndexToSizeCanonicalization>(context);
1196 }
1197 
1198 //===----------------------------------------------------------------------===//
1199 // FromExtentsOp
1200 //===----------------------------------------------------------------------===//
1201 
1202 OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) {
1203  if (llvm::any_of(adaptor.getExtents(), [](Attribute a) { return !a; }))
1204  return nullptr;
1205  SmallVector<int64_t, 6> extents;
1206  for (auto attr : adaptor.getExtents())
1207  extents.push_back(llvm::cast<IntegerAttr>(attr).getInt());
1208  Builder builder(getContext());
1209  return builder.getIndexTensorAttr(extents);
1210 }
1211 
1212 //===----------------------------------------------------------------------===//
1213 // FunctionLibraryOp
1214 //===----------------------------------------------------------------------===//
1215 
1216 void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1217  StringRef name) {
1218  result.attributes.push_back(builder.getNamedAttr(
1220 }
1221 
1222 FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1223  auto attr = llvm::dyn_cast_or_null<FlatSymbolRefAttr>(
1224  getMapping().get(op->getName().getIdentifier()));
1225  if (!attr)
1226  return nullptr;
1227  return lookupSymbol<FuncOp>(attr);
1228 }
1229 
1230 ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
1231  OperationState &result) {
1232  // Parse the op name.
1233  StringAttr nameAttr;
1234  if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
1235  result.attributes))
1236  return failure();
1237 
1238  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1239  return failure();
1240 
1241  auto *bodyRegion = result.addRegion();
1242  if (parser.parseRegion(*bodyRegion))
1243  return failure();
1244 
1245  if (parser.parseKeyword("mapping"))
1246  return failure();
1247 
1248  DictionaryAttr mappingAttr;
1249  if (parser.parseAttribute(mappingAttr,
1250  parser.getBuilder().getType<NoneType>(), "mapping",
1251  result.attributes))
1252  return failure();
1253  return success();
1254 }
1255 
1257  p << ' ';
1258  p.printSymbolName(getName());
1260  (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"});
1261  p << ' ';
1262  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1263  /*printBlockTerminators=*/false);
1264  p << " mapping ";
1265  p.printAttributeWithoutType(getMappingAttr());
1266 }
1267 
1268 //===----------------------------------------------------------------------===//
1269 // FuncOp
1270 //===----------------------------------------------------------------------===//
1271 
1272 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1273  ArrayRef<NamedAttribute> attrs) {
1274  OpBuilder builder(location->getContext());
1275  OperationState state(location, getOperationName());
1276  FuncOp::build(builder, state, name, type, attrs);
1277  return cast<FuncOp>(Operation::create(state));
1278 }
1279 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1281  SmallVector<NamedAttribute, 8> attrRef(attrs);
1282  return create(location, name, type, llvm::ArrayRef(attrRef));
1283 }
1284 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1286  ArrayRef<DictionaryAttr> argAttrs) {
1287  FuncOp func = create(location, name, type, attrs);
1288  func.setAllArgAttrs(argAttrs);
1289  return func;
1290 }
1291 
1292 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
1293  FunctionType type, ArrayRef<NamedAttribute> attrs,
1294  ArrayRef<DictionaryAttr> argAttrs) {
1295  state.addAttribute(FuncOp::getSymNameAttrName(state.name),
1296  builder.getStringAttr(name));
1297  state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
1298  TypeAttr::get(type));
1299  state.attributes.append(attrs.begin(), attrs.end());
1300  state.addRegion();
1301 
1302  if (argAttrs.empty())
1303  return;
1304  assert(type.getNumInputs() == argAttrs.size());
1306  builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
1307  getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
1308 }
1309 
1310 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
1311  auto buildFuncType =
1312  [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
1314  std::string &) { return builder.getFunctionType(argTypes, results); };
1315 
1317  parser, result, /*allowVariadic=*/false,
1318  getFunctionTypeAttrName(result.name), buildFuncType,
1319  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1320 }
1321 
1322 void FuncOp::print(OpAsmPrinter &p) {
1324  p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
1325  getArgAttrsAttrName(), getResAttrsAttrName());
1326 }
1327 
1328 //===----------------------------------------------------------------------===//
1329 // GetExtentOp
1330 //===----------------------------------------------------------------------===//
1331 
1332 std::optional<int64_t> GetExtentOp::getConstantDim() {
1333  if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1334  return constSizeOp.getValue().getLimitedValue();
1335  if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1336  return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1337  return std::nullopt;
1338 }
1339 
1340 OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {
1341  auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1342  if (!elements)
1343  return nullptr;
1344  std::optional<int64_t> dim = getConstantDim();
1345  if (!dim.has_value())
1346  return nullptr;
1347  if (dim.value() >= elements.getNumElements())
1348  return nullptr;
1349  return elements.getValues<Attribute>()[(uint64_t)dim.value()];
1350 }
1351 
1352 void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1353  int64_t dim) {
1354  auto loc = result.location;
1355  auto dimAttr = builder.getIndexAttr(dim);
1356  if (llvm::isa<ShapeType>(shape.getType())) {
1357  Value dim = builder.create<ConstSizeOp>(loc, dimAttr);
1358  build(builder, result, builder.getType<SizeType>(), shape, dim);
1359  } else {
1360  Value dim =
1361  builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr);
1362  build(builder, result, builder.getIndexType(), shape, dim);
1363  }
1364 }
1365 
1366 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1367  MLIRContext *context, std::optional<Location> location,
1368  GetExtentOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1369  inferredReturnTypes.assign({IndexType::get(context)});
1370  return success();
1371 }
1372 
1373 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1374  TypeRange r) {
1375  // SizeType is compatible with IndexType.
1376  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1377 }
1378 
1379 LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); }
1380 
1381 //===----------------------------------------------------------------------===//
1382 // IsBroadcastableOp
1383 //===----------------------------------------------------------------------===//
1384 
1385 void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1386  MLIRContext *context) {
1387  patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1388 }
1389 
1390 OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
1391  // Can always broadcast fewer than two shapes.
1392  if (adaptor.getShapes().size() < 2) {
1393  return BoolAttr::get(getContext(), true);
1394  }
1395 
1396  return nullptr;
1397 }
1398 
1399 //===----------------------------------------------------------------------===//
1400 // MeetOp
1401 //===----------------------------------------------------------------------===//
1402 
1403 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1404  MLIRContext *context, std::optional<Location> location,
1405  MeetOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1406  if (adaptor.getOperands().empty())
1407  return failure();
1408 
1409  auto isShapeType = [](Type arg) {
1410  if (llvm::isa<ShapeType>(arg))
1411  return true;
1412  return isExtentTensorType(arg);
1413  };
1414 
1415  ValueRange::type_range types = adaptor.getOperands().getTypes();
1416  Type acc = types.front();
1417  for (auto t : drop_begin(types)) {
1418  Type l = acc, r = t;
1419  if (!llvm::isa<ShapeType, SizeType>(l))
1420  std::swap(l, r);
1421 
1422  // Handle sizes, propagate error type if present.
1423  if (llvm::isa<SizeType>(l)) {
1424  if (llvm::isa<SizeType, IndexType>(r))
1425  acc = l;
1426  else
1427  return emitOptionalError(location, "requires all sizes or shapes");
1428  } else if (llvm::isa<IndexType>(l)) {
1429  if (llvm::isa<IndexType>(r))
1430  acc = r;
1431  else
1432  return emitOptionalError(location, "requires all sizes or shapes");
1433  } else if (llvm::isa<ShapeType>(l)) {
1434  // Handle shapes, propagate error type if present.
1435  if (isShapeType(r))
1436  acc = l;
1437  else
1438  return emitOptionalError(location, "requires all sizes or shapes");
1439  } else if (isExtentTensorType(l)) {
1440  auto rank1 = llvm::cast<RankedTensorType>(l).getShape()[0];
1441  auto rank2 = llvm::cast<RankedTensorType>(r).getShape()[0];
1442  if (ShapedType::isDynamic(rank1))
1443  acc = l;
1444  else if (ShapedType::isDynamic(rank2))
1445  acc = r;
1446  else if (rank1 != rank2)
1447  return emitOptionalError(location, "unequal shape cardinality");
1448  else
1449  acc = l;
1450  }
1451  }
1452  inferredReturnTypes.assign({acc});
1453  return success();
1454 }
1455 
1456 bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1457  if (l.size() != 1 || r.size() != 1)
1458  return false;
1459  if (l == r)
1460  return true;
1461 
1462  Type lhs = l.front();
1463  Type rhs = r.front();
1464 
1465  if (!llvm::isa<ShapeType, SizeType>(lhs))
1466  std::swap(lhs, rhs);
1467 
1468  if (llvm::isa<SizeType>(lhs))
1469  return llvm::isa<SizeType, IndexType>(rhs);
1470  if (llvm::isa<ShapeType>(lhs))
1471  return llvm::isa<ShapeType, TensorType>(rhs);
1472 
1473  if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1474  return true;
1475  return false;
1476 }
1477 
1478 //===----------------------------------------------------------------------===//
1479 // RankOp
1480 //===----------------------------------------------------------------------===//
1481 
1482 OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
1483  auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1484  if (!shape)
1485  return {};
1486  int64_t rank = shape.getNumElements();
1487  Builder builder(getContext());
1488  return builder.getIndexAttr(rank);
1489 }
1490 
1491 /// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1492 /// Constant folding fails in cases where only the rank is constant, not the
1493 /// shape itself.
1494 /// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1495 ///
1496 /// Example:
1497 ///
1498 /// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1499 /// %rank = shape.rank %shape
1500 ///
1501 /// becomes
1502 ///
1503 /// %rank = shape.const_size 3
1504 
1505 namespace {
1506 struct RankShapeOfCanonicalizationPattern
1507  : public OpRewritePattern<shape::RankOp> {
1509 
1510  LogicalResult matchAndRewrite(shape::RankOp op,
1511  PatternRewriter &rewriter) const override {
1512  auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1513  if (!shapeOfOp)
1514  return failure();
1515  auto rankedTensorType =
1516  llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1517  if (!rankedTensorType)
1518  return failure();
1519  int64_t rank = rankedTensorType.getRank();
1520  if (llvm::isa<IndexType>(op.getType())) {
1521  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1522  rank);
1523  } else if (llvm::isa<shape::SizeType>(op.getType())) {
1524  rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1525  } else {
1526  return failure();
1527  }
1528  return success();
1529  }
1530 };
1531 } // namespace
1532 
1533 void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1534  MLIRContext *context) {
1535  patterns.add<RankShapeOfCanonicalizationPattern>(context);
1536 }
1537 
1538 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1539  MLIRContext *context, std::optional<Location> location,
1540  RankOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1541  if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1542  inferredReturnTypes.assign({SizeType::get(context)});
1543  else
1544  inferredReturnTypes.assign({IndexType::get(context)});
1545  return success();
1546 }
1547 
1548 bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1549  // SizeType is compatible with IndexType.
1550  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1551 }
1552 
1553 LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); }
1554 
1555 //===----------------------------------------------------------------------===//
1556 // NumElementsOp
1557 //===----------------------------------------------------------------------===//
1558 
1559 OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
1560 
1561  // Fold only when argument constant.
1562  Attribute shape = adaptor.getShape();
1563  if (!shape)
1564  return {};
1565 
1566  APInt product(64, 1);
1567  for (auto value : llvm::cast<DenseIntElementsAttr>(shape))
1568  product *= value;
1569  Builder builder(getContext());
1570  return builder.getIndexAttr(product.getLimitedValue());
1571 }
1572 
1573 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1574  MLIRContext *context, std::optional<Location> location,
1575  NumElementsOp::Adaptor adaptor,
1576  SmallVectorImpl<Type> &inferredReturnTypes) {
1577  if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1578  inferredReturnTypes.assign({SizeType::get(context)});
1579  else
1580  inferredReturnTypes.assign({IndexType::get(context)});
1581  return success();
1582 }
1583 
1584 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1585  TypeRange r) {
1586  // SizeType is compatible with IndexType.
1587  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1588 }
1589 
1590 LogicalResult shape::NumElementsOp::verify() {
1591  return verifySizeOrIndexOp(*this);
1592 }
1593 
1594 //===----------------------------------------------------------------------===//
1595 // MaxOp
1596 //===----------------------------------------------------------------------===//
1597 
1598 OpFoldResult MaxOp::fold(FoldAdaptor adaptor) {
1599  // If operands are equal, just propagate one.
1600  if (getLhs() == getRhs())
1601  return getLhs();
1602  return nullptr;
1603 }
1604 
1605 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1606  MLIRContext *context, std::optional<Location> location,
1607  MaxOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1608  if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1609  inferredReturnTypes.assign({adaptor.getLhs().getType()});
1610  else
1611  inferredReturnTypes.assign({SizeType::get(context)});
1612  return success();
1613 }
1614 
1615 bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1616  if (l.size() != 1 || r.size() != 1)
1617  return false;
1618  if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1619  return true;
1620  if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1621  return true;
1622  return false;
1623 }
1624 
1625 //===----------------------------------------------------------------------===//
1626 // MinOp
1627 //===----------------------------------------------------------------------===//
1628 
1629 OpFoldResult MinOp::fold(FoldAdaptor adaptor) {
1630  // If operands are equal, just propagate one.
1631  if (getLhs() == getRhs())
1632  return getLhs();
1633  return nullptr;
1634 }
1635 
1636 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1637  MLIRContext *context, std::optional<Location> location,
1638  MinOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1639  if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1640  inferredReturnTypes.assign({adaptor.getLhs().getType()});
1641  else
1642  inferredReturnTypes.assign({SizeType::get(context)});
1643  return success();
1644 }
1645 
1646 bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1647  if (l.size() != 1 || r.size() != 1)
1648  return false;
1649  if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1650  return true;
1651  if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1652  return true;
1653  return false;
1654 }
1655 
1656 //===----------------------------------------------------------------------===//
1657 // MulOp
1658 //===----------------------------------------------------------------------===//
1659 
1660 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1661  auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1662  if (!lhs)
1663  return nullptr;
1664  auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1665  if (!rhs)
1666  return nullptr;
1667  APInt folded = lhs.getValue() * rhs.getValue();
1668  Type indexTy = IndexType::get(getContext());
1669  return IntegerAttr::get(indexTy, folded);
1670 }
1671 
1672 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1673  MLIRContext *context, std::optional<Location> location,
1674  MulOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1675  if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1676  llvm::isa<SizeType>(adaptor.getRhs().getType()))
1677  inferredReturnTypes.assign({SizeType::get(context)});
1678  else
1679  inferredReturnTypes.assign({IndexType::get(context)});
1680  return success();
1681 }
1682 
1683 bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1684  // SizeType is compatible with IndexType.
1685  return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1686 }
1687 
1688 LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
1689 
1690 //===----------------------------------------------------------------------===//
1691 // ShapeOfOp
1692 //===----------------------------------------------------------------------===//
1693 
1694 namespace {
1695 /// Replace shape_of(x) where x has a constant shape with a const_shape op.
1696 struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
1698 
1699  LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1700  PatternRewriter &rewriter) const override {
1701  auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
1702  if (!type || !type.hasStaticShape())
1703  return failure();
1704  Location loc = op.getLoc();
1705  Value constShape =
1706  rewriter
1707  .create<ConstShapeOp>(loc,
1708  rewriter.getIndexTensorAttr(type.getShape()))
1709  .getResult();
1710  if (constShape.getType() != op.getResult().getType())
1711  constShape = rewriter.create<tensor::CastOp>(
1712  loc, op.getResult().getType(), constShape);
1713  rewriter.replaceOp(op, constShape);
1714  return success();
1715  }
1716 };
1717 
1718 // Canonicalize
1719 //
1720 // %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1721 // %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
1722 //
1723 // to
1724 //
1725 // %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1726 // %1 = %shape
1727 //
1728 struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
1730 
1731  LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1732  PatternRewriter &rewriter) const override {
1733  auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
1734  if (!tensorReshapeOp)
1735  return rewriter.notifyMatchFailure(op, "producer is not tensor.reshape");
1736  if (!isa<TensorType>(op.getType()))
1737  return rewriter.notifyMatchFailure(op, "result is not a tensor");
1738 
1739  // Operand 'shape' of 'tensor.reshape' may now be used as the result of
1740  // 'shape.shape_of'. While its type is guaranteed to be compatible in well-
1741  // formed IR, it may not be identical (dynamically vs statically shaped),
1742  // in which case it needs to be cast first using 'tensor.cast'.
1743  // Additionally, it may not have identical element type (i32 vs index)
1744  // while it has identical shaped type (dynamic vs static), in which case it
1745  // needs to be cast first using 'arith.index_cast'. Note: 'shape.shape_of'
1746  // op result must be shape or extent tensor.
1747  Value shape = tensorReshapeOp.getShape();
1748 
1749  auto opTensorTy = cast<RankedTensorType>(op.getType());
1750  auto shapeTensorTy = cast<RankedTensorType>(shape.getType());
1751 
1752  if (opTensorTy != shapeTensorTy) {
1753  if (opTensorTy.getElementType() == shapeTensorTy.getElementType())
1754  shape = rewriter.create<tensor::CastOp>(op.getLoc(), opTensorTy, shape);
1755  else if (!isExtentTensorType(shapeTensorTy))
1756  shape =
1757  rewriter.create<arith::IndexCastOp>(op.getLoc(), opTensorTy, shape);
1758  }
1759 
1760  rewriter.replaceOp(op, shape);
1761  return success();
1762  }
1763 };
1764 
1765 // Canonicalize
1766 // ```
1767 // %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1768 // %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1769 // ```
1770 // to
1771 // ```
1772 // %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1773 // ```
1774 struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1776 
1777  LogicalResult matchAndRewrite(tensor::CastOp op,
1778  PatternRewriter &rewriter) const override {
1779  auto ty = llvm::dyn_cast<RankedTensorType>(op.getType());
1780  if (!ty || ty.getRank() != 1)
1781  return failure();
1782 
1783  auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>();
1784  if (!shapeOfOp)
1785  return failure();
1786 
1787  // Argument type must be ranked and must not conflict.
1788  auto argTy = llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1789  if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1790  return failure();
1791 
1792  rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1793  return success();
1794  }
1795 };
1796 } // namespace
1797 
1798 void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1799  MLIRContext *context) {
1800  patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
1801  ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
1802  context);
1803 }
1804 
1805 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1806  MLIRContext *context, std::optional<Location> location,
1807  ShapeOfOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1808  if (llvm::isa<ValueShapeType>(adaptor.getArg().getType()))
1809  inferredReturnTypes.assign({ShapeType::get(context)});
1810  else {
1811  auto shapedTy = llvm::cast<ShapedType>(adaptor.getArg().getType());
1812  int64_t rank =
1813  shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
1814  Type indexTy = IndexType::get(context);
1815  Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1816  inferredReturnTypes.assign({extentTensorTy});
1817  }
1818  return success();
1819 }
1820 
1821 bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1822  if (l.size() != 1 || r.size() != 1)
1823  return false;
1824  if (l == r)
1825  return true;
1826 
1827  Type lhs = l.front();
1828  Type rhs = r.front();
1829 
1830  if (!llvm::isa<ShapeType, ShapedType>(lhs) ||
1831  !llvm::isa<ShapeType, ShapedType>(rhs))
1832  return false;
1833 
1834  if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
1835  // Shape type is compatible with all other valid return types.
1836  return true;
1837 
1838  if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1839  return true;
1840  return false;
1841 }
1842 
1843 LogicalResult shape::ShapeOfOp::verify() {
1844  return verifyShapeOrExtentTensorOp(*this);
1845 }
1846 
1847 //===----------------------------------------------------------------------===//
1848 // SizeToIndexOp
1849 //===----------------------------------------------------------------------===//
1850 
1851 OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) {
1852  // Constant values of both types, `shape.size` and `index`, are represented as
1853  // `IntegerAttr`s which makes constant folding simple.
1854  if (Attribute arg = adaptor.getArg())
1855  return arg;
1856  return OpFoldResult();
1857 }
1858 
1859 void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1860  MLIRContext *context) {
1861  patterns.add<IndexToSizeToIndexCanonicalization>(context);
1862 }
1863 
1864 bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1865  if (inputs.size() != 1 || outputs.size() != 1)
1866  return false;
1867  return llvm::isa<IndexType, SizeType>(inputs[0]) &&
1868  llvm::isa<IndexType>(outputs[0]);
1869 }
1870 
1871 //===----------------------------------------------------------------------===//
1872 // YieldOp
1873 //===----------------------------------------------------------------------===//
1874 
1875 LogicalResult shape::YieldOp::verify() {
1876  auto *parentOp = (*this)->getParentOp();
1877  auto results = parentOp->getResults();
1878  auto operands = getOperands();
1879 
1880  if (parentOp->getNumResults() != getNumOperands())
1881  return emitOpError() << "number of operands does not match number of "
1882  "results of its parent";
1883  for (auto e : llvm::zip(results, operands))
1884  if (std::get<0>(e).getType() != std::get<1>(e).getType())
1885  return emitOpError() << "types mismatch between yield op and its parent";
1886 
1887  return success();
1888 }
1889 
1890 //===----------------------------------------------------------------------===//
1891 // SplitAtOp
1892 //===----------------------------------------------------------------------===//
1893 
1894 LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
1895  SmallVectorImpl<OpFoldResult> &results) {
1896  if (!adaptor.getOperand() || !adaptor.getIndex())
1897  return failure();
1898  auto shapeVec = llvm::to_vector<6>(
1899  llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<int64_t>());
1900  auto shape = llvm::ArrayRef(shapeVec);
1901  auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt();
1902  // Verify that the split point is in the correct range.
1903  // TODO: Constant fold to an "error".
1904  int64_t rank = shape.size();
1905  if (-rank > splitPoint || splitPoint > rank)
1906  return failure();
1907  if (splitPoint < 0)
1908  splitPoint += shape.size();
1909  Builder builder(adaptor.getOperand().getContext());
1910  results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1911  results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1912  return success();
1913 }
1914 
1915 //===----------------------------------------------------------------------===//
1916 // ToExtentTensorOp
1917 //===----------------------------------------------------------------------===//
1918 
1919 OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
1920  if (!adaptor.getInput())
1921  return OpFoldResult();
1922  Builder builder(getContext());
1923  auto shape = llvm::to_vector<6>(
1924  llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<int64_t>());
1925  auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1926  builder.getIndexType());
1927  return DenseIntElementsAttr::get(type, shape);
1928 }
1929 
1930 bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1931  if (inputs.size() != 1 || outputs.size() != 1)
1932  return false;
1933  if (auto inputTensor = llvm::dyn_cast<RankedTensorType>(inputs[0])) {
1934  if (!llvm::isa<IndexType>(inputTensor.getElementType()) ||
1935  inputTensor.getRank() != 1)
1936  return false;
1937  } else if (!llvm::isa<ShapeType>(inputs[0])) {
1938  return false;
1939  }
1940 
1941  TensorType outputTensor = llvm::dyn_cast<TensorType>(outputs[0]);
1942  return outputTensor && llvm::isa<IndexType>(outputTensor.getElementType());
1943 }
1944 
1945 //===----------------------------------------------------------------------===//
1946 // ReduceOp
1947 //===----------------------------------------------------------------------===//
1948 
1949 void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1950  ValueRange initVals) {
1951  OpBuilder::InsertionGuard g(builder);
1952  result.addOperands(shape);
1953  result.addOperands(initVals);
1954 
1955  Region *bodyRegion = result.addRegion();
1956  Block *bodyBlock = builder.createBlock(
1957  bodyRegion, /*insertPt=*/{}, builder.getIndexType(), result.location);
1958 
1959  Type elementType;
1960  if (auto tensorType = llvm::dyn_cast<TensorType>(shape.getType()))
1961  elementType = tensorType.getElementType();
1962  else
1963  elementType = SizeType::get(builder.getContext());
1964  bodyBlock->addArgument(elementType, shape.getLoc());
1965 
1966  for (Value initVal : initVals) {
1967  bodyBlock->addArgument(initVal.getType(), initVal.getLoc());
1968  result.addTypes(initVal.getType());
1969  }
1970 }
1971 
1972 LogicalResult ReduceOp::verify() {
1973  // Verify block arg types.
1974  Block &block = getRegion().front();
1975 
1976  // The block takes index, extent, and aggregated values as arguments.
1977  auto blockArgsCount = getInitVals().size() + 2;
1978  if (block.getNumArguments() != blockArgsCount)
1979  return emitOpError() << "ReduceOp body is expected to have "
1980  << blockArgsCount << " arguments";
1981 
1982  // The first block argument is the index and must always be of type `index`.
1983  if (!llvm::isa<IndexType>(block.getArgument(0).getType()))
1984  return emitOpError(
1985  "argument 0 of ReduceOp body is expected to be of IndexType");
1986 
1987  // The second block argument is the extent and must be of type `size` or
1988  // `index`, depending on whether the reduce operation is applied to a shape or
1989  // to an extent tensor.
1990  Type extentTy = block.getArgument(1).getType();
1991  if (llvm::isa<ShapeType>(getShape().getType())) {
1992  if (!llvm::isa<SizeType>(extentTy))
1993  return emitOpError("argument 1 of ReduceOp body is expected to be of "
1994  "SizeType if the ReduceOp operates on a ShapeType");
1995  } else {
1996  if (!llvm::isa<IndexType>(extentTy))
1997  return emitOpError(
1998  "argument 1 of ReduceOp body is expected to be of IndexType if the "
1999  "ReduceOp operates on an extent tensor");
2000  }
2001 
2002  for (const auto &type : llvm::enumerate(getInitVals()))
2003  if (block.getArgument(type.index() + 2).getType() != type.value().getType())
2004  return emitOpError() << "type mismatch between argument "
2005  << type.index() + 2
2006  << " of ReduceOp body and initial value "
2007  << type.index();
2008  return success();
2009 }
2010 
2011 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
2012  // Parse operands.
2014  Type shapeOrExtentTensorType;
2015  if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
2017  parser.parseColonType(shapeOrExtentTensorType) ||
2018  parser.parseOptionalArrowTypeList(result.types))
2019  return failure();
2020 
2021  // Resolve operands.
2022  auto initVals = llvm::ArrayRef(operands).drop_front();
2023  if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
2024  result.operands) ||
2025  parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2026  result.operands))
2027  return failure();
2028 
2029  // Parse the body.
2030  Region *body = result.addRegion();
2031  if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
2032  return failure();
2033 
2034  // Parse attributes.
2035  if (parser.parseOptionalAttrDict(result.attributes))
2036  return failure();
2037 
2038  return success();
2039 }
2040 
2041 void ReduceOp::print(OpAsmPrinter &p) {
2042  p << '(' << getShape() << ", " << getInitVals()
2043  << ") : " << getShape().getType();
2044  p.printOptionalArrowTypeList(getResultTypes());
2045  p << ' ';
2046  p.printRegion(getRegion());
2047  p.printOptionalAttrDict((*this)->getAttrs());
2048 }
2049 
2050 #define GET_OP_CLASSES
2051 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
2052 
2053 #define GET_TYPEDEF_CLASSES
2054 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static bool isErrorPropagationPossible(TypeRange operandTypes)
Definition: Shape.cpp:67
static bool hasAtMostSingleNonScalar(ArrayRef< Attribute > attributes)
Definition: Shape.cpp:975
static LogicalResult verifyShapeOrExtentTensorOp(Operation *op)
Definition: Shape.cpp:84
static bool eachHasOnlyOneOfTypes(TypeRange typeRange)
Definition: Shape.cpp:97
static LogicalResult verifySizeOrIndexOp(Operation *op)
Definition: Shape.cpp:72
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Definition: Attributes.cpp:37
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
Operation & front()
Definition: Block.h:153
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:76
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:89
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:189
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:51
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:90
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:440
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:442
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
ValueTypeRange< ValueRange > type_range
Definition: ValueRange.h:418
Type front()
Return first type in the range.
Definition: TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A named class for passing around the variadic flag.
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 >> shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Definition: Traits.cpp:25
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition: Traits.cpp:60
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition: Barvinok.cpp:63
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
bool isExtentTensorType(Type)
Definition: Shape.cpp:45
LogicalResult getShapeVec(Value input, SmallVectorImpl< int64_t > &shapeValues)
Definition: Shape.cpp:50
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Definition: Shape.cpp:41
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.