17#include "llvm/ADT/SCCIterator.h"
18#include "llvm/ADT/TypeSwitch.h"
22#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
24#if LLVM_ENABLE_ABI_BREAKING_CHECKS
25void transform::detail::checkImplementsTransformOpInterface(
34 opName.
hasInterface<ConversionPatternDescriptorOpInterface>() ||
35 opName.
hasInterface<TypeConverterBuilderOpInterface>() ||
38 "non-terminator ops injected into the transform dialect must "
39 "implement TransformOpInterface or PatternDescriptorOpInterface or "
40 "ConversionPatternDescriptorOpInterface");
41 if (!opName.
hasInterface<PatternDescriptorOpInterface>() &&
42 !opName.
hasInterface<ConversionPatternDescriptorOpInterface>() &&
43 !opName.
hasInterface<TypeConverterBuilderOpInterface>() &&
46 "ops injected into the transform dialect must implement "
47 "MemoryEffectsOpInterface");
51void transform::detail::checkImplementsTransformHandleTypeInterface(
54 assert((abstractType.hasInterface(
55 TransformHandleTypeInterface::getInterfaceID()) ||
56 abstractType.hasInterface(
57 TransformParamTypeInterface::getInterfaceID()) ||
58 abstractType.hasInterface(
59 TransformValueHandleTypeInterface::getInterfaceID())) &&
60 "expected Transform dialect type to implement one of the three "
65void transform::TransformDialect::initialize() {
70#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
72 initializeAttributes();
74 initializeLibraryModule();
84 auto it = attributeParsingHooks.find(keyword);
85 if (it == attributeParsingHooks.end()) {
86 parser.
emitError(loc) <<
"unknown attribute mnemonic: " << keyword;
90 return it->getValue()(parser, type);
93void transform::TransformDialect::printAttribute(
95 auto it = attributePrintingHooks.find(attribute.
getTypeID());
96 assert(it != attributePrintingHooks.end() &&
"printing unknown attribute");
97 it->getSecond()(attribute, printer);
106 auto it = typeParsingHooks.find(keyword);
107 if (it == typeParsingHooks.end()) {
108 parser.
emitError(loc) <<
"unknown type mnemonic: " << keyword;
112 return it->getValue()(parser);
115void transform::TransformDialect::printType(
Type type,
117 auto it = typePrintingHooks.find(type.
getTypeID());
118 assert(it != typePrintingHooks.end() &&
"printing unknown type");
119 it->getSecond()(type, printer);
122LogicalResult transform::TransformDialect::loadIntoLibraryModule(
124 return detail::mergeSymbolsInto(getLibraryModule(), std::move(library));
127void transform::TransformDialect::initializeLibraryModule() {
131 libraryModule = ModuleOp::create(loc,
"__transform_library");
132 libraryModule.get()->setAttr(TransformDialect::kWithNamedSequenceAttrName,
133 UnitAttr::get(context));
136void transform::TransformDialect::reportDuplicateAttributeRegistration(
137 StringRef attrName) {
139 llvm::raw_string_ostream msg(buffer);
140 msg <<
"extensible dialect attribute '" << attrName
141 <<
"' is already registered with a different implementation";
142 llvm::report_fatal_error(StringRef(buffer));
145void transform::TransformDialect::reportDuplicateTypeRegistration(
146 StringRef mnemonic) {
148 llvm::raw_string_ostream msg(buffer);
149 msg <<
"extensible dialect type '" << mnemonic
150 <<
"' is already registered with a different implementation";
151 llvm::report_fatal_error(StringRef(buffer));
154void transform::TransformDialect::reportDuplicateOpRegistration(
157 llvm::raw_string_ostream msg(buffer);
158 msg <<
"extensible dialect operation '" << opName
159 <<
"' is already registered with a mismatching TypeID";
160 llvm::report_fatal_error(StringRef(buffer));
163LogicalResult transform::TransformDialect::verifyOperationAttribute(
165 if (attribute.
getName().getValue() == kWithNamedSequenceAttrName) {
168 <<
" attribute can only be attached to "
169 "operations with symbol tables";
176 if (!isa<CallableOpInterface, CallOpInterface>(nested))
187 for (
auto scc = llvm::scc_begin(&callgraph); !scc.isAtEnd(); ++scc) {
193 if ((*scc->begin())->isExternal())
194 return op->
emitOpError() <<
"contains a call to an external operation, "
195 "which is not allowed";
199 <<
"recursion not allowed in named sequences";
200 for (
auto it = std::next(scc->begin()); it != scc->end(); ++it) {
203 if ((*it)->isExternal()) {
204 return op->
emitOpError() <<
"contains a call to an external "
205 "operation, which is not allowed";
209 diag.attachNote(current->
getLoc()) <<
"operation on recursion stack";
215 if (attribute.
getName().getValue() == kTargetTagAttrName) {
216 if (!llvm::isa<StringAttr>(attribute.
getValue())) {
218 << attribute.
getName() <<
" attribute must be a string";
222 if (attribute.
getName().getValue() == kArgConsumedAttrName ||
223 attribute.
getName().getValue() == kArgReadOnlyAttrName) {
224 if (!llvm::isa<UnitAttr>(attribute.
getValue())) {
226 << attribute.
getName() <<
" must be a unit attribute";
230 if (attribute.
getName().getValue() ==
231 FindPayloadReplacementOpInterface::kSilenceTrackingFailuresAttrName) {
232 if (!llvm::isa<UnitAttr>(attribute.
getValue())) {
234 << attribute.
getName() <<
" must be a unit attribute";
239 <<
"unknown attribute: " << attribute.
getName();
static std::string diag(const llvm::Value &value)
static const AbstractType & lookup(TypeID typeID, MLIRContext *context)
Look up the specified abstract type in the MLIRContext and return a reference to it.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
TypeID getTypeID()
Return a unique identifier for the concrete attribute type.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class provides the API for ops that are known to be terminators.
A trait used to provide symbol table functionalities to a region operation.
bool hasTrait() const
Returns true if the operation was registered with a particular trait, e.g.
bool hasInterface() const
Returns true if this operation has the given interface registered to it.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
This is a "type erased" representation of a registered operation.
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
This class provides an efficient unique identifier for a specific C++ type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
TypeID getTypeID()
Return a unique identifier for the concrete type.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...