MLIR 23.0.0git
TransformDialect.cpp
Go to the documentation of this file.
1//===- TransformDialect.cpp - Transform Dialect Definition ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
16#include "mlir/IR/Verifier.h"
17#include "llvm/ADT/SCCIterator.h"
18#include "llvm/ADT/TypeSwitch.h"
19
20using namespace mlir;
21
22#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
23
24#if LLVM_ENABLE_ABI_BREAKING_CHECKS
25void transform::detail::checkImplementsTransformOpInterface(
26 StringRef name, MLIRContext *context) {
27 // Since the operation is being inserted into the Transform dialect and the
28 // dialect does not implement the interface fallback, only check for the op
29 // itself having the interface implementation.
31 *RegisteredOperationName::lookup(name, context);
32 assert((opName.hasInterface<TransformOpInterface>() ||
33 opName.hasInterface<PatternDescriptorOpInterface>() ||
34 opName.hasInterface<ConversionPatternDescriptorOpInterface>() ||
35 opName.hasInterface<TypeConverterBuilderOpInterface>() ||
37 opName.hasInterface<NormalFormCheckedOpInterface>()) &&
38 "non-terminator ops injected into the transform dialect must "
39 "implement TransformOpInterface or PatternDescriptorOpInterface or "
40 "ConversionPatternDescriptorOpInterface");
41 if (!opName.hasInterface<PatternDescriptorOpInterface>() &&
42 !opName.hasInterface<ConversionPatternDescriptorOpInterface>() &&
43 !opName.hasInterface<TypeConverterBuilderOpInterface>() &&
44 !opName.hasInterface<NormalFormCheckedOpInterface>()) {
45 assert(opName.hasInterface<MemoryEffectOpInterface>() &&
46 "ops injected into the transform dialect must implement "
47 "MemoryEffectsOpInterface");
48 }
49}
50
51void transform::detail::checkImplementsTransformHandleTypeInterface(
52 TypeID typeID, MLIRContext *context) {
53 const auto &abstractType = AbstractType::lookup(typeID, context);
54 assert((abstractType.hasInterface(
55 TransformHandleTypeInterface::getInterfaceID()) ||
56 abstractType.hasInterface(
57 TransformParamTypeInterface::getInterfaceID()) ||
58 abstractType.hasInterface(
59 TransformValueHandleTypeInterface::getInterfaceID())) &&
60 "expected Transform dialect type to implement one of the three "
61 "interfaces");
62}
63#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
64
65void transform::TransformDialect::initialize() {
66 // Using the checked versions to enable the same assertions as for the ops
67 // from extensions.
68 addOperationsChecked<
69#define GET_OP_LIST
70#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
71 >();
72 initializeAttributes();
73 initializeTypes();
74 initializeLibraryModule();
75}
76
77Attribute transform::TransformDialect::parseAttribute(DialectAsmParser &parser,
78 Type type) const {
79 StringRef keyword;
80 SMLoc loc = parser.getCurrentLocation();
81 if (failed(parser.parseKeyword(&keyword)))
82 return nullptr;
83
84 auto it = attributeParsingHooks.find(keyword);
85 if (it == attributeParsingHooks.end()) {
86 parser.emitError(loc) << "unknown attribute mnemonic: " << keyword;
87 return nullptr;
88 }
89
90 return it->getValue()(parser, type);
91}
92
93void transform::TransformDialect::printAttribute(
94 Attribute attribute, DialectAsmPrinter &printer) const {
95 auto it = attributePrintingHooks.find(attribute.getTypeID());
96 assert(it != attributePrintingHooks.end() && "printing unknown attribute");
97 it->getSecond()(attribute, printer);
98}
99
100Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {
101 StringRef keyword;
102 SMLoc loc = parser.getCurrentLocation();
103 if (failed(parser.parseKeyword(&keyword)))
104 return nullptr;
105
106 auto it = typeParsingHooks.find(keyword);
107 if (it == typeParsingHooks.end()) {
108 parser.emitError(loc) << "unknown type mnemonic: " << keyword;
109 return nullptr;
110 }
111
112 return it->getValue()(parser);
113}
114
115void transform::TransformDialect::printType(Type type,
116 DialectAsmPrinter &printer) const {
117 auto it = typePrintingHooks.find(type.getTypeID());
118 assert(it != typePrintingHooks.end() && "printing unknown type");
119 it->getSecond()(type, printer);
120}
121
122LogicalResult transform::TransformDialect::loadIntoLibraryModule(
124 return detail::mergeSymbolsInto(getLibraryModule(), std::move(library));
125}
126
127void transform::TransformDialect::initializeLibraryModule() {
128 MLIRContext *context = getContext();
129 auto loc =
130 FileLineColLoc::get(context, "<transform-dialect-library-module>", 0, 0);
131 libraryModule = ModuleOp::create(loc, "__transform_library");
132 libraryModule.get()->setAttr(TransformDialect::kWithNamedSequenceAttrName,
133 UnitAttr::get(context));
134}
135
136void transform::TransformDialect::reportDuplicateAttributeRegistration(
137 StringRef attrName) {
138 std::string buffer;
139 llvm::raw_string_ostream msg(buffer);
140 msg << "extensible dialect attribute '" << attrName
141 << "' is already registered with a different implementation";
142 llvm::report_fatal_error(StringRef(buffer));
143}
144
145void transform::TransformDialect::reportDuplicateTypeRegistration(
146 StringRef mnemonic) {
147 std::string buffer;
148 llvm::raw_string_ostream msg(buffer);
149 msg << "extensible dialect type '" << mnemonic
150 << "' is already registered with a different implementation";
151 llvm::report_fatal_error(StringRef(buffer));
152}
153
154void transform::TransformDialect::reportDuplicateOpRegistration(
155 StringRef opName) {
156 std::string buffer;
157 llvm::raw_string_ostream msg(buffer);
158 msg << "extensible dialect operation '" << opName
159 << "' is already registered with a mismatching TypeID";
160 llvm::report_fatal_error(StringRef(buffer));
161}
162
163LogicalResult transform::TransformDialect::verifyOperationAttribute(
164 Operation *op, NamedAttribute attribute) {
165 if (attribute.getName().getValue() == kWithNamedSequenceAttrName) {
166 if (!op->hasTrait<OpTrait::SymbolTable>()) {
167 return emitError(op->getLoc()) << attribute.getName()
168 << " attribute can only be attached to "
169 "operations with symbol tables";
170 }
171
172 // Pre-verify calls and callables because call graph construction below
173 // assumes they are valid, but this verifier runs before verifying the
174 // nested operations.
175 WalkResult walkResult = op->walk([](Operation *nested) {
176 if (!isa<CallableOpInterface, CallOpInterface>(nested))
177 return WalkResult::advance();
178
179 if (failed(verify(nested, /*verifyRecursively=*/false)))
180 return WalkResult::interrupt();
181 return WalkResult::advance();
182 });
183 if (walkResult.wasInterrupted())
184 return failure();
185
186 const mlir::CallGraph callgraph(op);
187 for (auto scc = llvm::scc_begin(&callgraph); !scc.isAtEnd(); ++scc) {
188 if (!scc.hasCycle())
189 continue;
190
191 // Need to check this here additionally because this verification may run
192 // before we check the nested operations.
193 if ((*scc->begin())->isExternal())
194 return op->emitOpError() << "contains a call to an external operation, "
195 "which is not allowed";
196
197 Operation *first = (*scc->begin())->getCallableRegion()->getParentOp();
199 << "recursion not allowed in named sequences";
200 for (auto it = std::next(scc->begin()); it != scc->end(); ++it) {
201 // Need to check this here additionally because this verification may
202 // run before we check the nested operations.
203 if ((*it)->isExternal()) {
204 return op->emitOpError() << "contains a call to an external "
205 "operation, which is not allowed";
206 }
207
208 Operation *current = (*it)->getCallableRegion()->getParentOp();
209 diag.attachNote(current->getLoc()) << "operation on recursion stack";
210 }
211 return diag;
212 }
213 return success();
214 }
215 if (attribute.getName().getValue() == kTargetTagAttrName) {
216 if (!llvm::isa<StringAttr>(attribute.getValue())) {
217 return op->emitError()
218 << attribute.getName() << " attribute must be a string";
219 }
220 return success();
221 }
222 if (attribute.getName().getValue() == kArgConsumedAttrName ||
223 attribute.getName().getValue() == kArgReadOnlyAttrName) {
224 if (!llvm::isa<UnitAttr>(attribute.getValue())) {
225 return op->emitError()
226 << attribute.getName() << " must be a unit attribute";
227 }
228 return success();
229 }
230 if (attribute.getName().getValue() ==
231 FindPayloadReplacementOpInterface::kSilenceTrackingFailuresAttrName) {
232 if (!llvm::isa<UnitAttr>(attribute.getValue())) {
233 return op->emitError()
234 << attribute.getName() << " must be a unit attribute";
235 }
236 return success();
237 }
238 return emitError(op->getLoc())
239 << "unknown attribute: " << attribute.getName();
240}
return success()
b getContext())
static std::string diag(const llvm::Value &value)
static const AbstractType & lookup(TypeID typeID, MLIRContext *context)
Look up the specified abstract type in the MLIRContext and return a reference to it.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition Attributes.h:25
TypeID getTypeID()
Return a unique identifier for the concrete attribute type.
Definition Attributes.h:52
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition Location.cpp:157
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
This class provides the API for ops that are known to be terminators.
A trait used to provide symbol table functionalities to a region operation.
bool hasTrait() const
Returns true if the operation was registered with a particular trait, e.g.
bool hasInterface() const
Returns true if this operation has the given interface registered to it.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition OwningOpRef.h:29
This is a "type erased" representation of a registered operation.
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
TypeID getTypeID()
Return a unique identifier for the concrete type.
Definition Types.h:101
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:480