18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/Sequence.h"
20 #include "llvm/ADT/SmallPtrSet.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/ErrorHandling.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/TableGen/Error.h"
27 #include "llvm/TableGen/Record.h"
29 #define DEBUG_TYPE "mlir-tblgen-operator"
39 using llvm::StringInit;
42 : dialect(def.getValueAsDef(
"opDialect")), def(def) {
48 std::tie(prefix, cppClassName) = def.getName().split(
'_');
51 cppClassName = def.getName();
52 }
else if (cppClassName.empty()) {
54 cppClassName = prefix;
57 cppNamespace = def.getValueAsString(
"cppNamespace");
59 populateOpStructure();
63 std::string Operator::getOperationName()
const {
64 auto prefix = dialect.
getName();
65 auto opName = def.getValueAsString(
"opName");
67 return std::string(opName);
68 return std::string(llvm::formatv(
"{0}.{1}", prefix, opName));
76 return std::string(llvm::formatv(
"{0}GenericAdaptor",
getCppClassName()));
81 std::string accessorName =
82 convertToCamelFromSnakeCase(name,
true);
92 auto nameOverlapsWithOpAPI = [&](StringRef newName) {
93 if (newName ==
"AttributeNames" || newName ==
"Attributes" ||
94 newName ==
"Operation")
96 if (newName ==
"Operands")
98 if (newName ==
"Regions")
100 if (newName ==
"Type")
104 if (nameOverlapsWithOpAPI(accessorName)) {
108 PrintFatalError(op.
getLoc(),
"generated accessor for `" + name +
109 "` overlaps with a default one; please "
110 "rename to avoid overlap");
117 auto checkName = [&](StringRef name, StringRef entity) {
120 auto insertion = existingNames.insert({name, entity});
121 if (insertion.second) {
126 if (entity == insertion.first->second)
127 PrintFatalError(
getLoc(),
"op has a conflict with two " + entity +
128 " having the same name '" + name +
"'");
129 PrintFatalError(
getLoc(),
"op has a conflict with " +
130 insertion.first->second +
" and " + entity +
131 " both having an entry with the name '" +
157 if (cppNamespace.empty())
158 return std::string(cppClassName);
159 return std::string(llvm::formatv(
"{0}::{1}", cppNamespace, cppClassName));
165 const DagInit *results = def.getValueAsDag(
"results");
166 return results->getNumArgs();
170 constexpr
auto attr =
"extraClassDeclaration";
171 if (def.isValueUnset(attr))
173 return def.getValueAsString(attr);
177 constexpr
auto attr =
"extraClassDefinition";
178 if (def.isValueUnset(attr))
180 return def.getValueAsString(attr);
186 return def.getValueAsBit(
"skipDefaultBuilders");
190 return results.begin();
194 return results.end();
202 const DagInit *results = def.getValueAsDag(
"results");
207 const DagInit *results = def.getValueAsDag(
"results");
208 return results->getArgNameStr(index);
212 const Record *result =
213 cast<DefInit>(def.getValueAsDag(
"results")->getArg(index))->getDef();
214 if (!result->isSubClassOf(
"OpVariable"))
216 return *result->getValueAsListInit(
"decorators");
245 const DagInit *argumentValues = def.getValueAsDag(
"arguments");
246 return argumentValues->getArgNameStr(index);
251 cast<DefInit>(def.getValueAsDag(
"arguments")->getArg(index))->getDef();
252 if (!arg->isSubClassOf(
"OpVariable"))
254 return *arg->getValueAsListInit(
"decorators");
258 for (
const auto &t : traits) {
259 if (
const auto *traitDef = dyn_cast<NativeTrait>(&t)) {
260 if (traitDef->getFullyQualifiedTraitName() == trait)
262 }
else if (
const auto *traitDef = dyn_cast<InternalTrait>(&t)) {
263 if (traitDef->getFullyQualifiedTraitName() == trait)
265 }
else if (
const auto *traitDef = dyn_cast<InterfaceTrait>(&t)) {
266 if (traitDef->getFullyQualifiedTraitName() == trait)
274 return regions.begin();
277 return regions.end();
287 return regions[index];
291 return llvm::count_if(regions,
296 return successors.begin();
299 return successors.end();
309 return successors[index];
313 return llvm::count_if(successors,
318 return traits.begin();
328 return attributes.begin();
331 return attributes.end();
338 return attributes.begin();
341 return attributes.end();
344 return {attribute_begin(), attribute_end()};
348 return operands.begin();
351 return operands.end();
360 return any_of(llvm::concat<const NamedTypeConstraint>(operands, results),
364 void Operator::populateTypeInferenceInfo(
365 const llvm::StringMap<int> &argumentsAndResultsIndex) {
368 auto &recordKeeper = def.getRecords();
370 allResultsHaveKnownTypes =
false;
385 if (
getTrait(
"::mlir::OpTrait::SameOperandsAndResultType")) {
387 auto *operandI = llvm::find_if(arguments, [](
const Argument &arg) {
391 if (operandI == arguments.end())
395 int operandIdx = operandI - arguments.begin();
397 resultTypeMapping.emplace_back(operandIdx,
"$_self");
399 allResultsHaveKnownTypes =
true;
409 struct ResultTypeInference {
413 bool inferred =
false;
423 if (
getResult(idx).constraint.getBuilderCall()) {
426 infer.inferred =
true;
432 for (
const Trait &trait : traits) {
433 const Record &def = trait.getDef();
439 if (def.isSubClassOf(
442 if (
const auto *traitDef = dyn_cast<InterfaceTrait>(&trait))
443 if (&traitDef->getDef() == inferTrait)
448 if (def.isSubClassOf(
"TypesMatchWith")) {
449 int target = argumentsAndResultsIndex.lookup(def.getValueAsString(
"rhs"));
454 ResultTypeInference &infer = inference[resultIndex];
459 argumentsAndResultsIndex.lookup(def.getValueAsString(
"lhs"));
460 infer.sources.emplace_back(sourceIndex,
461 def.getValueAsString(
"transformer").str());
472 if (def.isSubClassOf(
"ShapedTypeMatchesElementCountAndTypes")) {
473 StringRef shapedArg = def.getValueAsString(
"shaped");
474 StringRef elementsArg = def.getValueAsString(
"elements");
476 int shapedIndex = argumentsAndResultsIndex.lookup(shapedArg);
477 int elementsIndex = argumentsAndResultsIndex.lookup(elementsArg);
483 ResultTypeInference &infer = inference[resultIndex];
484 if (!infer.inferred) {
485 infer.sources.emplace_back(
487 "::llvm::SmallVector<::mlir::Type>(::llvm::cast<::mlir::"
488 "ShapedType>($_self).getNumElements(), "
489 "::llvm::cast<::mlir::ShapedType>($_self).getElementType())");
490 infer.inferred =
true;
500 if (!def.isSubClassOf(
"AllTypesMatch"))
503 auto values = def.getValueAsListOfStrings(
"values");
509 std::optional<int> fullyInferredIndex;
511 for (StringRef name : values) {
512 int index = argumentsAndResultsIndex.lookup(name);
517 fullyInferredIndex = index;
519 if (fullyInferredIndex) {
522 for (
int resultIndex : resultIndices) {
523 ResultTypeInference &infer = inference[resultIndex];
524 if (!infer.inferred) {
525 infer.sources.assign(1, {*fullyInferredIndex,
"$_self"});
526 infer.inferred =
true;
531 for (
int resultIndex : resultIndices) {
532 for (
int otherResultIndex : resultIndices) {
533 if (resultIndex == otherResultIndex)
535 inference[resultIndex].sources.emplace_back(
543 std::vector<ResultTypeInference *> worklist;
544 for (ResultTypeInference &infer : inference)
546 worklist.push_back(&infer);
550 for (
auto cur = worklist.begin(); cur != worklist.end();) {
551 ResultTypeInference &infer = **cur;
555 assert(InferredResultType::isResultIndex(source.getIndex()));
556 return inference[InferredResultType::unmapResultIndex(
560 if (iter == infer.sources.end()) {
566 infer.inferred =
true;
568 infer.sources.assign(1, *iter);
569 cur = worklist.erase(cur);
573 allResultsHaveKnownTypes = worklist.empty();
576 if (allResultsHaveKnownTypes) {
578 for (
const ResultTypeInference &infer : inference)
579 resultTypeMapping.push_back(infer.sources.front());
583 void Operator::populateOpStructure() {
584 auto &recordKeeper = def.getRecords();
585 auto *typeConstraintClass = recordKeeper.getClass(
"TypeConstraint");
586 auto *attrClass = recordKeeper.getClass(
"Attr");
587 auto *propertyClass = recordKeeper.getClass(
"Property");
588 auto *derivedAttrClass = recordKeeper.getClass(
"DerivedAttr");
589 auto *opVarClass = recordKeeper.getClass(
"OpVariable");
590 numNativeAttributes = 0;
592 const DagInit *argumentValues = def.getValueAsDag(
"arguments");
593 unsigned numArgs = argumentValues->getNumArgs();
597 llvm::StringMap<int> argumentsAndResultsIndex;
600 for (
unsigned i = 0; i != numArgs; ++i) {
601 auto *arg = argumentValues->getArg(i);
602 auto givenName = argumentValues->getArgNameStr(i);
603 auto *argDefInit = dyn_cast<DefInit>(arg);
605 PrintFatalError(def.getLoc(),
606 Twine(
"undefined type for argument #") + Twine(i));
607 const Record *argDef = argDefInit->getDef();
608 if (argDef->isSubClassOf(opVarClass))
609 argDef = argDef->getValueAsDef(
"constraint");
611 if (argDef->isSubClassOf(typeConstraintClass)) {
614 }
else if (argDef->isSubClassOf(attrClass)) {
615 if (givenName.empty())
616 PrintFatalError(argDef->getLoc(),
"attributes must be named");
617 if (argDef->isSubClassOf(derivedAttrClass))
618 PrintFatalError(argDef->getLoc(),
619 "derived attributes not allowed in argument list");
620 attributes.push_back({givenName,
Attribute(argDef)});
621 ++numNativeAttributes;
622 }
else if (argDef->isSubClassOf(propertyClass)) {
623 if (givenName.empty())
624 PrintFatalError(argDef->getLoc(),
"properties must be named");
625 properties.push_back({givenName,
Property(argDef)});
627 PrintFatalError(def.getLoc(),
628 "unexpected def type; only defs deriving "
629 "from TypeConstraint or Attr or Property are allowed");
631 if (!givenName.empty())
632 argumentsAndResultsIndex[givenName] = i;
636 for (
const auto &val : def.getValues()) {
637 if (
auto *record = dyn_cast<llvm::RecordRecTy>(val.getType())) {
638 if (!record->isSubClassOf(attrClass))
640 if (!record->isSubClassOf(derivedAttrClass))
641 PrintFatalError(def.getLoc(),
642 "unexpected Attr where only DerivedAttr is allowed");
644 if (record->getClasses().size() != 1) {
647 "unsupported attribute modelling, only single class expected");
649 attributes.push_back({cast<StringInit>(val.getNameInit())->getValue(),
650 Attribute(cast<DefInit>(val.getValue()))});
658 int operandIndex = 0, attrIndex = 0, propIndex = 0;
659 for (
unsigned i = 0; i != numArgs; ++i) {
660 const Record *argDef =
661 dyn_cast<DefInit>(argumentValues->getArg(i))->getDef();
662 if (argDef->isSubClassOf(opVarClass))
663 argDef = argDef->getValueAsDef(
"constraint");
665 if (argDef->isSubClassOf(typeConstraintClass)) {
666 attrOrOperandMapping.push_back(
668 arguments.emplace_back(&operands[operandIndex++]);
669 }
else if (argDef->isSubClassOf(attrClass)) {
670 attrOrOperandMapping.push_back(
672 arguments.emplace_back(&attributes[attrIndex++]);
674 assert(argDef->isSubClassOf(propertyClass));
675 arguments.emplace_back(&properties[propIndex++]);
679 auto *resultsDag = def.getValueAsDag(
"results");
680 auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator());
681 if (!outsOp || outsOp->getDef()->getName() !=
"outs") {
682 PrintFatalError(def.getLoc(),
"'results' must have 'outs' directive");
686 for (
unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) {
687 auto name = resultsDag->getArgNameStr(i);
688 auto *resultInit = dyn_cast<DefInit>(resultsDag->getArg(i));
690 PrintFatalError(def.getLoc(),
691 Twine(
"undefined type for result #") + Twine(i));
693 auto *resultDef = resultInit->getDef();
694 if (resultDef->isSubClassOf(opVarClass))
695 resultDef = resultDef->getValueAsDef(
"constraint");
701 if (results.back().constraint.isVariadicOfVariadic()) {
704 "'VariadicOfVariadic' results are currently not supported");
709 auto *successorsDag = def.getValueAsDag(
"successors");
710 auto *successorsOp = dyn_cast<DefInit>(successorsDag->getOperator());
711 if (!successorsOp || successorsOp->getDef()->getName() !=
"successor") {
712 PrintFatalError(def.getLoc(),
713 "'successors' must have 'successor' directive");
716 for (
unsigned i = 0, e = successorsDag->getNumArgs(); i < e; ++i) {
717 auto name = successorsDag->getArgNameStr(i);
718 auto *successorInit = dyn_cast<DefInit>(successorsDag->getArg(i));
719 if (!successorInit) {
720 PrintFatalError(def.getLoc(),
721 Twine(
"undefined kind for successor #") + Twine(i));
723 Successor successor(successorInit->getDef());
726 if (i != e - 1 && successor.isVariadic())
727 PrintFatalError(def.getLoc(),
"only the last successor can be variadic");
728 successors.push_back({name, successor});
733 if (
auto *traitList = def.getValueAsListInit(
"traits")) {
736 traits.reserve(traitSet.size());
743 auto verifyTraitValidity = [&](
const Record *trait) {
744 auto *dependentTraits = trait->getValueAsListInit(
"dependentTraits");
745 for (
auto *traitInit : *dependentTraits)
746 if (!traitSet.contains(traitInit))
749 trait->getValueAsString(
"trait") +
" requires " +
750 cast<DefInit>(traitInit)->getDef()->getValueAsString(
752 " to precede it in traits list");
755 std::function<void(
const ListInit *)> insert;
756 insert = [&](
const ListInit *traitList) {
757 for (
auto *traitInit : *traitList) {
758 auto *def = cast<DefInit>(traitInit)->getDef();
759 if (def->isSubClassOf(
"TraitList")) {
760 insert(def->getValueAsListInit(
"traits"));
765 if (!traitSet.insert(traitInit).second)
770 if (def->isSubClassOf(
"Interface"))
771 insert(def->getValueAsListInit(
"baseInterfaces"));
775 verifyTraitValidity(def);
782 populateTypeInferenceInfo(argumentsAndResultsIndex);
785 auto *regionsDag = def.getValueAsDag(
"regions");
786 auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator());
787 if (!regionsOp || regionsOp->getDef()->getName() !=
"region") {
788 PrintFatalError(def.getLoc(),
"'regions' must have 'region' directive");
791 for (
unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) {
792 auto name = regionsDag->getArgNameStr(i);
793 auto *regionInit = dyn_cast<DefInit>(regionsDag->getArg(i));
795 PrintFatalError(def.getLoc(),
796 Twine(
"undefined kind for region #") + Twine(i));
798 Region region(regionInit->getDef());
799 if (region.isVariadic()) {
802 PrintFatalError(def.getLoc(),
"only the last region can be variadic");
804 PrintFatalError(def.getLoc(),
"variadic regions must be named");
807 regions.push_back({name, region});
811 auto *builderList = dyn_cast_or_null<ListInit>(def.getValueInit(
"builders"));
812 if (builderList && !builderList->empty()) {
813 for (
const Init *init : builderList->getElements())
814 builders.emplace_back(cast<DefInit>(init)->
getDef(), def.getLoc());
818 "default builders are skipped and no custom builders provided");
821 LLVM_DEBUG(
print(llvm::dbgs()));
826 return resultTypeMapping[index];
836 return def.getValueAsString(
"description");
842 return def.getValueAsString(
"summary");
846 auto *valueInit = def.getValueInit(
"assemblyFormat");
847 return isa<StringInit>(valueInit);
852 .Case<StringInit>([&](
auto *init) {
return init->getValue(); });
858 if (
auto *attr = llvm::dyn_cast_if_present<NamedAttribute *>(arg))
859 os <<
"[attribute] " << attr->name <<
'\n';
861 os <<
"[operand] " << cast<NamedTypeConstraint *>(arg)->name <<
'\n';
872 return attrOrOperandMapping[index];
876 return "get" + convertToCamelFromSnakeCase(name,
true);
880 return "set" + convertToCamelFromSnakeCase(name,
true);
884 return "remove" + convertToCamelFromSnakeCase(name,
true);
890 return def.getValueAsBit(
"useCustomPropertiesEncoding");
static void assertAccessorInvariants(const Operator &op, StringRef name)
Assert the invariants of accessors generated for the given name.
Attributes are known-constant values of operations.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
StringRef getName() const
This class represents an inferred result type.
static int mapResultIndex(int i)
static int unmapResultIndex(int i)
static bool isResultIndex(int i)
static bool isArgIndex(int i)
Wrapper class that contains a MLIR op's information (e.g., operands, attributes) defined in TableGen ...
std::string getQualCppClassName() const
Returns this op's C++ class name prefixed with namespaces.
unsigned getNumSuccessors() const
Returns the number of successors.
const NamedRegion & getRegion(unsigned index) const
Returns the index-th region.
TypeConstraint getResultTypeConstraint(int index) const
Returns the index-th result's type constraint.
ArrayRef< SMLoc > getLoc() const
Operator(const llvm::Record &def)
llvm::iterator_range< const_region_iterator > getRegions() const
OperandOrAttribute getArgToOperandOrAttribute(int index) const
Returns the OperandOrAttribute corresponding to the index.
NamedTypeConstraint & getOperand(int index)
StringRef getCppNamespace() const
Returns this op's C++ namespace.
const_attribute_iterator attribute_begin() const
std::string getGetterName(StringRef name) const
Returns the getter name for the accessor of name.
const_successor_iterator successor_end() const
int getNumOperands() const
StringRef getDescription() const
const_value_range getResults() const
arg_range getArgs() const
const_value_range getOperands() const
const_region_iterator region_begin() const
bool useCustomPropertiesEncoding() const
Whether to generate the readProperty/writeProperty methods for bytecode emission.
StringRef getResultName(int index) const
Returns the index-th result's name.
var_decorator_range getArgDecorators(int index) const
unsigned getNumVariableLengthOperands() const
Returns the number of variadic operands in this operation.
var_decorator_range getResultDecorators(int index) const
Returns the index-th result's decorators.
std::string getGenericAdaptorName() const
Returns the name of op's generic adaptor C++ class.
StringRef getExtraClassDefinition() const
Returns this op's extra class definition code.
NamedTypeConstraint & getResult(int index)
Returns the op result at the given index.
const_value_iterator result_begin() const
Op result iterators.
const_attribute_iterator attribute_end() const
const_trait_iterator trait_end() const
std::string getAdaptorName() const
Returns the name of op's adaptor C++ class.
int getNumResults() const
Returns the number of results this op produces.
llvm::iterator_range< const_attribute_iterator > getAttributes() const
const_value_iterator operand_end() const
arg_iterator arg_end() const
int getNumArgs() const
Returns the total number of arguments.
const_value_iterator operand_begin() const
Op operand iterators.
void assertInvariants() const
Check invariants (like no duplicated or conflicted names) and abort the process if any invariant is b...
StringRef getArgName(int index) const
StringRef getDialectName() const
Returns this op's dialect name.
const_region_iterator region_end() const
unsigned getNumVariableLengthResults() const
Returns the number of variable length results in this operation.
bool hasSingleVariadicArg() const
Returns true of the operation has a single variadic arg.
unsigned getNumVariadicSuccessors() const
Returns the number of variadic successors in this operation.
StringRef getSummary() const
bool isVariadic() const
Returns true if this op has variable length operands or results.
llvm::iterator_range< const_trait_iterator > getTraits() const
const Trait * getTrait(llvm::StringRef trait) const
Returns the trait wrapper for the given MLIR C++ trait.
llvm::iterator_range< const_successor_iterator > getSuccessors() const
const_successor_iterator successor_begin() const
void print(llvm::raw_ostream &os) const
Prints the contents in this operator to the given os.
unsigned getNumRegions() const
Returns the number of regions.
const_trait_iterator trait_begin() const
StringRef getExtraClassDeclaration() const
Returns this op's extra class declaration code.
StringRef getAssemblyFormat() const
std::string getSetterName(StringRef name) const
Returns the setter name for the accessor of name.
std::string getOperationName() const
Returns the operation name.
const NamedSuccessor & getSuccessor(unsigned index) const
Returns the index-th successor.
StringRef getCppClassName() const
Returns this op's C++ class name.
bool allResultTypesKnown() const
Return whether all the result types are known.
bool hasAssemblyFormat() const
Query functions for the assembly format of the operator.
unsigned getNumVariadicRegions() const
Returns the number of variadic regions in this operation.
bool skipDefaultBuilders() const
Returns true if default builders should not be generated.
arg_iterator arg_begin() const
Op argument (attribute or operand) iterators.
const InferredResultType & getInferredResultType(int index) const
Return all arguments or type constraints with same type as result[index].
const llvm::Record & getDef() const
Returns the Tablegen definition this operator was constructed from.
const_value_iterator result_end() const
std::string getRemoverName(StringRef name) const
Returns the remove name for the accessor of name.
Argument getArg(int index) const
Op argument (attribute or operand) accessors.
bool hasDescription() const
Query functions for the documentation of the operator.
static Trait create(const llvm::Init *init)
bool isVariableLength() const
The OpAsmOpInterface, see OpAsmInterface.td for more details.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
const char * inferTypeOpInterface
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
bool isVariableLength() const
TypeConstraint constraint
Pair consisting kind of argument and index into operands or attributes.
static VariableDecorator unwrap(const llvm::Init *init)
A class used to represent the decorators of an operator variable, i.e.