MLIR 22.0.0git
TransformDialect.cpp
Go to the documentation of this file.
1//===- TransformDialect.cpp - Transform Dialect Definition ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
16#include "mlir/IR/Verifier.h"
17#include "llvm/ADT/SCCIterator.h"
18#include "llvm/ADT/TypeSwitch.h"
19
20using namespace mlir;
21
22#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
23
24#define GET_ATTRDEF_CLASSES
25#include "mlir/Dialect/Transform/IR/TransformAttrs.cpp.inc"
26
27#ifndef NDEBUG
29 StringRef name, MLIRContext *context) {
30 // Since the operation is being inserted into the Transform dialect and the
31 // dialect does not implement the interface fallback, only check for the op
32 // itself having the interface implementation.
34 *RegisteredOperationName::lookup(name, context);
35 assert((opName.hasInterface<TransformOpInterface>() ||
36 opName.hasInterface<PatternDescriptorOpInterface>() ||
37 opName.hasInterface<ConversionPatternDescriptorOpInterface>() ||
38 opName.hasInterface<TypeConverterBuilderOpInterface>() ||
40 "non-terminator ops injected into the transform dialect must "
41 "implement TransformOpInterface or PatternDescriptorOpInterface or "
42 "ConversionPatternDescriptorOpInterface");
43 if (!opName.hasInterface<PatternDescriptorOpInterface>() &&
44 !opName.hasInterface<ConversionPatternDescriptorOpInterface>() &&
45 !opName.hasInterface<TypeConverterBuilderOpInterface>()) {
46 assert(opName.hasInterface<MemoryEffectOpInterface>() &&
47 "ops injected into the transform dialect must implement "
48 "MemoryEffectsOpInterface");
49 }
50}
51
53 TypeID typeID, MLIRContext *context) {
54 const auto &abstractType = AbstractType::lookup(typeID, context);
55 assert((abstractType.hasInterface(
56 TransformHandleTypeInterface::getInterfaceID()) ||
57 abstractType.hasInterface(
58 TransformParamTypeInterface::getInterfaceID()) ||
59 abstractType.hasInterface(
60 TransformValueHandleTypeInterface::getInterfaceID())) &&
61 "expected Transform dialect type to implement one of the three "
62 "interfaces");
63}
64#endif // NDEBUG
65
66void transform::TransformDialect::initialize() {
67 // Using the checked versions to enable the same assertions as for the ops
68 // from extensions.
69 addOperationsChecked<
70#define GET_OP_LIST
71#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
72 >();
73 initializeTypes();
74 addAttributes<
75#define GET_ATTRDEF_LIST
76#include "mlir/Dialect/Transform/IR/TransformAttrs.cpp.inc"
77 >();
78 initializeLibraryModule();
79}
80
81Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {
82 StringRef keyword;
83 SMLoc loc = parser.getCurrentLocation();
84 if (failed(parser.parseKeyword(&keyword)))
85 return nullptr;
86
87 auto it = typeParsingHooks.find(keyword);
88 if (it == typeParsingHooks.end()) {
89 parser.emitError(loc) << "unknown type mnemonic: " << keyword;
90 return nullptr;
91 }
92
93 return it->getValue()(parser);
94}
95
96void transform::TransformDialect::printType(Type type,
97 DialectAsmPrinter &printer) const {
98 auto it = typePrintingHooks.find(type.getTypeID());
99 assert(it != typePrintingHooks.end() && "printing unknown type");
100 it->getSecond()(type, printer);
101}
102
103LogicalResult transform::TransformDialect::loadIntoLibraryModule(
105 return detail::mergeSymbolsInto(getLibraryModule(), std::move(library));
106}
107
108void transform::TransformDialect::initializeLibraryModule() {
109 MLIRContext *context = getContext();
110 auto loc =
111 FileLineColLoc::get(context, "<transform-dialect-library-module>", 0, 0);
112 libraryModule = ModuleOp::create(loc, "__transform_library");
113 libraryModule.get()->setAttr(TransformDialect::kWithNamedSequenceAttrName,
114 UnitAttr::get(context));
115}
116
117void transform::TransformDialect::reportDuplicateTypeRegistration(
118 StringRef mnemonic) {
119 std::string buffer;
120 llvm::raw_string_ostream msg(buffer);
121 msg << "extensible dialect type '" << mnemonic
122 << "' is already registered with a different implementation";
123 llvm::report_fatal_error(StringRef(buffer));
124}
125
126void transform::TransformDialect::reportDuplicateOpRegistration(
127 StringRef opName) {
128 std::string buffer;
129 llvm::raw_string_ostream msg(buffer);
130 msg << "extensible dialect operation '" << opName
131 << "' is already registered with a mismatching TypeID";
132 llvm::report_fatal_error(StringRef(buffer));
133}
134
135LogicalResult transform::TransformDialect::verifyOperationAttribute(
136 Operation *op, NamedAttribute attribute) {
137 if (attribute.getName().getValue() == kWithNamedSequenceAttrName) {
138 if (!op->hasTrait<OpTrait::SymbolTable>()) {
139 return emitError(op->getLoc()) << attribute.getName()
140 << " attribute can only be attached to "
141 "operations with symbol tables";
142 }
143
144 // Pre-verify calls and callables because call graph construction below
145 // assumes they are valid, but this verifier runs before verifying the
146 // nested operations.
147 WalkResult walkResult = op->walk([](Operation *nested) {
148 if (!isa<CallableOpInterface, CallOpInterface>(nested))
149 return WalkResult::advance();
150
151 if (failed(verify(nested, /*verifyRecursively=*/false)))
152 return WalkResult::interrupt();
153 return WalkResult::advance();
154 });
155 if (walkResult.wasInterrupted())
156 return failure();
157
158 const mlir::CallGraph callgraph(op);
159 for (auto scc = llvm::scc_begin(&callgraph); !scc.isAtEnd(); ++scc) {
160 if (!scc.hasCycle())
161 continue;
162
163 // Need to check this here additionally because this verification may run
164 // before we check the nested operations.
165 if ((*scc->begin())->isExternal())
166 return op->emitOpError() << "contains a call to an external operation, "
167 "which is not allowed";
168
169 Operation *first = (*scc->begin())->getCallableRegion()->getParentOp();
171 << "recursion not allowed in named sequences";
172 for (auto it = std::next(scc->begin()); it != scc->end(); ++it) {
173 // Need to check this here additionally because this verification may
174 // run before we check the nested operations.
175 if ((*it)->isExternal()) {
176 return op->emitOpError() << "contains a call to an external "
177 "operation, which is not allowed";
178 }
179
180 Operation *current = (*it)->getCallableRegion()->getParentOp();
181 diag.attachNote(current->getLoc()) << "operation on recursion stack";
182 }
183 return diag;
184 }
185 return success();
186 }
187 if (attribute.getName().getValue() == kTargetTagAttrName) {
188 if (!llvm::isa<StringAttr>(attribute.getValue())) {
189 return op->emitError()
190 << attribute.getName() << " attribute must be a string";
191 }
192 return success();
193 }
194 if (attribute.getName().getValue() == kArgConsumedAttrName ||
195 attribute.getName().getValue() == kArgReadOnlyAttrName) {
196 if (!llvm::isa<UnitAttr>(attribute.getValue())) {
197 return op->emitError()
198 << attribute.getName() << " must be a unit attribute";
199 }
200 return success();
201 }
202 if (attribute.getName().getValue() ==
203 FindPayloadReplacementOpInterface::kSilenceTrackingFailuresAttrName) {
204 if (!llvm::isa<UnitAttr>(attribute.getValue())) {
205 return op->emitError()
206 << attribute.getName() << " must be a unit attribute";
207 }
208 return success();
209 }
210 return emitError(op->getLoc())
211 << "unknown attribute: " << attribute.getName();
212}
return success()
b getContext())
static std::string diag(const llvm::Value &value)
static const AbstractType & lookup(TypeID typeID, MLIRContext *context)
Look up the specified abstract type in the MLIRContext and return a reference to it.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition Location.cpp:157
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
This class provides the API for ops that are known to be terminators.
A trait used to provide symbol table functionalities to a region operation.
bool hasTrait() const
Returns true if the operation was registered with a particular trait, e.g.
bool hasInterface() const
Returns true if this operation has the given interface registered to it.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition OwningOpRef.h:29
This is a "type erased" representation of a registered operation.
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
TypeID getTypeID()
Return a unique identifier for the concrete type.
Definition Types.h:101
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void checkImplementsTransformHandleTypeInterface(TypeID typeID, MLIRContext *context)
Asserts that the type provided as template argument implements the TransformHandleTypeInterface.
void checkImplementsTransformOpInterface(StringRef name, MLIRContext *context)
Asserts that the operations provided as template arguments implement the TransformOpInterface and Mem...
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:423