18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/Sequence.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/ErrorHandling.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/TableGen/Error.h"
27 #include "llvm/TableGen/Record.h"
29 #define DEBUG_TYPE "mlir-tblgen-operator"
39 using llvm::StringInit;
42 : dialect(def.getValueAsDef(
"opDialect")), def(def) {
48 std::tie(prefix, cppClassName) = def.getName().split(
'_');
51 cppClassName = def.getName();
52 }
else if (cppClassName.empty()) {
54 cppClassName = prefix;
57 cppNamespace = def.getValueAsString(
"cppNamespace");
59 populateOpStructure();
63 std::string Operator::getOperationName()
const {
64 auto prefix = dialect.
getName();
65 auto opName = def.getValueAsString(
"opName");
67 return std::string(opName);
68 return std::string(llvm::formatv(
"{0}.{1}", prefix, opName));
76 return std::string(llvm::formatv(
"{0}GenericAdaptor",
getCppClassName()));
81 std::string accessorName =
82 convertToCamelFromSnakeCase(name,
true);
92 auto nameOverlapsWithOpAPI = [&](StringRef newName) {
93 if (newName ==
"AttributeNames" || newName ==
"Attributes" ||
94 newName ==
"Operation")
96 if (newName ==
"Operands")
98 if (newName ==
"Regions")
100 if (newName ==
"Type")
104 if (nameOverlapsWithOpAPI(accessorName)) {
108 PrintFatalError(op.
getLoc(),
"generated accessor for `" + name +
109 "` overlaps with a default one; please "
110 "rename to avoid overlap");
117 auto checkName = [&](StringRef name, StringRef entity) {
120 auto insertion = existingNames.insert({name, entity});
121 if (insertion.second) {
126 if (entity == insertion.first->second)
127 PrintFatalError(
getLoc(),
"op has a conflict with two " + entity +
128 " having the same name '" + name +
"'");
129 PrintFatalError(
getLoc(),
"op has a conflict with " +
130 insertion.first->second +
" and " + entity +
131 " both having an entry with the name '" +
157 if (cppNamespace.empty())
158 return std::string(cppClassName);
159 return std::string(llvm::formatv(
"{0}::{1}", cppNamespace, cppClassName));
165 const DagInit *results = def.getValueAsDag(
"results");
166 return results->getNumArgs();
170 constexpr
auto attr =
"extraClassDeclaration";
171 if (def.isValueUnset(attr))
173 return def.getValueAsString(attr);
177 constexpr
auto attr =
"extraClassDefinition";
178 if (def.isValueUnset(attr))
180 return def.getValueAsString(attr);
186 return def.getValueAsBit(
"skipDefaultBuilders");
190 return results.begin();
194 return results.end();
202 const DagInit *results = def.getValueAsDag(
"results");
207 const DagInit *results = def.getValueAsDag(
"results");
208 return results->getArgNameStr(index);
212 const Record *result =
213 cast<DefInit>(def.getValueAsDag(
"results")->getArg(index))->getDef();
214 if (!result->isSubClassOf(
"OpVariable"))
216 return *result->getValueAsListInit(
"decorators");
245 const DagInit *argumentValues = def.getValueAsDag(
"arguments");
246 return argumentValues->getArgNameStr(index);
251 cast<DefInit>(def.getValueAsDag(
"arguments")->getArg(index))->getDef();
252 if (!arg->isSubClassOf(
"OpVariable"))
254 return *arg->getValueAsListInit(
"decorators");
258 for (
const auto &t : traits) {
259 if (
const auto *traitDef = dyn_cast<NativeTrait>(&t)) {
260 if (traitDef->getFullyQualifiedTraitName() == trait)
262 }
else if (
const auto *traitDef = dyn_cast<InternalTrait>(&t)) {
263 if (traitDef->getFullyQualifiedTraitName() == trait)
265 }
else if (
const auto *traitDef = dyn_cast<InterfaceTrait>(&t)) {
266 if (traitDef->getFullyQualifiedTraitName() == trait)
274 return regions.begin();
277 return regions.end();
287 return regions[index];
291 return llvm::count_if(regions,
296 return successors.begin();
299 return successors.end();
309 return successors[index];
313 return llvm::count_if(successors,
318 return traits.begin();
328 return attributes.begin();
331 return attributes.end();
338 return attributes.begin();
341 return attributes.end();
344 return {attribute_begin(), attribute_end()};
348 return operands.begin();
351 return operands.end();
360 return any_of(llvm::concat<const NamedTypeConstraint>(operands, results),
364 void Operator::populateTypeInferenceInfo(
365 const llvm::StringMap<int> &argumentsAndResultsIndex) {
368 auto &recordKeeper = def.getRecords();
370 allResultsHaveKnownTypes =
false;
385 if (
getTrait(
"::mlir::OpTrait::SameOperandsAndResultType")) {
387 auto *operandI = llvm::find_if(arguments, [](
const Argument &arg) {
389 llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg);
392 if (operandI == arguments.end())
396 int operandIdx = operandI - arguments.begin();
398 resultTypeMapping.emplace_back(operandIdx,
"$_self");
400 allResultsHaveKnownTypes =
true;
410 struct ResultTypeInference {
414 bool inferred =
false;
424 if (
getResult(idx).constraint.getBuilderCall()) {
427 infer.inferred =
true;
433 for (
const Trait &trait : traits) {
434 const Record &def = trait.getDef();
440 if (def.isSubClassOf(
443 if (
const auto *traitDef = dyn_cast<InterfaceTrait>(&trait))
444 if (&traitDef->getDef() == inferTrait)
449 if (def.isSubClassOf(
"TypesMatchWith")) {
450 int target = argumentsAndResultsIndex.lookup(def.getValueAsString(
"rhs"));
455 ResultTypeInference &infer = inference[resultIndex];
460 argumentsAndResultsIndex.lookup(def.getValueAsString(
"lhs"));
461 infer.sources.emplace_back(sourceIndex,
462 def.getValueAsString(
"transformer").str());
473 if (def.isSubClassOf(
"ShapedTypeMatchesElementCountAndTypes")) {
474 StringRef shapedArg = def.getValueAsString(
"shaped");
475 StringRef elementsArg = def.getValueAsString(
"elements");
477 int shapedIndex = argumentsAndResultsIndex.lookup(shapedArg);
478 int elementsIndex = argumentsAndResultsIndex.lookup(elementsArg);
484 ResultTypeInference &infer = inference[resultIndex];
485 if (!infer.inferred) {
486 infer.sources.emplace_back(
488 "::llvm::SmallVector<::mlir::Type>(::llvm::cast<::mlir::"
489 "ShapedType>($_self).getNumElements(), "
490 "::llvm::cast<::mlir::ShapedType>($_self).getElementType())");
491 infer.inferred =
true;
501 if (!def.isSubClassOf(
"AllTypesMatch"))
504 auto values = def.getValueAsListOfStrings(
"values");
510 std::optional<int> fullyInferredIndex;
512 for (StringRef name : values) {
513 int index = argumentsAndResultsIndex.lookup(name);
518 fullyInferredIndex = index;
520 if (fullyInferredIndex) {
523 for (
int resultIndex : resultIndices) {
524 ResultTypeInference &infer = inference[resultIndex];
525 if (!infer.inferred) {
526 infer.sources.assign(1, {*fullyInferredIndex,
"$_self"});
527 infer.inferred =
true;
532 for (
int resultIndex : resultIndices) {
533 for (
int otherResultIndex : resultIndices) {
534 if (resultIndex == otherResultIndex)
536 inference[resultIndex].sources.emplace_back(
544 std::vector<ResultTypeInference *> worklist;
545 for (ResultTypeInference &infer : inference)
547 worklist.push_back(&infer);
551 for (
auto cur = worklist.begin(); cur != worklist.end();) {
552 ResultTypeInference &infer = **cur;
556 assert(InferredResultType::isResultIndex(source.getIndex()));
557 return inference[InferredResultType::unmapResultIndex(
561 if (iter == infer.sources.end()) {
567 infer.inferred =
true;
569 infer.sources.assign(1, *iter);
570 cur = worklist.erase(cur);
574 allResultsHaveKnownTypes = worklist.empty();
577 if (allResultsHaveKnownTypes) {
579 for (
const ResultTypeInference &infer : inference)
580 resultTypeMapping.push_back(infer.sources.front());
584 void Operator::populateOpStructure() {
585 auto &recordKeeper = def.getRecords();
586 auto *typeConstraintClass = recordKeeper.getClass(
"TypeConstraint");
587 auto *attrClass = recordKeeper.getClass(
"Attr");
588 auto *propertyClass = recordKeeper.getClass(
"Property");
589 auto *derivedAttrClass = recordKeeper.getClass(
"DerivedAttr");
590 auto *opVarClass = recordKeeper.getClass(
"OpVariable");
591 numNativeAttributes = 0;
593 const DagInit *argumentValues = def.getValueAsDag(
"arguments");
594 unsigned numArgs = argumentValues->getNumArgs();
598 llvm::StringMap<int> argumentsAndResultsIndex;
601 for (
unsigned i = 0; i != numArgs; ++i) {
602 auto *arg = argumentValues->getArg(i);
603 auto givenName = argumentValues->getArgNameStr(i);
604 auto *argDefInit = dyn_cast<DefInit>(arg);
606 PrintFatalError(def.getLoc(),
607 Twine(
"undefined type for argument #") + Twine(i));
608 const Record *argDef = argDefInit->getDef();
609 if (argDef->isSubClassOf(opVarClass))
610 argDef = argDef->getValueAsDef(
"constraint");
612 if (argDef->isSubClassOf(typeConstraintClass)) {
615 }
else if (argDef->isSubClassOf(attrClass)) {
616 if (givenName.empty())
617 PrintFatalError(argDef->getLoc(),
"attributes must be named");
618 if (argDef->isSubClassOf(derivedAttrClass))
619 PrintFatalError(argDef->getLoc(),
620 "derived attributes not allowed in argument list");
621 attributes.push_back({givenName,
Attribute(argDef)});
622 ++numNativeAttributes;
623 }
else if (argDef->isSubClassOf(propertyClass)) {
624 if (givenName.empty())
625 PrintFatalError(argDef->getLoc(),
"properties must be named");
626 properties.push_back({givenName,
Property(argDef)});
628 PrintFatalError(def.getLoc(),
629 "unexpected def type; only defs deriving "
630 "from TypeConstraint or Attr or Property are allowed");
632 if (!givenName.empty())
633 argumentsAndResultsIndex[givenName] = i;
637 for (
const auto &val : def.getValues()) {
638 if (
auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) {
639 if (!record->isSubClassOf(attrClass))
641 if (!record->isSubClassOf(derivedAttrClass))
642 PrintFatalError(def.getLoc(),
643 "unexpected Attr where only DerivedAttr is allowed");
645 if (record->getClasses().size() != 1) {
648 "unsupported attribute modelling, only single class expected");
650 attributes.push_back({cast<StringInit>(val.getNameInit())->getValue(),
651 Attribute(cast<DefInit>(val.getValue()))});
659 int operandIndex = 0, attrIndex = 0, propIndex = 0;
660 for (
unsigned i = 0; i != numArgs; ++i) {
661 const Record *argDef =
662 dyn_cast<DefInit>(argumentValues->getArg(i))->getDef();
663 if (argDef->isSubClassOf(opVarClass))
664 argDef = argDef->getValueAsDef(
"constraint");
666 if (argDef->isSubClassOf(typeConstraintClass)) {
667 attrPropOrOperandMapping.push_back(
669 arguments.emplace_back(&operands[operandIndex++]);
670 }
else if (argDef->isSubClassOf(attrClass)) {
671 attrPropOrOperandMapping.push_back(
673 arguments.emplace_back(&attributes[attrIndex++]);
675 assert(argDef->isSubClassOf(propertyClass));
676 attrPropOrOperandMapping.push_back(
678 arguments.emplace_back(&properties[propIndex++]);
682 auto *resultsDag = def.getValueAsDag(
"results");
683 auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator());
684 if (!outsOp || outsOp->getDef()->getName() !=
"outs") {
685 PrintFatalError(def.getLoc(),
"'results' must have 'outs' directive");
689 for (
unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) {
690 auto name = resultsDag->getArgNameStr(i);
691 auto *resultInit = dyn_cast<DefInit>(resultsDag->getArg(i));
693 PrintFatalError(def.getLoc(),
694 Twine(
"undefined type for result #") + Twine(i));
696 auto *resultDef = resultInit->getDef();
697 if (resultDef->isSubClassOf(opVarClass))
698 resultDef = resultDef->getValueAsDef(
"constraint");
704 if (results.back().constraint.isVariadicOfVariadic()) {
707 "'VariadicOfVariadic' results are currently not supported");
712 auto *successorsDag = def.getValueAsDag(
"successors");
713 auto *successorsOp = dyn_cast<DefInit>(successorsDag->getOperator());
714 if (!successorsOp || successorsOp->getDef()->getName() !=
"successor") {
715 PrintFatalError(def.getLoc(),
716 "'successors' must have 'successor' directive");
719 for (
unsigned i = 0, e = successorsDag->getNumArgs(); i < e; ++i) {
720 auto name = successorsDag->getArgNameStr(i);
721 auto *successorInit = dyn_cast<DefInit>(successorsDag->getArg(i));
722 if (!successorInit) {
723 PrintFatalError(def.getLoc(),
724 Twine(
"undefined kind for successor #") + Twine(i));
726 Successor successor(successorInit->getDef());
729 if (i != e - 1 && successor.isVariadic())
730 PrintFatalError(def.getLoc(),
"only the last successor can be variadic");
731 successors.push_back({name, successor});
736 if (
auto *traitList = def.getValueAsListInit(
"traits")) {
739 traits.reserve(traitSet.size());
746 auto verifyTraitValidity = [&](
const Record *trait) {
747 auto *dependentTraits = trait->getValueAsListInit(
"dependentTraits");
748 for (
auto *traitInit : *dependentTraits)
749 if (!traitSet.contains(traitInit))
752 trait->getValueAsString(
"trait") +
" requires " +
753 cast<DefInit>(traitInit)->getDef()->getValueAsString(
755 " to precede it in traits list");
758 std::function<void(
const ListInit *)> insert;
759 insert = [&](
const ListInit *traitList) {
760 for (
auto *traitInit : *traitList) {
761 auto *def = cast<DefInit>(traitInit)->getDef();
762 if (def->isSubClassOf(
"TraitList")) {
763 insert(def->getValueAsListInit(
"traits"));
768 if (!traitSet.insert(traitInit).second)
773 if (def->isSubClassOf(
"Interface"))
774 insert(def->getValueAsListInit(
"baseInterfaces"));
778 verifyTraitValidity(def);
785 populateTypeInferenceInfo(argumentsAndResultsIndex);
788 auto *regionsDag = def.getValueAsDag(
"regions");
789 auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator());
790 if (!regionsOp || regionsOp->getDef()->getName() !=
"region") {
791 PrintFatalError(def.getLoc(),
"'regions' must have 'region' directive");
794 for (
unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) {
795 auto name = regionsDag->getArgNameStr(i);
796 auto *regionInit = dyn_cast<DefInit>(regionsDag->getArg(i));
798 PrintFatalError(def.getLoc(),
799 Twine(
"undefined kind for region #") + Twine(i));
801 Region region(regionInit->getDef());
802 if (region.isVariadic()) {
805 PrintFatalError(def.getLoc(),
"only the last region can be variadic");
807 PrintFatalError(def.getLoc(),
"variadic regions must be named");
810 regions.push_back({name, region});
814 auto *builderList = dyn_cast_or_null<ListInit>(def.getValueInit(
"builders"));
815 if (builderList && !builderList->empty()) {
816 for (
const Init *init : builderList->getElements())
817 builders.emplace_back(cast<DefInit>(init)->
getDef(), def.getLoc());
821 "default builders are skipped and no custom builders provided");
824 LLVM_DEBUG(
print(llvm::dbgs()));
829 return resultTypeMapping[index];
839 return def.getValueAsString(
"description");
845 return def.getValueAsString(
"summary");
849 auto *valueInit = def.getValueInit(
"assemblyFormat");
850 return isa<StringInit>(valueInit);
855 .Case<StringInit>([&](
auto *init) {
return init->getValue(); });
861 if (
auto *attr = llvm::dyn_cast_if_present<NamedAttribute *>(arg))
862 os <<
"[attribute] " << attr->name <<
'\n';
864 os <<
"[operand] " << cast<NamedTypeConstraint *>(arg)->name <<
'\n';
874 return attrPropOrOperandMapping[index];
878 return "get" + convertToCamelFromSnakeCase(name,
true);
882 return "set" + convertToCamelFromSnakeCase(name,
true);
886 return "remove" + convertToCamelFromSnakeCase(name,
true);
892 return def.getValueAsBit(
"useCustomPropertiesEncoding");
static void assertAccessorInvariants(const Operator &op, StringRef name)
Assert the invariants of accessors generated for the given name.
Attributes are known-constant values of operations.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
StringRef getName() const
This class represents an inferred result type.
static int mapResultIndex(int i)
static int unmapResultIndex(int i)
static bool isResultIndex(int i)
static bool isArgIndex(int i)
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
std::string getQualCppClassName() const
Returns this op's C++ class name prefixed with namespaces.
unsigned getNumSuccessors() const
Returns the number of successors.
const NamedRegion & getRegion(unsigned index) const
Returns the index-th region.
TypeConstraint getResultTypeConstraint(int index) const
Returns the index-th result's type constraint.
ArrayRef< SMLoc > getLoc() const
Operator(const llvm::Record &def)
llvm::iterator_range< const_region_iterator > getRegions() const
NamedTypeConstraint & getOperand(int index)
StringRef getCppNamespace() const
Returns this op's C++ namespace.
const_attribute_iterator attribute_begin() const
std::string getGetterName(StringRef name) const
Returns the getter name for the accessor of name.
const_successor_iterator successor_end() const
int getNumOperands() const
StringRef getDescription() const
const_value_range getResults() const
arg_range getArgs() const
const_value_range getOperands() const
const_region_iterator region_begin() const
bool useCustomPropertiesEncoding() const
Whether to generate the readProperty/writeProperty methods for bytecode emission.
StringRef getResultName(int index) const
Returns the index-th result's name.
var_decorator_range getArgDecorators(int index) const
unsigned getNumVariableLengthOperands() const
Returns the number of variadic operands in this operation.
OperandAttrOrProp getArgToOperandAttrOrProp(int index) const
Returns the OperandAttrOrProp corresponding to the index.
var_decorator_range getResultDecorators(int index) const
Returns the index-th result's decorators.
std::string getGenericAdaptorName() const
Returns the name of op's generic adaptor C++ class.
StringRef getExtraClassDefinition() const
Returns this op's extra class definition code.
NamedTypeConstraint & getResult(int index)
Returns the op result at the given index.
const_value_iterator result_begin() const
Op result iterators.
const_attribute_iterator attribute_end() const
const_trait_iterator trait_end() const
std::string getAdaptorName() const
Returns the name of op's adaptor C++ class.
int getNumResults() const
Returns the number of results this op produces.
llvm::iterator_range< const_attribute_iterator > getAttributes() const
const_value_iterator operand_end() const
arg_iterator arg_end() const
int getNumArgs() const
Returns the total number of arguments.
const_value_iterator operand_begin() const
Op operand iterators.
void assertInvariants() const
Check invariants (like no duplicated or conflicted names) and abort the process if any invariant is b...
StringRef getArgName(int index) const
StringRef getDialectName() const
Returns this op's dialect name.
const_region_iterator region_end() const
unsigned getNumVariableLengthResults() const
Returns the number of variable length results in this operation.
bool hasSingleVariadicArg() const
Returns true of the operation has a single variadic arg.
unsigned getNumVariadicSuccessors() const
Returns the number of variadic successors in this operation.
StringRef getSummary() const
bool isVariadic() const
Returns true if this op has variable length operands or results.
llvm::iterator_range< const_trait_iterator > getTraits() const
const Trait * getTrait(llvm::StringRef trait) const
Returns the trait wrapper for the given MLIR C++ trait.
llvm::iterator_range< const_successor_iterator > getSuccessors() const
const_successor_iterator successor_begin() const
void print(llvm::raw_ostream &os) const
Prints the contents in this operator to the given os.
unsigned getNumRegions() const
Returns the number of regions.
const_trait_iterator trait_begin() const
StringRef getExtraClassDeclaration() const
Returns this op's extra class declaration code.
StringRef getAssemblyFormat() const
std::string getSetterName(StringRef name) const
Returns the setter name for the accessor of name.
std::string getOperationName() const
Returns the operation name.
const NamedSuccessor & getSuccessor(unsigned index) const
Returns the index-th successor.
StringRef getCppClassName() const
Returns this op's C++ class name.
bool allResultTypesKnown() const
Return whether all the result types are known.
bool hasAssemblyFormat() const
Query functions for the assembly format of the operator.
unsigned getNumVariadicRegions() const
Returns the number of variadic regions in this operation.
bool skipDefaultBuilders() const
Returns true if default builders should not be generated.
arg_iterator arg_begin() const
Op argument (attribute or operand) iterators.
const InferredResultType & getInferredResultType(int index) const
Return all arguments or type constraints with same type as result[index].
const llvm::Record & getDef() const
Returns the Tablegen definition this operator was constructed from.
const_value_iterator result_end() const
std::string getRemoverName(StringRef name) const
Returns the remove name for the accessor of name.
Argument getArg(int index) const
Op argument (attribute or operand) accessors.
bool hasDescription() const
Query functions for the documentation of the operator.
static Trait create(const llvm::Init *init)
bool isVariableLength() const
The OpAsmOpInterface, see OpAsmInterface.td for more details.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
const char * inferTypeOpInterface
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
bool isVariableLength() const
TypeConstraint constraint
Pair consisting kind of argument and index into operands, attributes, or properties.
static VariableDecorator unwrap(const llvm::Init *init)
A class used to represent the decorators of an operator variable, i.e.