MLIR  20.0.0git
Operator.cpp
Go to the documentation of this file.
1 //===- Operator.cpp - Operator class --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Operator wrapper to simplify using TableGen Record defining a MLIR Op.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/TableGen/Operator.h"
14 #include "mlir/TableGen/Argument.h"
16 #include "mlir/TableGen/Trait.h"
17 #include "mlir/TableGen/Type.h"
18 #include "llvm/ADT/EquivalenceClasses.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/TableGen/Error.h"
28 #include "llvm/TableGen/Record.h"
29 #include <list>
30 
31 #define DEBUG_TYPE "mlir-tblgen-operator"
32 
33 using namespace mlir;
34 using namespace mlir::tblgen;
35 
36 using llvm::DagInit;
37 using llvm::DefInit;
38 using llvm::Init;
39 using llvm::ListInit;
40 using llvm::Record;
41 using llvm::StringInit;
42 
43 Operator::Operator(const Record &def)
44  : dialect(def.getValueAsDef("opDialect")), def(def) {
45  // The first `_` in the op's TableGen def name is treated as separating the
46  // dialect prefix and the op class name. The dialect prefix will be ignored if
47  // not empty. Otherwise, if def name starts with a `_`, the `_` is considered
48  // as part of the class name.
49  StringRef prefix;
50  std::tie(prefix, cppClassName) = def.getName().split('_');
51  if (prefix.empty()) {
52  // Class name with a leading underscore and without dialect prefix
53  cppClassName = def.getName();
54  } else if (cppClassName.empty()) {
55  // Class name without dialect prefix
56  cppClassName = prefix;
57  }
58 
59  cppNamespace = def.getValueAsString("cppNamespace");
60 
61  populateOpStructure();
62  assertInvariants();
63 }
64 
65 std::string Operator::getOperationName() const {
66  auto prefix = dialect.getName();
67  auto opName = def.getValueAsString("opName");
68  if (prefix.empty())
69  return std::string(opName);
70  return std::string(llvm::formatv("{0}.{1}", prefix, opName));
71 }
72 
73 std::string Operator::getAdaptorName() const {
74  return std::string(llvm::formatv("{0}Adaptor", getCppClassName()));
75 }
76 
77 std::string Operator::getGenericAdaptorName() const {
78  return std::string(llvm::formatv("{0}GenericAdaptor", getCppClassName()));
79 }
80 
81 /// Assert the invariants of accessors generated for the given name.
82 static void assertAccessorInvariants(const Operator &op, StringRef name) {
83  std::string accessorName =
84  convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true);
85 
86  // Functor used to detect when an accessor will cause an overlap with an
87  // operation API.
88  //
89  // There are a little bit more invasive checks possible for cases where not
90  // all ops have the trait that would cause overlap. For many cases here,
91  // renaming would be better (e.g., we can only guard in limited manner
92  // against methods from traits and interfaces here, so avoiding these in op
93  // definition is safer).
94  auto nameOverlapsWithOpAPI = [&](StringRef newName) {
95  if (newName == "AttributeNames" || newName == "Attributes" ||
96  newName == "Operation")
97  return true;
98  if (newName == "Operands")
99  return op.getNumOperands() != 1 || op.getNumVariableLengthOperands() != 1;
100  if (newName == "Regions")
101  return op.getNumRegions() != 1 || op.getNumVariadicRegions() != 1;
102  if (newName == "Type")
103  return op.getNumResults() != 1;
104  return false;
105  };
106  if (nameOverlapsWithOpAPI(accessorName)) {
107  // This error could be avoided in situations where the final function is
108  // identical, but preferably the op definition should avoid using generic
109  // names.
110  PrintFatalError(op.getLoc(), "generated accessor for `" + name +
111  "` overlaps with a default one; please "
112  "rename to avoid overlap");
113  }
114 }
115 
117  // Check that the name of arguments/results/regions/successors don't overlap.
118  DenseMap<StringRef, StringRef> existingNames;
119  auto checkName = [&](StringRef name, StringRef entity) {
120  if (name.empty())
121  return;
122  auto insertion = existingNames.insert({name, entity});
123  if (insertion.second) {
124  // Assert invariants for accessors generated for this name.
125  assertAccessorInvariants(*this, name);
126  return;
127  }
128  if (entity == insertion.first->second)
129  PrintFatalError(getLoc(), "op has a conflict with two " + entity +
130  " having the same name '" + name + "'");
131  PrintFatalError(getLoc(), "op has a conflict with " +
132  insertion.first->second + " and " + entity +
133  " both having an entry with the name '" +
134  name + "'");
135  };
136  // Check operands amongst themselves.
137  for (int i : llvm::seq<int>(0, getNumOperands()))
138  checkName(getOperand(i).name, "operands");
139 
140  // Check results amongst themselves and against operands.
141  for (int i : llvm::seq<int>(0, getNumResults()))
142  checkName(getResult(i).name, "results");
143 
144  // Check regions amongst themselves and against operands and results.
145  for (int i : llvm::seq<int>(0, getNumRegions()))
146  checkName(getRegion(i).name, "regions");
147 
148  // Check successors amongst themselves and against operands, results, and
149  // regions.
150  for (int i : llvm::seq<int>(0, getNumSuccessors()))
151  checkName(getSuccessor(i).name, "successors");
152 }
153 
154 StringRef Operator::getDialectName() const { return dialect.getName(); }
155 
156 StringRef Operator::getCppClassName() const { return cppClassName; }
157 
158 std::string Operator::getQualCppClassName() const {
159  if (cppNamespace.empty())
160  return std::string(cppClassName);
161  return std::string(llvm::formatv("{0}::{1}", cppNamespace, cppClassName));
162 }
163 
164 StringRef Operator::getCppNamespace() const { return cppNamespace; }
165 
167  const DagInit *results = def.getValueAsDag("results");
168  return results->getNumArgs();
169 }
170 
172  constexpr auto attr = "extraClassDeclaration";
173  if (def.isValueUnset(attr))
174  return {};
175  return def.getValueAsString(attr);
176 }
177 
179  constexpr auto attr = "extraClassDefinition";
180  if (def.isValueUnset(attr))
181  return {};
182  return def.getValueAsString(attr);
183 }
184 
185 const Record &Operator::getDef() const { return def; }
186 
188  return def.getValueAsBit("skipDefaultBuilders");
189 }
190 
192  return results.begin();
193 }
194 
196  return results.end();
197 }
198 
200  return {result_begin(), result_end()};
201 }
202 
204  const DagInit *results = def.getValueAsDag("results");
205  return TypeConstraint(cast<DefInit>(results->getArg(index)));
206 }
207 
208 StringRef Operator::getResultName(int index) const {
209  const DagInit *results = def.getValueAsDag("results");
210  return results->getArgNameStr(index);
211 }
212 
214  const Record *result =
215  cast<DefInit>(def.getValueAsDag("results")->getArg(index))->getDef();
216  if (!result->isSubClassOf("OpVariable"))
217  return var_decorator_range(nullptr, nullptr);
218  return *result->getValueAsListInit("decorators");
219 }
220 
222  return llvm::count_if(results, [](const NamedTypeConstraint &c) {
223  return c.constraint.isVariableLength();
224  });
225 }
226 
228  return llvm::count_if(operands, [](const NamedTypeConstraint &c) {
229  return c.constraint.isVariableLength();
230  });
231 }
232 
234  return getNumArgs() == 1 && getArg(0).is<NamedTypeConstraint *>() &&
235  getOperand(0).isVariadic();
236 }
237 
238 Operator::arg_iterator Operator::arg_begin() const { return arguments.begin(); }
239 
240 Operator::arg_iterator Operator::arg_end() const { return arguments.end(); }
241 
243  return {arg_begin(), arg_end()};
244 }
245 
246 StringRef Operator::getArgName(int index) const {
247  const DagInit *argumentValues = def.getValueAsDag("arguments");
248  return argumentValues->getArgNameStr(index);
249 }
250 
252  const Record *arg =
253  cast<DefInit>(def.getValueAsDag("arguments")->getArg(index))->getDef();
254  if (!arg->isSubClassOf("OpVariable"))
255  return var_decorator_range(nullptr, nullptr);
256  return *arg->getValueAsListInit("decorators");
257 }
258 
259 const Trait *Operator::getTrait(StringRef trait) const {
260  for (const auto &t : traits) {
261  if (const auto *traitDef = dyn_cast<NativeTrait>(&t)) {
262  if (traitDef->getFullyQualifiedTraitName() == trait)
263  return traitDef;
264  } else if (const auto *traitDef = dyn_cast<InternalTrait>(&t)) {
265  if (traitDef->getFullyQualifiedTraitName() == trait)
266  return traitDef;
267  } else if (const auto *traitDef = dyn_cast<InterfaceTrait>(&t)) {
268  if (traitDef->getFullyQualifiedTraitName() == trait)
269  return traitDef;
270  }
271  }
272  return nullptr;
273 }
274 
276  return regions.begin();
277 }
279  return regions.end();
280 }
283  return {region_begin(), region_end()};
284 }
285 
286 unsigned Operator::getNumRegions() const { return regions.size(); }
287 
288 const NamedRegion &Operator::getRegion(unsigned index) const {
289  return regions[index];
290 }
291 
293  return llvm::count_if(regions,
294  [](const NamedRegion &c) { return c.isVariadic(); });
295 }
296 
298  return successors.begin();
299 }
301  return successors.end();
302 }
305  return {successor_begin(), successor_end()};
306 }
307 
308 unsigned Operator::getNumSuccessors() const { return successors.size(); }
309 
310 const NamedSuccessor &Operator::getSuccessor(unsigned index) const {
311  return successors[index];
312 }
313 
315  return llvm::count_if(successors,
316  [](const NamedSuccessor &c) { return c.isVariadic(); });
317 }
318 
320  return traits.begin();
321 }
323  return traits.end();
324 }
326  return {trait_begin(), trait_end()};
327 }
328 
330  return attributes.begin();
331 }
333  return attributes.end();
334 }
337  return {attribute_begin(), attribute_end()};
338 }
340  return attributes.begin();
341 }
343  return attributes.end();
344 }
346  return {attribute_begin(), attribute_end()};
347 }
348 
350  return operands.begin();
351 }
353  return operands.end();
354 }
356  return {operand_begin(), operand_end()};
357 }
358 
359 auto Operator::getArg(int index) const -> Argument { return arguments[index]; }
360 
361 bool Operator::isVariadic() const {
362  return any_of(llvm::concat<const NamedTypeConstraint>(operands, results),
363  [](const NamedTypeConstraint &op) { return op.isVariadic(); });
364 }
365 
366 void Operator::populateTypeInferenceInfo(
367  const llvm::StringMap<int> &argumentsAndResultsIndex) {
368  // If the type inference op interface is not registered, then do not attempt
369  // to determine if the result types an be inferred.
370  auto &recordKeeper = def.getRecords();
371  auto *inferTrait = recordKeeper.getDef(inferTypeOpInterface);
372  allResultsHaveKnownTypes = false;
373  if (!inferTrait)
374  return;
375 
376  // If there are no results, the skip this else the build method generated
377  // overlaps with another autogenerated builder.
378  if (getNumResults() == 0)
379  return;
380 
381  // Skip ops with variadic or optional results.
382  if (getNumVariableLengthResults() > 0)
383  return;
384 
385  // Skip cases currently being custom generated.
386  // TODO: Remove special cases.
387  if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
388  // Check for a non-variable length operand to use as the type anchor.
389  auto *operandI = llvm::find_if(arguments, [](const Argument &arg) {
390  NamedTypeConstraint *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg);
391  return operand && !operand->isVariableLength();
392  });
393  if (operandI == arguments.end())
394  return;
395 
396  // All result types are inferred from the operand type.
397  int operandIdx = operandI - arguments.begin();
398  for (int i = 0; i < getNumResults(); ++i)
399  resultTypeMapping.emplace_back(operandIdx, "$_self");
400 
401  allResultsHaveKnownTypes = true;
402  traits.push_back(Trait::create(inferTrait->getDefInit()));
403  return;
404  }
405 
406  /// This struct represents a node in this operation's result type inferenece
407  /// graph. Each node has a list of incoming type inference edges `sources`.
408  /// Each edge represents a "source" from which the result type can be
409  /// inferred, either an operand (leaf) or another result (node). When a node
410  /// is known to have a fully-inferred type, `inferred` is set to true.
411  struct ResultTypeInference {
412  /// The list of incoming type inference edges.
414  /// This flag is set to true when the result type is known to be inferrable.
415  bool inferred = false;
416  };
417 
418  // This vector represents the type inference graph, with one node for each
419  // operation result. The nth element is the node for the nth result.
421 
422  // For all results whose types are buildable, initialize their type inference
423  // nodes with an edge to themselves. Mark those nodes are fully-inferred.
424  for (auto [idx, infer] : llvm::enumerate(inference)) {
425  if (getResult(idx).constraint.getBuilderCall()) {
426  infer.sources.emplace_back(InferredResultType::mapResultIndex(idx),
427  "$_self");
428  infer.inferred = true;
429  }
430  }
431 
432  // Use `AllTypesMatch` and `TypesMatchWith` operation traits to build the
433  // result type inference graph.
434  for (const Trait &trait : traits) {
435  const Record &def = trait.getDef();
436 
437  // If the infer type op interface was manually added, then treat it as
438  // intention that the op needs special handling.
439  // TODO: Reconsider whether to always generate, this is more conservative
440  // and keeps existing behavior so starting that way for now.
441  if (def.isSubClassOf(
442  llvm::formatv("{0}::Trait", inferTypeOpInterface).str()))
443  return;
444  if (const auto *traitDef = dyn_cast<InterfaceTrait>(&trait))
445  if (&traitDef->getDef() == inferTrait)
446  return;
447 
448  // The `TypesMatchWith` trait represents a 1 -> 1 type inference edge with a
449  // type transformer.
450  if (def.isSubClassOf("TypesMatchWith")) {
451  int target = argumentsAndResultsIndex.lookup(def.getValueAsString("rhs"));
452  // Ignore operand type inference.
453  if (InferredResultType::isArgIndex(target))
454  continue;
455  int resultIndex = InferredResultType::unmapResultIndex(target);
456  ResultTypeInference &infer = inference[resultIndex];
457  // If the type of the result has already been inferred, do nothing.
458  if (infer.inferred)
459  continue;
460  int sourceIndex =
461  argumentsAndResultsIndex.lookup(def.getValueAsString("lhs"));
462  infer.sources.emplace_back(sourceIndex,
463  def.getValueAsString("transformer").str());
464  // Locally propagate inferredness.
465  infer.inferred =
466  InferredResultType::isArgIndex(sourceIndex) ||
467  inference[InferredResultType::unmapResultIndex(sourceIndex)].inferred;
468  continue;
469  }
470 
471  if (!def.isSubClassOf("AllTypesMatch"))
472  continue;
473 
474  auto values = def.getValueAsListOfStrings("values");
475  // The `AllTypesMatch` trait represents an N <-> N fanin and fanout. That
476  // is, every result type has an edge from every other type. However, if any
477  // one of the values refers to an operand or a result with a fully-inferred
478  // type, we can infer all other types from that value. Try to find a
479  // fully-inferred type in the list.
480  std::optional<int> fullyInferredIndex;
481  SmallVector<int> resultIndices;
482  for (StringRef name : values) {
483  int index = argumentsAndResultsIndex.lookup(name);
485  resultIndices.push_back(InferredResultType::unmapResultIndex(index));
486  if (InferredResultType::isArgIndex(index) ||
487  inference[InferredResultType::unmapResultIndex(index)].inferred)
488  fullyInferredIndex = index;
489  }
490  if (fullyInferredIndex) {
491  // Make the fully-inferred type the only source for all results that
492  // aren't already inferred -- a 1 -> N fanout.
493  for (int resultIndex : resultIndices) {
494  ResultTypeInference &infer = inference[resultIndex];
495  if (!infer.inferred) {
496  infer.sources.assign(1, {*fullyInferredIndex, "$_self"});
497  infer.inferred = true;
498  }
499  }
500  } else {
501  // Add an edge between every result and every other type; N <-> N.
502  for (int resultIndex : resultIndices) {
503  for (int otherResultIndex : resultIndices) {
504  if (resultIndex == otherResultIndex)
505  continue;
506  inference[resultIndex].sources.emplace_back(otherResultIndex,
507  "$_self");
508  }
509  }
510  }
511  }
512 
513  // Propagate inferredness until a fixed point.
514  std::vector<ResultTypeInference *> worklist;
515  for (ResultTypeInference &infer : inference)
516  if (!infer.inferred)
517  worklist.push_back(&infer);
518  bool changed;
519  do {
520  changed = false;
521  for (auto cur = worklist.begin(); cur != worklist.end();) {
522  ResultTypeInference &infer = **cur;
523 
524  InferredResultType *iter =
525  llvm::find_if(infer.sources, [&](const InferredResultType &source) {
526  assert(InferredResultType::isResultIndex(source.getIndex()));
527  return inference[InferredResultType::unmapResultIndex(
528  source.getIndex())]
529  .inferred;
530  });
531  if (iter == infer.sources.end()) {
532  ++cur;
533  continue;
534  }
535 
536  changed = true;
537  infer.inferred = true;
538  // Make this the only source for the result. This breaks any cycles.
539  infer.sources.assign(1, *iter);
540  cur = worklist.erase(cur);
541  }
542  } while (changed);
543 
544  allResultsHaveKnownTypes = worklist.empty();
545 
546  // If the types could be computed, then add type inference trait.
547  if (allResultsHaveKnownTypes) {
548  traits.push_back(Trait::create(inferTrait->getDefInit()));
549  for (const ResultTypeInference &infer : inference)
550  resultTypeMapping.push_back(infer.sources.front());
551  }
552 }
553 
554 void Operator::populateOpStructure() {
555  auto &recordKeeper = def.getRecords();
556  auto *typeConstraintClass = recordKeeper.getClass("TypeConstraint");
557  auto *attrClass = recordKeeper.getClass("Attr");
558  auto *propertyClass = recordKeeper.getClass("Property");
559  auto *derivedAttrClass = recordKeeper.getClass("DerivedAttr");
560  auto *opVarClass = recordKeeper.getClass("OpVariable");
561  numNativeAttributes = 0;
562 
563  const DagInit *argumentValues = def.getValueAsDag("arguments");
564  unsigned numArgs = argumentValues->getNumArgs();
565 
566  // Mapping from name of to argument or result index. Arguments are indexed
567  // to match getArg index, while the results are negatively indexed.
568  llvm::StringMap<int> argumentsAndResultsIndex;
569 
570  // Handle operands and native attributes.
571  for (unsigned i = 0; i != numArgs; ++i) {
572  auto *arg = argumentValues->getArg(i);
573  auto givenName = argumentValues->getArgNameStr(i);
574  auto *argDefInit = dyn_cast<DefInit>(arg);
575  if (!argDefInit)
576  PrintFatalError(def.getLoc(),
577  Twine("undefined type for argument #") + Twine(i));
578  const Record *argDef = argDefInit->getDef();
579  if (argDef->isSubClassOf(opVarClass))
580  argDef = argDef->getValueAsDef("constraint");
581 
582  if (argDef->isSubClassOf(typeConstraintClass)) {
583  operands.push_back(
584  NamedTypeConstraint{givenName, TypeConstraint(argDef)});
585  } else if (argDef->isSubClassOf(attrClass)) {
586  if (givenName.empty())
587  PrintFatalError(argDef->getLoc(), "attributes must be named");
588  if (argDef->isSubClassOf(derivedAttrClass))
589  PrintFatalError(argDef->getLoc(),
590  "derived attributes not allowed in argument list");
591  attributes.push_back({givenName, Attribute(argDef)});
592  ++numNativeAttributes;
593  } else if (argDef->isSubClassOf(propertyClass)) {
594  if (givenName.empty())
595  PrintFatalError(argDef->getLoc(), "properties must be named");
596  properties.push_back({givenName, Property(argDef)});
597  } else {
598  PrintFatalError(def.getLoc(),
599  "unexpected def type; only defs deriving "
600  "from TypeConstraint or Attr or Property are allowed");
601  }
602  if (!givenName.empty())
603  argumentsAndResultsIndex[givenName] = i;
604  }
605 
606  // Handle derived attributes.
607  for (const auto &val : def.getValues()) {
608  if (auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) {
609  if (!record->isSubClassOf(attrClass))
610  continue;
611  if (!record->isSubClassOf(derivedAttrClass))
612  PrintFatalError(def.getLoc(),
613  "unexpected Attr where only DerivedAttr is allowed");
614 
615  if (record->getClasses().size() != 1) {
616  PrintFatalError(
617  def.getLoc(),
618  "unsupported attribute modelling, only single class expected");
619  }
620  attributes.push_back({cast<StringInit>(val.getNameInit())->getValue(),
621  Attribute(cast<DefInit>(val.getValue()))});
622  }
623  }
624 
625  // Populate `arguments`. This must happen after we've finalized `operands` and
626  // `attributes` because we will put their elements' pointers in `arguments`.
627  // SmallVector may perform re-allocation under the hood when adding new
628  // elements.
629  int operandIndex = 0, attrIndex = 0, propIndex = 0;
630  for (unsigned i = 0; i != numArgs; ++i) {
631  const Record *argDef =
632  dyn_cast<DefInit>(argumentValues->getArg(i))->getDef();
633  if (argDef->isSubClassOf(opVarClass))
634  argDef = argDef->getValueAsDef("constraint");
635 
636  if (argDef->isSubClassOf(typeConstraintClass)) {
637  attrOrOperandMapping.push_back(
638  {OperandOrAttribute::Kind::Operand, operandIndex});
639  arguments.emplace_back(&operands[operandIndex++]);
640  } else if (argDef->isSubClassOf(attrClass)) {
641  attrOrOperandMapping.push_back(
643  arguments.emplace_back(&attributes[attrIndex++]);
644  } else {
645  assert(argDef->isSubClassOf(propertyClass));
646  arguments.emplace_back(&properties[propIndex++]);
647  }
648  }
649 
650  auto *resultsDag = def.getValueAsDag("results");
651  auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator());
652  if (!outsOp || outsOp->getDef()->getName() != "outs") {
653  PrintFatalError(def.getLoc(), "'results' must have 'outs' directive");
654  }
655 
656  // Handle results.
657  for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) {
658  auto name = resultsDag->getArgNameStr(i);
659  auto *resultInit = dyn_cast<DefInit>(resultsDag->getArg(i));
660  if (!resultInit) {
661  PrintFatalError(def.getLoc(),
662  Twine("undefined type for result #") + Twine(i));
663  }
664  auto *resultDef = resultInit->getDef();
665  if (resultDef->isSubClassOf(opVarClass))
666  resultDef = resultDef->getValueAsDef("constraint");
667  results.push_back({name, TypeConstraint(resultDef)});
668  if (!name.empty())
669  argumentsAndResultsIndex[name] = InferredResultType::mapResultIndex(i);
670 
671  // We currently only support VariadicOfVariadic operands.
672  if (results.back().constraint.isVariadicOfVariadic()) {
673  PrintFatalError(
674  def.getLoc(),
675  "'VariadicOfVariadic' results are currently not supported");
676  }
677  }
678 
679  // Handle successors
680  auto *successorsDag = def.getValueAsDag("successors");
681  auto *successorsOp = dyn_cast<DefInit>(successorsDag->getOperator());
682  if (!successorsOp || successorsOp->getDef()->getName() != "successor") {
683  PrintFatalError(def.getLoc(),
684  "'successors' must have 'successor' directive");
685  }
686 
687  for (unsigned i = 0, e = successorsDag->getNumArgs(); i < e; ++i) {
688  auto name = successorsDag->getArgNameStr(i);
689  auto *successorInit = dyn_cast<DefInit>(successorsDag->getArg(i));
690  if (!successorInit) {
691  PrintFatalError(def.getLoc(),
692  Twine("undefined kind for successor #") + Twine(i));
693  }
694  Successor successor(successorInit->getDef());
695 
696  // Only support variadic successors if it is the last one for now.
697  if (i != e - 1 && successor.isVariadic())
698  PrintFatalError(def.getLoc(), "only the last successor can be variadic");
699  successors.push_back({name, successor});
700  }
701 
702  // Create list of traits, skipping over duplicates: appending to lists in
703  // tablegen is easy, making them unique less so, so dedupe here.
704  if (auto *traitList = def.getValueAsListInit("traits")) {
705  // This is uniquing based on pointers of the trait.
707  traits.reserve(traitSet.size());
708 
709  // The declaration order of traits imply the verification order of traits.
710  // Some traits may require other traits to be verified first then they can
711  // do further verification based on those verified facts. If you see this
712  // error, fix the traits declaration order by checking the `dependentTraits`
713  // field.
714  auto verifyTraitValidity = [&](const Record *trait) {
715  auto *dependentTraits = trait->getValueAsListInit("dependentTraits");
716  for (auto *traitInit : *dependentTraits)
717  if (!traitSet.contains(traitInit))
718  PrintFatalError(
719  def.getLoc(),
720  trait->getValueAsString("trait") + " requires " +
721  cast<DefInit>(traitInit)->getDef()->getValueAsString(
722  "trait") +
723  " to precede it in traits list");
724  };
725 
726  std::function<void(const ListInit *)> insert;
727  insert = [&](const ListInit *traitList) {
728  for (auto *traitInit : *traitList) {
729  auto *def = cast<DefInit>(traitInit)->getDef();
730  if (def->isSubClassOf("TraitList")) {
731  insert(def->getValueAsListInit("traits"));
732  continue;
733  }
734 
735  // Ignore duplicates.
736  if (!traitSet.insert(traitInit).second)
737  continue;
738 
739  // If this is an interface with base classes, add the bases to the
740  // trait list.
741  if (def->isSubClassOf("Interface"))
742  insert(def->getValueAsListInit("baseInterfaces"));
743 
744  // Verify if the trait has all the dependent traits declared before
745  // itself.
746  verifyTraitValidity(def);
747  traits.push_back(Trait::create(traitInit));
748  }
749  };
750  insert(traitList);
751  }
752 
753  populateTypeInferenceInfo(argumentsAndResultsIndex);
754 
755  // Handle regions
756  auto *regionsDag = def.getValueAsDag("regions");
757  auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator());
758  if (!regionsOp || regionsOp->getDef()->getName() != "region") {
759  PrintFatalError(def.getLoc(), "'regions' must have 'region' directive");
760  }
761 
762  for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) {
763  auto name = regionsDag->getArgNameStr(i);
764  auto *regionInit = dyn_cast<DefInit>(regionsDag->getArg(i));
765  if (!regionInit) {
766  PrintFatalError(def.getLoc(),
767  Twine("undefined kind for region #") + Twine(i));
768  }
769  Region region(regionInit->getDef());
770  if (region.isVariadic()) {
771  // Only support variadic regions if it is the last one for now.
772  if (i != e - 1)
773  PrintFatalError(def.getLoc(), "only the last region can be variadic");
774  if (name.empty())
775  PrintFatalError(def.getLoc(), "variadic regions must be named");
776  }
777 
778  regions.push_back({name, region});
779  }
780 
781  // Populate the builders.
782  auto *builderList = dyn_cast_or_null<ListInit>(def.getValueInit("builders"));
783  if (builderList && !builderList->empty()) {
784  for (const Init *init : builderList->getValues())
785  builders.emplace_back(cast<DefInit>(init)->getDef(), def.getLoc());
786  } else if (skipDefaultBuilders()) {
787  PrintFatalError(
788  def.getLoc(),
789  "default builders are skipped and no custom builders provided");
790  }
791 
792  LLVM_DEBUG(print(llvm::dbgs()));
793 }
794 
796  assert(allResultTypesKnown());
797  return resultTypeMapping[index];
798 }
799 
800 ArrayRef<SMLoc> Operator::getLoc() const { return def.getLoc(); }
801 
803  return !getDescription().trim().empty();
804 }
805 
806 StringRef Operator::getDescription() const {
807  return def.getValueAsString("description");
808 }
809 
810 bool Operator::hasSummary() const { return !getSummary().trim().empty(); }
811 
812 StringRef Operator::getSummary() const {
813  return def.getValueAsString("summary");
814 }
815 
817  auto *valueInit = def.getValueInit("assemblyFormat");
818  return isa<StringInit>(valueInit);
819 }
820 
821 StringRef Operator::getAssemblyFormat() const {
822  return TypeSwitch<const Init *, StringRef>(def.getValueInit("assemblyFormat"))
823  .Case<StringInit>([&](auto *init) { return init->getValue(); });
824 }
825 
826 void Operator::print(llvm::raw_ostream &os) const {
827  os << "op '" << getOperationName() << "'\n";
828  for (Argument arg : arguments) {
829  if (auto *attr = llvm::dyn_cast_if_present<NamedAttribute *>(arg))
830  os << "[attribute] " << attr->name << '\n';
831  else
832  os << "[operand] " << arg.get<NamedTypeConstraint *>()->name << '\n';
833  }
834 }
835 
837  -> VariableDecorator {
838  return VariableDecorator(cast<DefInit>(init)->getDef());
839 }
840 
842  -> OperandOrAttribute {
843  return attrOrOperandMapping[index];
844 }
845 
846 std::string Operator::getGetterName(StringRef name) const {
847  return "get" + convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true);
848 }
849 
850 std::string Operator::getSetterName(StringRef name) const {
851  return "set" + convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true);
852 }
853 
854 std::string Operator::getRemoverName(StringRef name) const {
855  return "remove" + convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true);
856 }
857 
858 bool Operator::hasFolder() const { return def.getValueAsBit("hasFolder"); }
859 
861  return def.getValueAsBit("useCustomPropertiesEncoding");
862 }
static void assertAccessorInvariants(const Operator &op, StringRef name)
Assert the invariants of accessors generated for the given name.
Definition: Operator.cpp:82
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
StringRef getName() const
This class represents an inferred result type.
Definition: Operator.h:44
static int mapResultIndex(int i)
Definition: Operator.h:59
static int unmapResultIndex(int i)
Definition: Operator.h:60
static bool isResultIndex(int i)
Definition: Operator.h:61
static bool isArgIndex(int i)
Definition: Operator.h:62
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
Definition: Operator.h:77
std::string getQualCppClassName() const
Returns this op's C++ class name prefixed with namespaces.
Definition: Operator.cpp:158
unsigned getNumSuccessors() const
Returns the number of successors.
Definition: Operator.cpp:308
const NamedRegion & getRegion(unsigned index) const
Returns the index-th region.
Definition: Operator.cpp:288
TypeConstraint getResultTypeConstraint(int index) const
Returns the index-th result's type constraint.
Definition: Operator.cpp:203
ArrayRef< SMLoc > getLoc() const
Definition: Operator.cpp:800
Operator(const llvm::Record &def)
llvm::iterator_range< const_region_iterator > getRegions() const
Definition: Operator.cpp:281
OperandOrAttribute getArgToOperandOrAttribute(int index) const
Returns the OperandOrAttribute corresponding to the index.
Definition: Operator.cpp:841
NamedTypeConstraint & getOperand(int index)
Definition: Operator.h:216
StringRef getCppNamespace() const
Returns this op's C++ namespace.
Definition: Operator.cpp:164
const_attribute_iterator attribute_begin() const
Definition: Operator.cpp:329
std::string getGetterName(StringRef name) const
Returns the getter name for the accessor of name.
Definition: Operator.cpp:846
const_successor_iterator successor_end() const
Definition: Operator.cpp:300
int getNumOperands() const
Definition: Operator.h:215
StringRef getDescription() const
Definition: Operator.cpp:806
const_value_range getResults() const
Definition: Operator.cpp:199
arg_range getArgs() const
Definition: Operator.cpp:242
const_value_range getOperands() const
Definition: Operator.cpp:355
const_region_iterator region_begin() const
Definition: Operator.cpp:275
bool useCustomPropertiesEncoding() const
Whether to generate the readProperty/writeProperty methods for bytecode emission.
Definition: Operator.cpp:860
StringRef getResultName(int index) const
Returns the index-th result's name.
Definition: Operator.cpp:208
var_decorator_range getArgDecorators(int index) const
Definition: Operator.cpp:251
unsigned getNumVariableLengthOperands() const
Returns the number of variadic operands in this operation.
Definition: Operator.cpp:227
var_decorator_range getResultDecorators(int index) const
Returns the index-th result's decorators.
Definition: Operator.cpp:213
std::string getGenericAdaptorName() const
Returns the name of op's generic adaptor C++ class.
Definition: Operator.cpp:77
StringRef getExtraClassDefinition() const
Returns this op's extra class definition code.
Definition: Operator.cpp:178
NamedTypeConstraint & getResult(int index)
Returns the op result at the given index.
Definition: Operator.h:155
const_value_iterator result_begin() const
Op result iterators.
Definition: Operator.cpp:191
const_attribute_iterator attribute_end() const
Definition: Operator.cpp:332
const_trait_iterator trait_end() const
Definition: Operator.cpp:322
std::string getAdaptorName() const
Returns the name of op's adaptor C++ class.
Definition: Operator.cpp:73
int getNumResults() const
Returns the number of results this op produces.
Definition: Operator.cpp:166
llvm::iterator_range< const_attribute_iterator > getAttributes() const
Definition: Operator.cpp:335
bool hasFolder() const
Definition: Operator.cpp:858
const_value_iterator operand_end() const
Definition: Operator.cpp:352
arg_iterator arg_end() const
Definition: Operator.cpp:240
int getNumArgs() const
Returns the total number of arguments.
Definition: Operator.h:225
const_value_iterator operand_begin() const
Op operand iterators.
Definition: Operator.cpp:349
void assertInvariants() const
Check invariants (like no duplicated or conflicted names) and abort the process if any invariant is b...
Definition: Operator.cpp:116
StringRef getArgName(int index) const
Definition: Operator.cpp:246
StringRef getDialectName() const
Returns this op's dialect name.
Definition: Operator.cpp:154
const_region_iterator region_end() const
Definition: Operator.cpp:278
unsigned getNumVariableLengthResults() const
Returns the number of variable length results in this operation.
Definition: Operator.cpp:221
bool hasSingleVariadicArg() const
Returns true of the operation has a single variadic arg.
Definition: Operator.cpp:233
unsigned getNumVariadicSuccessors() const
Returns the number of variadic successors in this operation.
Definition: Operator.cpp:314
StringRef getSummary() const
Definition: Operator.cpp:812
bool isVariadic() const
Returns true if this op has variable length operands or results.
Definition: Operator.cpp:361
llvm::iterator_range< const_trait_iterator > getTraits() const
Definition: Operator.cpp:325
const Trait * getTrait(llvm::StringRef trait) const
Returns the trait wrapper for the given MLIR C++ trait.
Definition: Operator.cpp:259
llvm::iterator_range< const_successor_iterator > getSuccessors() const
Definition: Operator.cpp:303
bool hasSummary() const
Definition: Operator.cpp:810
const_successor_iterator successor_begin() const
Definition: Operator.cpp:297
void print(llvm::raw_ostream &os) const
Prints the contents in this operator to the given os.
Definition: Operator.cpp:826
unsigned getNumRegions() const
Returns the number of regions.
Definition: Operator.cpp:286
const_trait_iterator trait_begin() const
Definition: Operator.cpp:319
StringRef getExtraClassDeclaration() const
Returns this op's extra class declaration code.
Definition: Operator.cpp:171
StringRef getAssemblyFormat() const
Definition: Operator.cpp:821
std::string getSetterName(StringRef name) const
Returns the setter name for the accessor of name.
Definition: Operator.cpp:850
std::string getOperationName() const
Returns the operation name.
Definition: Operator.cpp:65
const NamedSuccessor & getSuccessor(unsigned index) const
Returns the index-th successor.
Definition: Operator.cpp:310
StringRef getCppClassName() const
Returns this op's C++ class name.
Definition: Operator.cpp:156
bool allResultTypesKnown() const
Return whether all the result types are known.
Definition: Operator.h:320
bool hasAssemblyFormat() const
Query functions for the assembly format of the operator.
Definition: Operator.cpp:816
unsigned getNumVariadicRegions() const
Returns the number of variadic regions in this operation.
Definition: Operator.cpp:292
bool skipDefaultBuilders() const
Returns true if default builders should not be generated.
Definition: Operator.cpp:187
arg_iterator arg_begin() const
Op argument (attribute or operand) iterators.
Definition: Operator.cpp:238
const InferredResultType & getInferredResultType(int index) const
Return all arguments or type constraints with same type as result[index].
Definition: Operator.cpp:795
const llvm::Record & getDef() const
Returns the Tablegen definition this operator was constructed from.
Definition: Operator.cpp:185
const_value_iterator result_end() const
Definition: Operator.cpp:195
std::string getRemoverName(StringRef name) const
Returns the remove name for the accessor of name.
Definition: Operator.cpp:854
Argument getArg(int index) const
Op argument (attribute or operand) accessors.
Definition: Operator.cpp:359
bool hasDescription() const
Query functions for the documentation of the operator.
Definition: Operator.cpp:802
static Trait create(const llvm::Init *init)
Definition: Trait.cpp:28
bool isVariableLength() const
Definition: Type.h:53
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
const char * inferTypeOpInterface
Definition: Attribute.cpp:243
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
bool isVariadic() const
Definition: Region.h:33
Pair consisting kind of argument and index into operands or attributes.
Definition: Operator.h:327
static VariableDecorator unwrap(const llvm::Init *init)
Definition: Operator.cpp:836
A class used to represent the decorators of an operator variable, i.e.
Definition: Operator.h:110