18 #include "llvm/ADT/EquivalenceClasses.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/Sequence.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/TableGen/Error.h"
28 #include "llvm/TableGen/Record.h"
31 #define DEBUG_TYPE "mlir-tblgen-operator"
41 : dialect(def.getValueAsDef(
"opDialect")), def(def) {
47 std::tie(prefix, cppClassName) = def.getName().split(
'_');
50 cppClassName = def.getName();
51 }
else if (cppClassName.empty()) {
53 cppClassName = prefix;
56 cppNamespace = def.getValueAsString(
"cppNamespace");
58 populateOpStructure();
63 auto prefix = dialect.
getName();
64 auto opName = def.getValueAsString(
"opName");
66 return std::string(opName);
67 return std::string(llvm::formatv(
"{0}.{1}", prefix, opName));
75 return std::string(llvm::formatv(
"{0}GenericAdaptor",
getCppClassName()));
80 std::string accessorName =
81 convertToCamelFromSnakeCase(name,
true);
91 auto nameOverlapsWithOpAPI = [&](StringRef newName) {
92 if (newName ==
"AttributeNames" || newName ==
"Attributes" ||
93 newName ==
"Operation")
95 if (newName ==
"Operands")
96 return op.
getNumOperands() != 1 || op.getNumVariableLengthOperands() != 1;
97 if (newName ==
"Regions")
98 return op.
getNumRegions() != 1 || op.getNumVariadicRegions() != 1;
99 if (newName ==
"Type")
103 if (nameOverlapsWithOpAPI(accessorName)) {
107 PrintFatalError(op.
getLoc(),
"generated accessor for `" + name +
108 "` overlaps with a default one; please "
109 "rename to avoid overlap");
116 auto checkName = [&](StringRef name, StringRef entity) {
119 auto insertion = existingNames.insert({name, entity});
120 if (insertion.second) {
125 if (entity == insertion.first->second)
126 PrintFatalError(
getLoc(),
"op has a conflict with two " + entity +
127 " having the same name '" + name +
"'");
128 PrintFatalError(
getLoc(),
"op has a conflict with " +
129 insertion.first->second +
" and " + entity +
130 " both having an entry with the name '" +
156 if (cppNamespace.empty())
157 return std::string(cppClassName);
158 return std::string(llvm::formatv(
"{0}::{1}", cppNamespace, cppClassName));
164 DagInit *results = def.getValueAsDag(
"results");
165 return results->getNumArgs();
169 constexpr
auto attr =
"extraClassDeclaration";
170 if (def.isValueUnset(attr))
172 return def.getValueAsString(attr);
176 constexpr
auto attr =
"extraClassDefinition";
177 if (def.isValueUnset(attr))
179 return def.getValueAsString(attr);
185 return def.getValueAsBit(
"skipDefaultBuilders");
189 return results.begin();
193 return results.end();
201 DagInit *results = def.getValueAsDag(
"results");
206 DagInit *results = def.getValueAsDag(
"results");
207 return results->getArgNameStr(index);
212 cast<DefInit>(def.getValueAsDag(
"results")->getArg(index))->getDef();
213 if (!result->isSubClassOf(
"OpVariable"))
215 return *result->getValueAsListInit(
"decorators");
244 DagInit *argumentValues = def.getValueAsDag(
"arguments");
245 return argumentValues->getArgNameStr(index);
250 cast<DefInit>(def.getValueAsDag(
"arguments")->getArg(index))->getDef();
251 if (!arg->isSubClassOf(
"OpVariable"))
253 return *arg->getValueAsListInit(
"decorators");
257 for (
const auto &t : traits) {
258 if (
const auto *traitDef = dyn_cast<NativeTrait>(&t)) {
259 if (traitDef->getFullyQualifiedTraitName() == trait)
261 }
else if (
const auto *traitDef = dyn_cast<InternalTrait>(&t)) {
262 if (traitDef->getFullyQualifiedTraitName() == trait)
264 }
else if (
const auto *traitDef = dyn_cast<InterfaceTrait>(&t)) {
265 if (traitDef->getFullyQualifiedTraitName() == trait)
273 return regions.begin();
276 return regions.end();
286 return regions[index];
290 return llvm::count_if(regions,
295 return successors.begin();
298 return successors.end();
308 return successors[index];
312 return llvm::count_if(successors,
317 return traits.begin();
327 return attributes.begin();
330 return attributes.end();
337 return attributes.begin();
340 return attributes.end();
343 return {attribute_begin(), attribute_end()};
347 return operands.begin();
350 return operands.end();
359 return any_of(llvm::concat<const NamedTypeConstraint>(operands, results),
363 void Operator::populateTypeInferenceInfo(
364 const llvm::StringMap<int> &argumentsAndResultsIndex) {
367 auto &recordKeeper = def.getRecords();
369 allResultsHaveKnownTypes =
false;
384 if (
getTrait(
"::mlir::OpTrait::SameOperandsAndResultType")) {
386 auto *operandI = llvm::find_if(arguments, [](
const Argument &arg) {
390 if (operandI == arguments.end())
394 int operandIdx = operandI - arguments.begin();
396 resultTypeMapping.emplace_back(operandIdx,
"$_self");
398 allResultsHaveKnownTypes =
true;
408 struct ResultTypeInference {
412 bool inferred =
false;
422 if (
getResult(idx).constraint.getBuilderCall()) {
425 infer.inferred =
true;
431 for (
const Trait &trait : traits) {
432 const llvm::Record &def = trait.getDef();
438 if (def.isSubClassOf(
441 if (
const auto *traitDef = dyn_cast<InterfaceTrait>(&trait))
442 if (&traitDef->getDef() == inferTrait)
447 if (def.isSubClassOf(
"TypesMatchWith")) {
448 int target = argumentsAndResultsIndex.lookup(def.getValueAsString(
"rhs"));
453 ResultTypeInference &infer = inference[resultIndex];
458 argumentsAndResultsIndex.lookup(def.getValueAsString(
"lhs"));
459 infer.sources.emplace_back(sourceIndex,
460 def.getValueAsString(
"transformer").str());
468 if (!def.isSubClassOf(
"AllTypesMatch"))
471 auto values = def.getValueAsListOfStrings(
"values");
477 std::optional<int> fullyInferredIndex;
479 for (StringRef name : values) {
480 int index = argumentsAndResultsIndex.lookup(name);
485 fullyInferredIndex = index;
487 if (fullyInferredIndex) {
490 for (
int resultIndex : resultIndices) {
491 ResultTypeInference &infer = inference[resultIndex];
492 if (!infer.inferred) {
493 infer.sources.assign(1, {*fullyInferredIndex,
"$_self"});
494 infer.inferred =
true;
499 for (
int resultIndex : resultIndices) {
500 for (
int otherResultIndex : resultIndices) {
501 if (resultIndex == otherResultIndex)
503 inference[resultIndex].sources.emplace_back(otherResultIndex,
511 std::vector<ResultTypeInference *> worklist;
512 for (ResultTypeInference &infer : inference)
514 worklist.push_back(&infer);
518 for (
auto cur = worklist.begin(); cur != worklist.end();) {
519 ResultTypeInference &infer = **cur;
523 assert(InferredResultType::isResultIndex(source.getIndex()));
524 return inference[InferredResultType::unmapResultIndex(
528 if (iter == infer.sources.end()) {
534 infer.inferred =
true;
536 infer.sources.assign(1, *iter);
537 cur = worklist.erase(cur);
541 allResultsHaveKnownTypes = worklist.empty();
544 if (allResultsHaveKnownTypes) {
546 for (
const ResultTypeInference &infer : inference)
547 resultTypeMapping.push_back(infer.sources.front());
551 void Operator::populateOpStructure() {
552 auto &recordKeeper = def.getRecords();
553 auto *typeConstraintClass = recordKeeper.getClass(
"TypeConstraint");
554 auto *attrClass = recordKeeper.getClass(
"Attr");
555 auto *propertyClass = recordKeeper.getClass(
"Property");
556 auto *derivedAttrClass = recordKeeper.getClass(
"DerivedAttr");
557 auto *opVarClass = recordKeeper.getClass(
"OpVariable");
558 numNativeAttributes = 0;
560 DagInit *argumentValues = def.getValueAsDag(
"arguments");
561 unsigned numArgs = argumentValues->getNumArgs();
565 llvm::StringMap<int> argumentsAndResultsIndex;
568 for (
unsigned i = 0; i != numArgs; ++i) {
569 auto *arg = argumentValues->getArg(i);
570 auto givenName = argumentValues->getArgNameStr(i);
571 auto *argDefInit = dyn_cast<DefInit>(arg);
573 PrintFatalError(def.getLoc(),
574 Twine(
"undefined type for argument #") + Twine(i));
575 Record *argDef = argDefInit->getDef();
576 if (argDef->isSubClassOf(opVarClass))
577 argDef = argDef->getValueAsDef(
"constraint");
579 if (argDef->isSubClassOf(typeConstraintClass)) {
582 }
else if (argDef->isSubClassOf(attrClass)) {
583 if (givenName.empty())
584 PrintFatalError(argDef->getLoc(),
"attributes must be named");
585 if (argDef->isSubClassOf(derivedAttrClass))
586 PrintFatalError(argDef->getLoc(),
587 "derived attributes not allowed in argument list");
588 attributes.push_back({givenName,
Attribute(argDef)});
589 ++numNativeAttributes;
590 }
else if (argDef->isSubClassOf(propertyClass)) {
591 if (givenName.empty())
592 PrintFatalError(argDef->getLoc(),
"properties must be named");
593 properties.push_back({givenName,
Property(argDef)});
595 PrintFatalError(def.getLoc(),
596 "unexpected def type; only defs deriving "
597 "from TypeConstraint or Attr or Property are allowed");
599 if (!givenName.empty())
600 argumentsAndResultsIndex[givenName] = i;
604 for (
const auto &val : def.getValues()) {
605 if (
auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) {
606 if (!record->isSubClassOf(attrClass))
608 if (!record->isSubClassOf(derivedAttrClass))
609 PrintFatalError(def.getLoc(),
610 "unexpected Attr where only DerivedAttr is allowed");
612 if (record->getClasses().size() != 1) {
615 "unsupported attribute modelling, only single class expected");
617 attributes.push_back(
618 {cast<llvm::StringInit>(val.getNameInit())->getValue(),
619 Attribute(cast<DefInit>(val.getValue()))});
627 int operandIndex = 0, attrIndex = 0, propIndex = 0;
628 for (
unsigned i = 0; i != numArgs; ++i) {
629 Record *argDef = dyn_cast<DefInit>(argumentValues->getArg(i))->getDef();
630 if (argDef->isSubClassOf(opVarClass))
631 argDef = argDef->getValueAsDef(
"constraint");
633 if (argDef->isSubClassOf(typeConstraintClass)) {
634 attrOrOperandMapping.push_back(
636 arguments.emplace_back(&operands[operandIndex++]);
637 }
else if (argDef->isSubClassOf(attrClass)) {
638 attrOrOperandMapping.push_back(
640 arguments.emplace_back(&attributes[attrIndex++]);
642 assert(argDef->isSubClassOf(propertyClass));
643 arguments.emplace_back(&properties[propIndex++]);
647 auto *resultsDag = def.getValueAsDag(
"results");
648 auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator());
649 if (!outsOp || outsOp->getDef()->getName() !=
"outs") {
650 PrintFatalError(def.getLoc(),
"'results' must have 'outs' directive");
654 for (
unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) {
655 auto name = resultsDag->getArgNameStr(i);
656 auto *resultInit = dyn_cast<DefInit>(resultsDag->getArg(i));
658 PrintFatalError(def.getLoc(),
659 Twine(
"undefined type for result #") + Twine(i));
661 auto *resultDef = resultInit->getDef();
662 if (resultDef->isSubClassOf(opVarClass))
663 resultDef = resultDef->getValueAsDef(
"constraint");
669 if (results.back().constraint.isVariadicOfVariadic()) {
672 "'VariadicOfVariadic' results are currently not supported");
677 auto *successorsDag = def.getValueAsDag(
"successors");
678 auto *successorsOp = dyn_cast<DefInit>(successorsDag->getOperator());
679 if (!successorsOp || successorsOp->getDef()->getName() !=
"successor") {
680 PrintFatalError(def.getLoc(),
681 "'successors' must have 'successor' directive");
684 for (
unsigned i = 0, e = successorsDag->getNumArgs(); i < e; ++i) {
685 auto name = successorsDag->getArgNameStr(i);
686 auto *successorInit = dyn_cast<DefInit>(successorsDag->getArg(i));
687 if (!successorInit) {
688 PrintFatalError(def.getLoc(),
689 Twine(
"undefined kind for successor #") + Twine(i));
691 Successor successor(successorInit->getDef());
694 if (i != e - 1 && successor.isVariadic())
695 PrintFatalError(def.getLoc(),
"only the last successor can be variadic");
696 successors.push_back({name, successor});
701 if (
auto *traitList = def.getValueAsListInit(
"traits")) {
704 traits.reserve(traitSet.size());
711 auto verifyTraitValidity = [&](Record *trait) {
712 auto *dependentTraits = trait->getValueAsListInit(
"dependentTraits");
713 for (
auto *traitInit : *dependentTraits)
714 if (!traitSet.contains(traitInit))
717 trait->getValueAsString(
"trait") +
" requires " +
718 cast<DefInit>(traitInit)->getDef()->getValueAsString(
720 " to precede it in traits list");
723 std::function<void(llvm::ListInit *)> insert;
724 insert = [&](llvm::ListInit *traitList) {
725 for (
auto *traitInit : *traitList) {
726 auto *def = cast<DefInit>(traitInit)->getDef();
727 if (def->isSubClassOf(
"TraitList")) {
728 insert(def->getValueAsListInit(
"traits"));
733 if (!traitSet.insert(traitInit).second)
738 if (def->isSubClassOf(
"Interface"))
739 insert(def->getValueAsListInit(
"baseInterfaces"));
743 verifyTraitValidity(def);
750 populateTypeInferenceInfo(argumentsAndResultsIndex);
753 auto *regionsDag = def.getValueAsDag(
"regions");
754 auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator());
755 if (!regionsOp || regionsOp->getDef()->getName() !=
"region") {
756 PrintFatalError(def.getLoc(),
"'regions' must have 'region' directive");
759 for (
unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) {
760 auto name = regionsDag->getArgNameStr(i);
761 auto *regionInit = dyn_cast<DefInit>(regionsDag->getArg(i));
763 PrintFatalError(def.getLoc(),
764 Twine(
"undefined kind for region #") + Twine(i));
766 Region region(regionInit->getDef());
767 if (region.isVariadic()) {
770 PrintFatalError(def.getLoc(),
"only the last region can be variadic");
772 PrintFatalError(def.getLoc(),
"variadic regions must be named");
775 regions.push_back({name, region});
780 dyn_cast_or_null<llvm::ListInit>(def.getValueInit(
"builders"));
781 if (builderList && !builderList->empty()) {
782 for (llvm::Init *init : builderList->getValues())
783 builders.emplace_back(cast<llvm::DefInit>(init)->
getDef(), def.getLoc());
787 "default builders are skipped and no custom builders provided");
790 LLVM_DEBUG(
print(llvm::dbgs()));
795 return resultTypeMapping[index];
805 return def.getValueAsString(
"description");
811 return def.getValueAsString(
"summary");
815 auto *valueInit = def.getValueInit(
"assemblyFormat");
816 return isa<llvm::StringInit>(valueInit);
821 .Case<llvm::StringInit>([&](
auto *init) {
return init->getValue(); });
827 if (
auto *attr = llvm::dyn_cast_if_present<NamedAttribute *>(arg))
828 os <<
"[attribute] " << attr->name <<
'\n';
841 return attrOrOperandMapping[index];
845 return "get" + convertToCamelFromSnakeCase(name,
true);
849 return "set" + convertToCamelFromSnakeCase(name,
true);
853 return "remove" + convertToCamelFromSnakeCase(name,
true);
859 return def.getValueAsBit(
"useCustomPropertiesEncoding");
static void assertAccessorInvariants(const Operator &op, StringRef name)
Assert the invariants of accessors generated for the given name.
Attributes are known-constant values of operations.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
StringRef getName() const
This class represents an inferred result type.
static int mapResultIndex(int i)
static int unmapResultIndex(int i)
static bool isResultIndex(int i)
static bool isArgIndex(int i)
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
std::string getQualCppClassName() const
Returns this op's C++ class name prefixed with namespaces.
unsigned getNumSuccessors() const
Returns the number of successors.
const NamedRegion & getRegion(unsigned index) const
Returns the index-th region.
TypeConstraint getResultTypeConstraint(int index) const
Returns the index-th result's type constraint.
ArrayRef< SMLoc > getLoc() const
llvm::iterator_range< const_region_iterator > getRegions() const
OperandOrAttribute getArgToOperandOrAttribute(int index) const
Returns the OperandOrAttribute corresponding to the index.
NamedTypeConstraint & getOperand(int index)
StringRef getCppNamespace() const
Returns this op's C++ namespace.
const_attribute_iterator attribute_begin() const
std::string getGetterName(StringRef name) const
Returns the getter name for the accessor of name.
const_successor_iterator successor_end() const
int getNumOperands() const
StringRef getDescription() const
const_value_range getResults() const
arg_range getArgs() const
const_value_range getOperands() const
const_region_iterator region_begin() const
bool useCustomPropertiesEncoding() const
Whether to generate the readProperty/writeProperty methods for bytecode emission.
StringRef getResultName(int index) const
Returns the index-th result's name.
var_decorator_range getArgDecorators(int index) const
unsigned getNumVariableLengthOperands() const
Returns the number of variadic operands in this operation.
var_decorator_range getResultDecorators(int index) const
Returns the index-th result's decorators.
std::string getGenericAdaptorName() const
Returns the name of op's generic adaptor C++ class.
StringRef getExtraClassDefinition() const
Returns this op's extra class definition code.
NamedTypeConstraint & getResult(int index)
Returns the op result at the given index.
const_value_iterator result_begin() const
Op result iterators.
const_attribute_iterator attribute_end() const
const_trait_iterator trait_end() const
std::string getAdaptorName() const
Returns the name of op's adaptor C++ class.
int getNumResults() const
Returns the number of results this op produces.
llvm::iterator_range< const_attribute_iterator > getAttributes() const
const_value_iterator operand_end() const
arg_iterator arg_end() const
int getNumArgs() const
Returns the total number of arguments.
const_value_iterator operand_begin() const
Op operand iterators.
void assertInvariants() const
Check invariants (like no duplicated or conflicted names) and abort the process if any invariant is b...
StringRef getArgName(int index) const
const llvm::Record & getDef() const
Returns the Tablegen definition this operator was constructed from.
StringRef getDialectName() const
Returns this op's dialect name.
const_region_iterator region_end() const
unsigned getNumVariableLengthResults() const
Returns the number of variable length results in this operation.
bool hasSingleVariadicArg() const
Returns true of the operation has a single variadic arg.
unsigned getNumVariadicSuccessors() const
Returns the number of variadic successors in this operation.
StringRef getSummary() const
bool isVariadic() const
Returns true if this op has variable length operands or results.
llvm::iterator_range< const_trait_iterator > getTraits() const
const Trait * getTrait(llvm::StringRef trait) const
Returns the trait wrapper for the given MLIR C++ trait.
llvm::iterator_range< const_successor_iterator > getSuccessors() const
const_successor_iterator successor_begin() const
void print(llvm::raw_ostream &os) const
Prints the contents in this operator to the given os.
unsigned getNumRegions() const
Returns the number of regions.
const_trait_iterator trait_begin() const
StringRef getExtraClassDeclaration() const
Returns this op's extra class declaration code.
StringRef getAssemblyFormat() const
std::string getSetterName(StringRef name) const
Returns the setter name for the accessor of name.
std::string getOperationName() const
Returns the operation name.
const NamedSuccessor & getSuccessor(unsigned index) const
Returns the index-th successor.
StringRef getCppClassName() const
Returns this op's C++ class name.
bool allResultTypesKnown() const
Return whether all the result types are known.
bool hasAssemblyFormat() const
Query functions for the assembly format of the operator.
unsigned getNumVariadicRegions() const
Returns the number of variadic regions in this operation.
Operator(const llvm::Record &def)
bool skipDefaultBuilders() const
Returns true if default builders should not be generated.
arg_iterator arg_begin() const
Op argument (attribute or operand) iterators.
const InferredResultType & getInferredResultType(int index) const
Return all arguments or type constraints with same type as result[index].
const_value_iterator result_end() const
std::string getRemoverName(StringRef name) const
Returns the remove name for the accessor of name.
Argument getArg(int index) const
Op argument (attribute or operand) accessors.
bool hasDescription() const
Query functions for the documentation of the operator.
static Trait create(const llvm::Init *init)
bool isVariableLength() const
The OpAsmOpInterface, see OpAsmInterface.td for more details.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
const char * inferTypeOpInterface
Include the generated interface declarations.
bool isVariableLength() const
TypeConstraint constraint
Pair consisting kind of argument and index into operands or attributes.
static VariableDecorator unwrap(llvm::Init *init)
A class used to represent the decorators of an operator variable, i.e.