18 #include "llvm/ADT/EquivalenceClasses.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/TableGen/Error.h"
28 #include "llvm/TableGen/Record.h"
31 #define DEBUG_TYPE "mlir-tblgen-operator"
41 using llvm::StringInit;
44 : dialect(def.getValueAsDef(
"opDialect")), def(def) {
50 std::tie(prefix, cppClassName) = def.getName().split(
'_');
53 cppClassName = def.getName();
54 }
else if (cppClassName.empty()) {
56 cppClassName = prefix;
59 cppNamespace = def.getValueAsString(
"cppNamespace");
61 populateOpStructure();
65 std::string Operator::getOperationName()
const {
66 auto prefix = dialect.
getName();
67 auto opName = def.getValueAsString(
"opName");
69 return std::string(opName);
70 return std::string(llvm::formatv(
"{0}.{1}", prefix, opName));
78 return std::string(llvm::formatv(
"{0}GenericAdaptor",
getCppClassName()));
83 std::string accessorName =
84 convertToCamelFromSnakeCase(name,
true);
94 auto nameOverlapsWithOpAPI = [&](StringRef newName) {
95 if (newName ==
"AttributeNames" || newName ==
"Attributes" ||
96 newName ==
"Operation")
98 if (newName ==
"Operands")
100 if (newName ==
"Regions")
102 if (newName ==
"Type")
106 if (nameOverlapsWithOpAPI(accessorName)) {
110 PrintFatalError(op.
getLoc(),
"generated accessor for `" + name +
111 "` overlaps with a default one; please "
112 "rename to avoid overlap");
119 auto checkName = [&](StringRef name, StringRef entity) {
122 auto insertion = existingNames.insert({name, entity});
123 if (insertion.second) {
128 if (entity == insertion.first->second)
129 PrintFatalError(
getLoc(),
"op has a conflict with two " + entity +
130 " having the same name '" + name +
"'");
131 PrintFatalError(
getLoc(),
"op has a conflict with " +
132 insertion.first->second +
" and " + entity +
133 " both having an entry with the name '" +
159 if (cppNamespace.empty())
160 return std::string(cppClassName);
161 return std::string(llvm::formatv(
"{0}::{1}", cppNamespace, cppClassName));
167 const DagInit *results = def.getValueAsDag(
"results");
168 return results->getNumArgs();
172 constexpr
auto attr =
"extraClassDeclaration";
173 if (def.isValueUnset(attr))
175 return def.getValueAsString(attr);
179 constexpr
auto attr =
"extraClassDefinition";
180 if (def.isValueUnset(attr))
182 return def.getValueAsString(attr);
188 return def.getValueAsBit(
"skipDefaultBuilders");
192 return results.begin();
196 return results.end();
204 const DagInit *results = def.getValueAsDag(
"results");
209 const DagInit *results = def.getValueAsDag(
"results");
210 return results->getArgNameStr(index);
214 const Record *result =
215 cast<DefInit>(def.getValueAsDag(
"results")->getArg(index))->getDef();
216 if (!result->isSubClassOf(
"OpVariable"))
218 return *result->getValueAsListInit(
"decorators");
247 const DagInit *argumentValues = def.getValueAsDag(
"arguments");
248 return argumentValues->getArgNameStr(index);
253 cast<DefInit>(def.getValueAsDag(
"arguments")->getArg(index))->getDef();
254 if (!arg->isSubClassOf(
"OpVariable"))
256 return *arg->getValueAsListInit(
"decorators");
260 for (
const auto &t : traits) {
261 if (
const auto *traitDef = dyn_cast<NativeTrait>(&t)) {
262 if (traitDef->getFullyQualifiedTraitName() == trait)
264 }
else if (
const auto *traitDef = dyn_cast<InternalTrait>(&t)) {
265 if (traitDef->getFullyQualifiedTraitName() == trait)
267 }
else if (
const auto *traitDef = dyn_cast<InterfaceTrait>(&t)) {
268 if (traitDef->getFullyQualifiedTraitName() == trait)
276 return regions.begin();
279 return regions.end();
289 return regions[index];
293 return llvm::count_if(regions,
298 return successors.begin();
301 return successors.end();
311 return successors[index];
315 return llvm::count_if(successors,
320 return traits.begin();
330 return attributes.begin();
333 return attributes.end();
340 return attributes.begin();
343 return attributes.end();
346 return {attribute_begin(), attribute_end()};
350 return operands.begin();
353 return operands.end();
362 return any_of(llvm::concat<const NamedTypeConstraint>(operands, results),
366 void Operator::populateTypeInferenceInfo(
367 const llvm::StringMap<int> &argumentsAndResultsIndex) {
370 auto &recordKeeper = def.getRecords();
372 allResultsHaveKnownTypes =
false;
387 if (
getTrait(
"::mlir::OpTrait::SameOperandsAndResultType")) {
389 auto *operandI = llvm::find_if(arguments, [](
const Argument &arg) {
393 if (operandI == arguments.end())
397 int operandIdx = operandI - arguments.begin();
399 resultTypeMapping.emplace_back(operandIdx,
"$_self");
401 allResultsHaveKnownTypes =
true;
411 struct ResultTypeInference {
415 bool inferred =
false;
425 if (
getResult(idx).constraint.getBuilderCall()) {
428 infer.inferred =
true;
434 for (
const Trait &trait : traits) {
435 const Record &def = trait.getDef();
441 if (def.isSubClassOf(
444 if (
const auto *traitDef = dyn_cast<InterfaceTrait>(&trait))
445 if (&traitDef->getDef() == inferTrait)
450 if (def.isSubClassOf(
"TypesMatchWith")) {
451 int target = argumentsAndResultsIndex.lookup(def.getValueAsString(
"rhs"));
456 ResultTypeInference &infer = inference[resultIndex];
461 argumentsAndResultsIndex.lookup(def.getValueAsString(
"lhs"));
462 infer.sources.emplace_back(sourceIndex,
463 def.getValueAsString(
"transformer").str());
471 if (!def.isSubClassOf(
"AllTypesMatch"))
474 auto values = def.getValueAsListOfStrings(
"values");
480 std::optional<int> fullyInferredIndex;
482 for (StringRef name : values) {
483 int index = argumentsAndResultsIndex.lookup(name);
488 fullyInferredIndex = index;
490 if (fullyInferredIndex) {
493 for (
int resultIndex : resultIndices) {
494 ResultTypeInference &infer = inference[resultIndex];
495 if (!infer.inferred) {
496 infer.sources.assign(1, {*fullyInferredIndex,
"$_self"});
497 infer.inferred =
true;
502 for (
int resultIndex : resultIndices) {
503 for (
int otherResultIndex : resultIndices) {
504 if (resultIndex == otherResultIndex)
506 inference[resultIndex].sources.emplace_back(otherResultIndex,
514 std::vector<ResultTypeInference *> worklist;
515 for (ResultTypeInference &infer : inference)
517 worklist.push_back(&infer);
521 for (
auto cur = worklist.begin(); cur != worklist.end();) {
522 ResultTypeInference &infer = **cur;
526 assert(InferredResultType::isResultIndex(source.getIndex()));
527 return inference[InferredResultType::unmapResultIndex(
531 if (iter == infer.sources.end()) {
537 infer.inferred =
true;
539 infer.sources.assign(1, *iter);
540 cur = worklist.erase(cur);
544 allResultsHaveKnownTypes = worklist.empty();
547 if (allResultsHaveKnownTypes) {
549 for (
const ResultTypeInference &infer : inference)
550 resultTypeMapping.push_back(infer.sources.front());
554 void Operator::populateOpStructure() {
555 auto &recordKeeper = def.getRecords();
556 auto *typeConstraintClass = recordKeeper.getClass(
"TypeConstraint");
557 auto *attrClass = recordKeeper.getClass(
"Attr");
558 auto *propertyClass = recordKeeper.getClass(
"Property");
559 auto *derivedAttrClass = recordKeeper.getClass(
"DerivedAttr");
560 auto *opVarClass = recordKeeper.getClass(
"OpVariable");
561 numNativeAttributes = 0;
563 const DagInit *argumentValues = def.getValueAsDag(
"arguments");
564 unsigned numArgs = argumentValues->getNumArgs();
568 llvm::StringMap<int> argumentsAndResultsIndex;
571 for (
unsigned i = 0; i != numArgs; ++i) {
572 auto *arg = argumentValues->getArg(i);
573 auto givenName = argumentValues->getArgNameStr(i);
574 auto *argDefInit = dyn_cast<DefInit>(arg);
576 PrintFatalError(def.getLoc(),
577 Twine(
"undefined type for argument #") + Twine(i));
578 const Record *argDef = argDefInit->getDef();
579 if (argDef->isSubClassOf(opVarClass))
580 argDef = argDef->getValueAsDef(
"constraint");
582 if (argDef->isSubClassOf(typeConstraintClass)) {
585 }
else if (argDef->isSubClassOf(attrClass)) {
586 if (givenName.empty())
587 PrintFatalError(argDef->getLoc(),
"attributes must be named");
588 if (argDef->isSubClassOf(derivedAttrClass))
589 PrintFatalError(argDef->getLoc(),
590 "derived attributes not allowed in argument list");
591 attributes.push_back({givenName,
Attribute(argDef)});
592 ++numNativeAttributes;
593 }
else if (argDef->isSubClassOf(propertyClass)) {
594 if (givenName.empty())
595 PrintFatalError(argDef->getLoc(),
"properties must be named");
596 properties.push_back({givenName,
Property(argDef)});
598 PrintFatalError(def.getLoc(),
599 "unexpected def type; only defs deriving "
600 "from TypeConstraint or Attr or Property are allowed");
602 if (!givenName.empty())
603 argumentsAndResultsIndex[givenName] = i;
607 for (
const auto &val : def.getValues()) {
608 if (
auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) {
609 if (!record->isSubClassOf(attrClass))
611 if (!record->isSubClassOf(derivedAttrClass))
612 PrintFatalError(def.getLoc(),
613 "unexpected Attr where only DerivedAttr is allowed");
615 if (record->getClasses().size() != 1) {
618 "unsupported attribute modelling, only single class expected");
620 attributes.push_back({cast<StringInit>(val.getNameInit())->getValue(),
621 Attribute(cast<DefInit>(val.getValue()))});
629 int operandIndex = 0, attrIndex = 0, propIndex = 0;
630 for (
unsigned i = 0; i != numArgs; ++i) {
631 const Record *argDef =
632 dyn_cast<DefInit>(argumentValues->getArg(i))->getDef();
633 if (argDef->isSubClassOf(opVarClass))
634 argDef = argDef->getValueAsDef(
"constraint");
636 if (argDef->isSubClassOf(typeConstraintClass)) {
637 attrOrOperandMapping.push_back(
639 arguments.emplace_back(&operands[operandIndex++]);
640 }
else if (argDef->isSubClassOf(attrClass)) {
641 attrOrOperandMapping.push_back(
643 arguments.emplace_back(&attributes[attrIndex++]);
645 assert(argDef->isSubClassOf(propertyClass));
646 arguments.emplace_back(&properties[propIndex++]);
650 auto *resultsDag = def.getValueAsDag(
"results");
651 auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator());
652 if (!outsOp || outsOp->getDef()->getName() !=
"outs") {
653 PrintFatalError(def.getLoc(),
"'results' must have 'outs' directive");
657 for (
unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) {
658 auto name = resultsDag->getArgNameStr(i);
659 auto *resultInit = dyn_cast<DefInit>(resultsDag->getArg(i));
661 PrintFatalError(def.getLoc(),
662 Twine(
"undefined type for result #") + Twine(i));
664 auto *resultDef = resultInit->getDef();
665 if (resultDef->isSubClassOf(opVarClass))
666 resultDef = resultDef->getValueAsDef(
"constraint");
672 if (results.back().constraint.isVariadicOfVariadic()) {
675 "'VariadicOfVariadic' results are currently not supported");
680 auto *successorsDag = def.getValueAsDag(
"successors");
681 auto *successorsOp = dyn_cast<DefInit>(successorsDag->getOperator());
682 if (!successorsOp || successorsOp->getDef()->getName() !=
"successor") {
683 PrintFatalError(def.getLoc(),
684 "'successors' must have 'successor' directive");
687 for (
unsigned i = 0, e = successorsDag->getNumArgs(); i < e; ++i) {
688 auto name = successorsDag->getArgNameStr(i);
689 auto *successorInit = dyn_cast<DefInit>(successorsDag->getArg(i));
690 if (!successorInit) {
691 PrintFatalError(def.getLoc(),
692 Twine(
"undefined kind for successor #") + Twine(i));
694 Successor successor(successorInit->getDef());
697 if (i != e - 1 && successor.isVariadic())
698 PrintFatalError(def.getLoc(),
"only the last successor can be variadic");
699 successors.push_back({name, successor});
704 if (
auto *traitList = def.getValueAsListInit(
"traits")) {
707 traits.reserve(traitSet.size());
714 auto verifyTraitValidity = [&](
const Record *trait) {
715 auto *dependentTraits = trait->getValueAsListInit(
"dependentTraits");
716 for (
auto *traitInit : *dependentTraits)
717 if (!traitSet.contains(traitInit))
720 trait->getValueAsString(
"trait") +
" requires " +
721 cast<DefInit>(traitInit)->getDef()->getValueAsString(
723 " to precede it in traits list");
726 std::function<void(
const ListInit *)> insert;
727 insert = [&](
const ListInit *traitList) {
728 for (
auto *traitInit : *traitList) {
729 auto *def = cast<DefInit>(traitInit)->getDef();
730 if (def->isSubClassOf(
"TraitList")) {
731 insert(def->getValueAsListInit(
"traits"));
736 if (!traitSet.insert(traitInit).second)
741 if (def->isSubClassOf(
"Interface"))
742 insert(def->getValueAsListInit(
"baseInterfaces"));
746 verifyTraitValidity(def);
753 populateTypeInferenceInfo(argumentsAndResultsIndex);
756 auto *regionsDag = def.getValueAsDag(
"regions");
757 auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator());
758 if (!regionsOp || regionsOp->getDef()->getName() !=
"region") {
759 PrintFatalError(def.getLoc(),
"'regions' must have 'region' directive");
762 for (
unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) {
763 auto name = regionsDag->getArgNameStr(i);
764 auto *regionInit = dyn_cast<DefInit>(regionsDag->getArg(i));
766 PrintFatalError(def.getLoc(),
767 Twine(
"undefined kind for region #") + Twine(i));
769 Region region(regionInit->getDef());
770 if (region.isVariadic()) {
773 PrintFatalError(def.getLoc(),
"only the last region can be variadic");
775 PrintFatalError(def.getLoc(),
"variadic regions must be named");
778 regions.push_back({name, region});
782 auto *builderList = dyn_cast_or_null<ListInit>(def.getValueInit(
"builders"));
783 if (builderList && !builderList->empty()) {
784 for (
const Init *init : builderList->getValues())
785 builders.emplace_back(cast<DefInit>(init)->
getDef(), def.getLoc());
789 "default builders are skipped and no custom builders provided");
792 LLVM_DEBUG(
print(llvm::dbgs()));
797 return resultTypeMapping[index];
807 return def.getValueAsString(
"description");
813 return def.getValueAsString(
"summary");
817 auto *valueInit = def.getValueInit(
"assemblyFormat");
818 return isa<StringInit>(valueInit);
823 .Case<StringInit>([&](
auto *init) {
return init->getValue(); });
829 if (
auto *attr = llvm::dyn_cast_if_present<NamedAttribute *>(arg))
830 os <<
"[attribute] " << attr->name <<
'\n';
843 return attrOrOperandMapping[index];
847 return "get" + convertToCamelFromSnakeCase(name,
true);
851 return "set" + convertToCamelFromSnakeCase(name,
true);
855 return "remove" + convertToCamelFromSnakeCase(name,
true);
861 return def.getValueAsBit(
"useCustomPropertiesEncoding");
static void assertAccessorInvariants(const Operator &op, StringRef name)
Assert the invariants of accessors generated for the given name.
Attributes are known-constant values of operations.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
StringRef getName() const
This class represents an inferred result type.
static int mapResultIndex(int i)
static int unmapResultIndex(int i)
static bool isResultIndex(int i)
static bool isArgIndex(int i)
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
std::string getQualCppClassName() const
Returns this op's C++ class name prefixed with namespaces.
unsigned getNumSuccessors() const
Returns the number of successors.
const NamedRegion & getRegion(unsigned index) const
Returns the index-th region.
TypeConstraint getResultTypeConstraint(int index) const
Returns the index-th result's type constraint.
ArrayRef< SMLoc > getLoc() const
Operator(const llvm::Record &def)
llvm::iterator_range< const_region_iterator > getRegions() const
OperandOrAttribute getArgToOperandOrAttribute(int index) const
Returns the OperandOrAttribute corresponding to the index.
NamedTypeConstraint & getOperand(int index)
StringRef getCppNamespace() const
Returns this op's C++ namespace.
const_attribute_iterator attribute_begin() const
std::string getGetterName(StringRef name) const
Returns the getter name for the accessor of name.
const_successor_iterator successor_end() const
int getNumOperands() const
StringRef getDescription() const
const_value_range getResults() const
arg_range getArgs() const
const_value_range getOperands() const
const_region_iterator region_begin() const
bool useCustomPropertiesEncoding() const
Whether to generate the readProperty/writeProperty methods for bytecode emission.
StringRef getResultName(int index) const
Returns the index-th result's name.
var_decorator_range getArgDecorators(int index) const
unsigned getNumVariableLengthOperands() const
Returns the number of variadic operands in this operation.
var_decorator_range getResultDecorators(int index) const
Returns the index-th result's decorators.
std::string getGenericAdaptorName() const
Returns the name of op's generic adaptor C++ class.
StringRef getExtraClassDefinition() const
Returns this op's extra class definition code.
NamedTypeConstraint & getResult(int index)
Returns the op result at the given index.
const_value_iterator result_begin() const
Op result iterators.
const_attribute_iterator attribute_end() const
const_trait_iterator trait_end() const
std::string getAdaptorName() const
Returns the name of op's adaptor C++ class.
int getNumResults() const
Returns the number of results this op produces.
llvm::iterator_range< const_attribute_iterator > getAttributes() const
const_value_iterator operand_end() const
arg_iterator arg_end() const
int getNumArgs() const
Returns the total number of arguments.
const_value_iterator operand_begin() const
Op operand iterators.
void assertInvariants() const
Check invariants (like no duplicated or conflicted names) and abort the process if any invariant is b...
StringRef getArgName(int index) const
StringRef getDialectName() const
Returns this op's dialect name.
const_region_iterator region_end() const
unsigned getNumVariableLengthResults() const
Returns the number of variable length results in this operation.
bool hasSingleVariadicArg() const
Returns true of the operation has a single variadic arg.
unsigned getNumVariadicSuccessors() const
Returns the number of variadic successors in this operation.
StringRef getSummary() const
bool isVariadic() const
Returns true if this op has variable length operands or results.
llvm::iterator_range< const_trait_iterator > getTraits() const
const Trait * getTrait(llvm::StringRef trait) const
Returns the trait wrapper for the given MLIR C++ trait.
llvm::iterator_range< const_successor_iterator > getSuccessors() const
const_successor_iterator successor_begin() const
void print(llvm::raw_ostream &os) const
Prints the contents in this operator to the given os.
unsigned getNumRegions() const
Returns the number of regions.
const_trait_iterator trait_begin() const
StringRef getExtraClassDeclaration() const
Returns this op's extra class declaration code.
StringRef getAssemblyFormat() const
std::string getSetterName(StringRef name) const
Returns the setter name for the accessor of name.
std::string getOperationName() const
Returns the operation name.
const NamedSuccessor & getSuccessor(unsigned index) const
Returns the index-th successor.
StringRef getCppClassName() const
Returns this op's C++ class name.
bool allResultTypesKnown() const
Return whether all the result types are known.
bool hasAssemblyFormat() const
Query functions for the assembly format of the operator.
unsigned getNumVariadicRegions() const
Returns the number of variadic regions in this operation.
bool skipDefaultBuilders() const
Returns true if default builders should not be generated.
arg_iterator arg_begin() const
Op argument (attribute or operand) iterators.
const InferredResultType & getInferredResultType(int index) const
Return all arguments or type constraints with same type as result[index].
const llvm::Record & getDef() const
Returns the Tablegen definition this operator was constructed from.
const_value_iterator result_end() const
std::string getRemoverName(StringRef name) const
Returns the remove name for the accessor of name.
Argument getArg(int index) const
Op argument (attribute or operand) accessors.
bool hasDescription() const
Query functions for the documentation of the operator.
static Trait create(const llvm::Init *init)
bool isVariableLength() const
The OpAsmOpInterface, see OpAsmInterface.td for more details.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
const char * inferTypeOpInterface
Include the generated interface declarations.
bool isVariableLength() const
TypeConstraint constraint
Pair consisting kind of argument and index into operands or attributes.
static VariableDecorator unwrap(const llvm::Init *init)
A class used to represent the decorators of an operator variable, i.e.