21#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
27struct TransformInlinerInterface :
public DialectInlinerInterface {
28 using DialectInlinerInterface::DialectInlinerInterface;
32 bool wouldBeCloned)
const final {
33 return isa<transform::NamedSequenceOp>(callable);
40 IRMapping &valueMapping)
const final {
41 return isa_and_nonnull<transform::NamedSequenceOp>(dest->getParentOp()) &&
42 isa_and_nonnull<transform::NamedSequenceOp>(src->getParentOp());
49 IRMapping &valueMapping)
const final {
50 return isa_and_nonnull<transform::NamedSequenceOp>(dest->getParentOp());
56 void handleTerminator(Operation *op,
ValueRange valuesToRepl)
const final {
57 auto yieldOp = cast<transform::YieldOp>(op);
58 assert(yieldOp.getNumOperands() == valuesToRepl.size() &&
59 "mismatched yield/call result count");
60 for (
auto [from, to] : llvm::zip(valuesToRepl, yieldOp.getOperands()))
61 from.replaceAllUsesWith(to);
66#if LLVM_ENABLE_ABI_BREAKING_CHECKS
67void transform::detail::checkImplementsTransformOpInterface(
76 opName.
hasInterface<ConversionPatternDescriptorOpInterface>() ||
77 opName.
hasInterface<TypeConverterBuilderOpInterface>() ||
80 "non-terminator ops injected into the transform dialect must "
81 "implement TransformOpInterface or PatternDescriptorOpInterface or "
82 "ConversionPatternDescriptorOpInterface");
83 if (!opName.
hasInterface<PatternDescriptorOpInterface>() &&
84 !opName.
hasInterface<ConversionPatternDescriptorOpInterface>() &&
85 !opName.
hasInterface<TypeConverterBuilderOpInterface>() &&
88 "ops injected into the transform dialect must implement "
89 "MemoryEffectsOpInterface");
93void transform::detail::checkImplementsTransformHandleTypeInterface(
96 assert((abstractType.hasInterface(
97 TransformHandleTypeInterface::getInterfaceID()) ||
98 abstractType.hasInterface(
99 TransformParamTypeInterface::getInterfaceID()) ||
100 abstractType.hasInterface(
101 TransformValueHandleTypeInterface::getInterfaceID())) &&
102 "expected Transform dialect type to implement one of the three "
107void transform::TransformDialect::initialize() {
110 addOperationsChecked<
112#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
114 initializeAttributes();
116 initializeLibraryModule();
117 addInterfaces<TransformInlinerInterface>();
127 auto it = attributeParsingHooks.find(keyword);
128 if (it == attributeParsingHooks.end()) {
129 parser.
emitError(loc) <<
"unknown attribute mnemonic: " << keyword;
133 return it->getValue()(parser, type);
136void transform::TransformDialect::printAttribute(
138 auto it = attributePrintingHooks.find(attribute.
getTypeID());
139 assert(it != attributePrintingHooks.end() &&
"printing unknown attribute");
140 it->getSecond()(attribute, printer);
149 auto it = typeParsingHooks.find(keyword);
150 if (it == typeParsingHooks.end()) {
151 parser.
emitError(loc) <<
"unknown type mnemonic: " << keyword;
155 return it->getValue()(parser);
158void transform::TransformDialect::printType(
Type type,
160 auto it = typePrintingHooks.find(type.
getTypeID());
161 assert(it != typePrintingHooks.end() &&
"printing unknown type");
162 it->getSecond()(type, printer);
165LogicalResult transform::TransformDialect::loadIntoLibraryModule(
167 return detail::mergeSymbolsInto(getLibraryModule(), std::move(library));
170void transform::TransformDialect::initializeLibraryModule() {
174 libraryModule = ModuleOp::create(loc,
"__transform_library");
175 libraryModule.get()->setAttr(TransformDialect::kWithNamedSequenceAttrName,
176 UnitAttr::get(context));
179void transform::TransformDialect::reportDuplicateAttributeRegistration(
180 StringRef attrName) {
182 llvm::raw_string_ostream msg(buffer);
183 msg <<
"extensible dialect attribute '" << attrName
184 <<
"' is already registered with a different implementation";
185 llvm::report_fatal_error(StringRef(buffer));
188void transform::TransformDialect::reportDuplicateTypeRegistration(
189 StringRef mnemonic) {
191 llvm::raw_string_ostream msg(buffer);
192 msg <<
"extensible dialect type '" << mnemonic
193 <<
"' is already registered with a different implementation";
194 llvm::report_fatal_error(StringRef(buffer));
197void transform::TransformDialect::reportDuplicateOpRegistration(
200 llvm::raw_string_ostream msg(buffer);
201 msg <<
"extensible dialect operation '" << opName
202 <<
"' is already registered with a mismatching TypeID";
203 llvm::report_fatal_error(StringRef(buffer));
206LogicalResult transform::TransformDialect::verifyOperationAttribute(
208 if (attribute.
getName().getValue() == kWithNamedSequenceAttrName) {
211 <<
" attribute can only be attached to "
212 "operations with symbol tables";
219 if (!isa<CallableOpInterface, CallOpInterface>(nested))
229 return detail::verifyNoRecursionInCallGraph(op);
231 if (attribute.
getName().getValue() == kTargetTagAttrName) {
232 if (!llvm::isa<StringAttr>(attribute.
getValue())) {
234 << attribute.
getName() <<
" attribute must be a string";
238 if (attribute.
getName().getValue() == kArgConsumedAttrName ||
239 attribute.
getName().getValue() == kArgReadOnlyAttrName) {
240 if (!llvm::isa<UnitAttr>(attribute.
getValue())) {
242 << attribute.
getName() <<
" must be a unit attribute";
246 if (attribute.
getName().getValue() ==
247 FindPayloadReplacementOpInterface::kSilenceTrackingFailuresAttrName) {
248 if (!llvm::isa<UnitAttr>(attribute.
getValue())) {
250 << attribute.
getName() <<
" must be a unit attribute";
255 <<
"unknown attribute: " << attribute.
getName();
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static const AbstractType & lookup(TypeID typeID, MLIRContext *context)
Look up the specified abstract type in the MLIRContext and return a reference to it.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
TypeID getTypeID()
Return a unique identifier for the concrete attribute type.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
This class provides the API for ops that are known to be terminators.
A trait used to provide symbol table functionalities to a region operation.
bool hasTrait() const
Returns true if the operation was registered with a particular trait, e.g.
bool hasInterface() const
Returns true if this operation has the given interface registered to it.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
This is a "type erased" representation of a registered operation.
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
This class provides an efficient unique identifier for a specific C++ type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
TypeID getTypeID()
Return a unique identifier for the concrete type.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...