MLIR 23.0.0git
TransformDialect.cpp
Go to the documentation of this file.
1//===- TransformDialect.cpp - Transform Dialect Definition ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
16#include "mlir/IR/Verifier.h"
17#include "llvm/ADT/SCCIterator.h"
18#include "llvm/ADT/TypeSwitch.h"
19
20using namespace mlir;
21
22#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
23
24#define GET_ATTRDEF_CLASSES
25#include "mlir/Dialect/Transform/IR/TransformAttrs.cpp.inc"
26
27#if LLVM_ENABLE_ABI_BREAKING_CHECKS
28void transform::detail::checkImplementsTransformOpInterface(
29 StringRef name, MLIRContext *context) {
30 // Since the operation is being inserted into the Transform dialect and the
31 // dialect does not implement the interface fallback, only check for the op
32 // itself having the interface implementation.
34 *RegisteredOperationName::lookup(name, context);
35 assert((opName.hasInterface<TransformOpInterface>() ||
36 opName.hasInterface<PatternDescriptorOpInterface>() ||
37 opName.hasInterface<ConversionPatternDescriptorOpInterface>() ||
38 opName.hasInterface<TypeConverterBuilderOpInterface>() ||
40 "non-terminator ops injected into the transform dialect must "
41 "implement TransformOpInterface or PatternDescriptorOpInterface or "
42 "ConversionPatternDescriptorOpInterface");
43 if (!opName.hasInterface<PatternDescriptorOpInterface>() &&
44 !opName.hasInterface<ConversionPatternDescriptorOpInterface>() &&
45 !opName.hasInterface<TypeConverterBuilderOpInterface>()) {
46 assert(opName.hasInterface<MemoryEffectOpInterface>() &&
47 "ops injected into the transform dialect must implement "
48 "MemoryEffectsOpInterface");
49 }
50}
51
52void transform::detail::checkImplementsTransformHandleTypeInterface(
53 TypeID typeID, MLIRContext *context) {
54 const auto &abstractType = AbstractType::lookup(typeID, context);
55 assert((abstractType.hasInterface(
56 TransformHandleTypeInterface::getInterfaceID()) ||
57 abstractType.hasInterface(
58 TransformParamTypeInterface::getInterfaceID()) ||
59 abstractType.hasInterface(
60 TransformValueHandleTypeInterface::getInterfaceID())) &&
61 "expected Transform dialect type to implement one of the three "
62 "interfaces");
63}
64#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
65
66void transform::TransformDialect::initialize() {
67 // Using the checked versions to enable the same assertions as for the ops
68 // from extensions.
69 addOperationsChecked<
70#define GET_OP_LIST
71#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
72 >();
73 initializeTypes();
74 addAttributes<
75#define GET_ATTRDEF_LIST
76#include "mlir/Dialect/Transform/IR/TransformAttrs.cpp.inc"
77 >();
78 initializeLibraryModule();
79}
80
81Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {
82 StringRef keyword;
83 SMLoc loc = parser.getCurrentLocation();
84 if (failed(parser.parseKeyword(&keyword)))
85 return nullptr;
86
87 auto it = typeParsingHooks.find(keyword);
88 if (it == typeParsingHooks.end()) {
89 parser.emitError(loc) << "unknown type mnemonic: " << keyword;
90 return nullptr;
91 }
92
93 return it->getValue()(parser);
94}
95
96void transform::TransformDialect::printType(Type type,
97 DialectAsmPrinter &printer) const {
98 auto it = typePrintingHooks.find(type.getTypeID());
99 assert(it != typePrintingHooks.end() && "printing unknown type");
100 it->getSecond()(type, printer);
101}
102
103LogicalResult transform::TransformDialect::loadIntoLibraryModule(
105 return detail::mergeSymbolsInto(getLibraryModule(), std::move(library));
106}
107
108void transform::TransformDialect::initializeLibraryModule() {
109 MLIRContext *context = getContext();
110 auto loc =
111 FileLineColLoc::get(context, "<transform-dialect-library-module>", 0, 0);
112 libraryModule = ModuleOp::create(loc, "__transform_library");
113 libraryModule.get()->setAttr(TransformDialect::kWithNamedSequenceAttrName,
114 UnitAttr::get(context));
115}
116
117void transform::TransformDialect::reportDuplicateTypeRegistration(
118 StringRef mnemonic) {
119 std::string buffer;
120 llvm::raw_string_ostream msg(buffer);
121 msg << "extensible dialect type '" << mnemonic
122 << "' is already registered with a different implementation";
123 llvm::report_fatal_error(StringRef(buffer));
124}
125
126void transform::TransformDialect::reportDuplicateOpRegistration(
127 StringRef opName) {
128 std::string buffer;
129 llvm::raw_string_ostream msg(buffer);
130 msg << "extensible dialect operation '" << opName
131 << "' is already registered with a mismatching TypeID";
132 llvm::report_fatal_error(StringRef(buffer));
133}
134
135LogicalResult transform::TransformDialect::verifyOperationAttribute(
136 Operation *op, NamedAttribute attribute) {
137 if (attribute.getName().getValue() == kWithNamedSequenceAttrName) {
138 if (!op->hasTrait<OpTrait::SymbolTable>()) {
139 return emitError(op->getLoc()) << attribute.getName()
140 << " attribute can only be attached to "
141 "operations with symbol tables";
142 }
143
144 // Pre-verify calls and callables because call graph construction below
145 // assumes they are valid, but this verifier runs before verifying the
146 // nested operations.
147 WalkResult walkResult = op->walk([](Operation *nested) {
148 if (!isa<CallableOpInterface, CallOpInterface>(nested))
149 return WalkResult::advance();
150
151 if (failed(verify(nested, /*verifyRecursively=*/false)))
152 return WalkResult::interrupt();
153 return WalkResult::advance();
154 });
155 if (walkResult.wasInterrupted())
156 return failure();
157
158 const mlir::CallGraph callgraph(op);
159 for (auto scc = llvm::scc_begin(&callgraph); !scc.isAtEnd(); ++scc) {
160 if (!scc.hasCycle())
161 continue;
162
163 // Need to check this here additionally because this verification may run
164 // before we check the nested operations.
165 if ((*scc->begin())->isExternal())
166 return op->emitOpError() << "contains a call to an external operation, "
167 "which is not allowed";
168
169 Operation *first = (*scc->begin())->getCallableRegion()->getParentOp();
171 << "recursion not allowed in named sequences";
172 for (auto it = std::next(scc->begin()); it != scc->end(); ++it) {
173 // Need to check this here additionally because this verification may
174 // run before we check the nested operations.
175 if ((*it)->isExternal()) {
176 return op->emitOpError() << "contains a call to an external "
177 "operation, which is not allowed";
178 }
179
180 Operation *current = (*it)->getCallableRegion()->getParentOp();
181 diag.attachNote(current->getLoc()) << "operation on recursion stack";
182 }
183 return diag;
184 }
185 return success();
186 }
187 if (attribute.getName().getValue() == kTargetTagAttrName) {
188 if (!llvm::isa<StringAttr>(attribute.getValue())) {
189 return op->emitError()
190 << attribute.getName() << " attribute must be a string";
191 }
192 return success();
193 }
194 if (attribute.getName().getValue() == kArgConsumedAttrName ||
195 attribute.getName().getValue() == kArgReadOnlyAttrName) {
196 if (!llvm::isa<UnitAttr>(attribute.getValue())) {
197 return op->emitError()
198 << attribute.getName() << " must be a unit attribute";
199 }
200 return success();
201 }
202 if (attribute.getName().getValue() ==
203 FindPayloadReplacementOpInterface::kSilenceTrackingFailuresAttrName) {
204 if (!llvm::isa<UnitAttr>(attribute.getValue())) {
205 return op->emitError()
206 << attribute.getName() << " must be a unit attribute";
207 }
208 return success();
209 }
210 return emitError(op->getLoc())
211 << "unknown attribute: " << attribute.getName();
212}
return success()
b getContext())
static std::string diag(const llvm::Value &value)
static const AbstractType & lookup(TypeID typeID, MLIRContext *context)
Look up the specified abstract type in the MLIRContext and return a reference to it.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition Location.cpp:157
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
This class provides the API for ops that are known to be terminators.
A trait used to provide symbol table functionalities to a region operation.
bool hasTrait() const
Returns true if the operation was registered with a particular trait, e.g.
bool hasInterface() const
Returns true if this operation has the given interface registered to it.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition OwningOpRef.h:29
This is a "type erased" representation of a registered operation.
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
TypeID getTypeID()
Return a unique identifier for the concrete type.
Definition Types.h:101
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:480