MLIR 23.0.0git
TransformDialect.cpp
Go to the documentation of this file.
1//===- TransformDialect.cpp - Transform Dialect Definition ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "mlir/IR/IRMapping.h"
16#include "mlir/IR/Verifier.h"
18
19using namespace mlir;
20
21#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
22
23namespace {
24/// This interface enables inlining of `transform.named_sequence` operations
25/// into the body of other `transform.named_sequence` operations. The dialect
26/// does not allow inlining into any other context.
27struct TransformInlinerInterface : public DialectInlinerInterface {
28 using DialectInlinerInterface::DialectInlinerInterface;
29
30 /// A call may be inlined when its callee is a `transform.named_sequence`.
31 bool isLegalToInline(Operation *call, Operation *callable,
32 bool wouldBeCloned) const final {
33 return isa<transform::NamedSequenceOp>(callable);
34 }
35
36 /// A region may be inlined into another region only when both are bodies of
37 /// `transform.named_sequence` operations: this restricts inlining to the
38 /// "named sequence into named sequence" case.
39 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
40 IRMapping &valueMapping) const final {
41 return isa_and_nonnull<transform::NamedSequenceOp>(dest->getParentOp()) &&
42 isa_and_nonnull<transform::NamedSequenceOp>(src->getParentOp());
43 }
44
45 /// Any operation is legal to inline into the body of a
46 /// `transform.named_sequence`. Whether a particular operation is actually
47 /// valid in that context is enforced by the regular op verifiers.
48 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
49 IRMapping &valueMapping) const final {
50 return isa_and_nonnull<transform::NamedSequenceOp>(dest->getParentOp());
51 }
52
53 /// Replace the `transform.yield` terminator of an inlined single-block
54 /// region by directly forwarding its operands to the values that used to be
55 /// produced by the call site.
56 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
57 auto yieldOp = cast<transform::YieldOp>(op);
58 assert(yieldOp.getNumOperands() == valuesToRepl.size() &&
59 "mismatched yield/call result count");
60 for (auto [from, to] : llvm::zip(valuesToRepl, yieldOp.getOperands()))
61 from.replaceAllUsesWith(to);
62 }
63};
64} // namespace
65
66#if LLVM_ENABLE_ABI_BREAKING_CHECKS
67void transform::detail::checkImplementsTransformOpInterface(
68 StringRef name, MLIRContext *context) {
69 // Since the operation is being inserted into the Transform dialect and the
70 // dialect does not implement the interface fallback, only check for the op
71 // itself having the interface implementation.
73 *RegisteredOperationName::lookup(name, context);
74 assert((opName.hasInterface<TransformOpInterface>() ||
75 opName.hasInterface<PatternDescriptorOpInterface>() ||
76 opName.hasInterface<ConversionPatternDescriptorOpInterface>() ||
77 opName.hasInterface<TypeConverterBuilderOpInterface>() ||
79 opName.hasInterface<NormalFormCheckedOpInterface>()) &&
80 "non-terminator ops injected into the transform dialect must "
81 "implement TransformOpInterface or PatternDescriptorOpInterface or "
82 "ConversionPatternDescriptorOpInterface");
83 if (!opName.hasInterface<PatternDescriptorOpInterface>() &&
84 !opName.hasInterface<ConversionPatternDescriptorOpInterface>() &&
85 !opName.hasInterface<TypeConverterBuilderOpInterface>() &&
86 !opName.hasInterface<NormalFormCheckedOpInterface>()) {
87 assert(opName.hasInterface<MemoryEffectOpInterface>() &&
88 "ops injected into the transform dialect must implement "
89 "MemoryEffectsOpInterface");
90 }
91}
92
93void transform::detail::checkImplementsTransformHandleTypeInterface(
94 TypeID typeID, MLIRContext *context) {
95 const auto &abstractType = AbstractType::lookup(typeID, context);
96 assert((abstractType.hasInterface(
97 TransformHandleTypeInterface::getInterfaceID()) ||
98 abstractType.hasInterface(
99 TransformParamTypeInterface::getInterfaceID()) ||
100 abstractType.hasInterface(
101 TransformValueHandleTypeInterface::getInterfaceID())) &&
102 "expected Transform dialect type to implement one of the three "
103 "interfaces");
104}
105#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
106
107void transform::TransformDialect::initialize() {
108 // Using the checked versions to enable the same assertions as for the ops
109 // from extensions.
110 addOperationsChecked<
111#define GET_OP_LIST
112#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
113 >();
114 initializeAttributes();
115 initializeTypes();
116 initializeLibraryModule();
117 addInterfaces<TransformInlinerInterface>();
118}
119
120Attribute transform::TransformDialect::parseAttribute(DialectAsmParser &parser,
121 Type type) const {
122 StringRef keyword;
123 SMLoc loc = parser.getCurrentLocation();
124 if (failed(parser.parseKeyword(&keyword)))
125 return nullptr;
126
127 auto it = attributeParsingHooks.find(keyword);
128 if (it == attributeParsingHooks.end()) {
129 parser.emitError(loc) << "unknown attribute mnemonic: " << keyword;
130 return nullptr;
131 }
132
133 return it->getValue()(parser, type);
134}
135
136void transform::TransformDialect::printAttribute(
137 Attribute attribute, DialectAsmPrinter &printer) const {
138 auto it = attributePrintingHooks.find(attribute.getTypeID());
139 assert(it != attributePrintingHooks.end() && "printing unknown attribute");
140 it->getSecond()(attribute, printer);
141}
142
143Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {
144 StringRef keyword;
145 SMLoc loc = parser.getCurrentLocation();
146 if (failed(parser.parseKeyword(&keyword)))
147 return nullptr;
148
149 auto it = typeParsingHooks.find(keyword);
150 if (it == typeParsingHooks.end()) {
151 parser.emitError(loc) << "unknown type mnemonic: " << keyword;
152 return nullptr;
153 }
154
155 return it->getValue()(parser);
156}
157
158void transform::TransformDialect::printType(Type type,
159 DialectAsmPrinter &printer) const {
160 auto it = typePrintingHooks.find(type.getTypeID());
161 assert(it != typePrintingHooks.end() && "printing unknown type");
162 it->getSecond()(type, printer);
163}
164
165LogicalResult transform::TransformDialect::loadIntoLibraryModule(
167 return detail::mergeSymbolsInto(getLibraryModule(), std::move(library));
168}
169
170void transform::TransformDialect::initializeLibraryModule() {
171 MLIRContext *context = getContext();
172 auto loc =
173 FileLineColLoc::get(context, "<transform-dialect-library-module>", 0, 0);
174 libraryModule = ModuleOp::create(loc, "__transform_library");
175 libraryModule.get()->setAttr(TransformDialect::kWithNamedSequenceAttrName,
176 UnitAttr::get(context));
177}
178
179void transform::TransformDialect::reportDuplicateAttributeRegistration(
180 StringRef attrName) {
181 std::string buffer;
182 llvm::raw_string_ostream msg(buffer);
183 msg << "extensible dialect attribute '" << attrName
184 << "' is already registered with a different implementation";
185 llvm::report_fatal_error(StringRef(buffer));
186}
187
188void transform::TransformDialect::reportDuplicateTypeRegistration(
189 StringRef mnemonic) {
190 std::string buffer;
191 llvm::raw_string_ostream msg(buffer);
192 msg << "extensible dialect type '" << mnemonic
193 << "' is already registered with a different implementation";
194 llvm::report_fatal_error(StringRef(buffer));
195}
196
197void transform::TransformDialect::reportDuplicateOpRegistration(
198 StringRef opName) {
199 std::string buffer;
200 llvm::raw_string_ostream msg(buffer);
201 msg << "extensible dialect operation '" << opName
202 << "' is already registered with a mismatching TypeID";
203 llvm::report_fatal_error(StringRef(buffer));
204}
205
206LogicalResult transform::TransformDialect::verifyOperationAttribute(
207 Operation *op, NamedAttribute attribute) {
208 if (attribute.getName().getValue() == kWithNamedSequenceAttrName) {
209 if (!op->hasTrait<OpTrait::SymbolTable>()) {
210 return emitError(op->getLoc()) << attribute.getName()
211 << " attribute can only be attached to "
212 "operations with symbol tables";
213 }
214
215 // Pre-verify calls and callables because call graph construction below
216 // assumes they are valid, but this verifier runs before verifying the
217 // nested operations.
218 WalkResult walkResult = op->walk([](Operation *nested) {
219 if (!isa<CallableOpInterface, CallOpInterface>(nested))
220 return WalkResult::advance();
221
222 if (failed(verify(nested, /*verifyRecursively=*/false)))
223 return WalkResult::interrupt();
224 return WalkResult::advance();
225 });
226 if (walkResult.wasInterrupted())
227 return failure();
228
229 return detail::verifyNoRecursionInCallGraph(op);
230 }
231 if (attribute.getName().getValue() == kTargetTagAttrName) {
232 if (!llvm::isa<StringAttr>(attribute.getValue())) {
233 return op->emitError()
234 << attribute.getName() << " attribute must be a string";
235 }
236 return success();
237 }
238 if (attribute.getName().getValue() == kArgConsumedAttrName ||
239 attribute.getName().getValue() == kArgReadOnlyAttrName) {
240 if (!llvm::isa<UnitAttr>(attribute.getValue())) {
241 return op->emitError()
242 << attribute.getName() << " must be a unit attribute";
243 }
244 return success();
245 }
246 if (attribute.getName().getValue() ==
247 FindPayloadReplacementOpInterface::kSilenceTrackingFailuresAttrName) {
248 if (!llvm::isa<UnitAttr>(attribute.getValue())) {
249 return op->emitError()
250 << attribute.getName() << " must be a unit attribute";
251 }
252 return success();
253 }
254 return emitError(op->getLoc())
255 << "unknown attribute: " << attribute.getName();
256}
return success()
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b getContext())
static const AbstractType & lookup(TypeID typeID, MLIRContext *context)
Look up the specified abstract type in the MLIRContext and return a reference to it.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition Attributes.h:25
TypeID getTypeID()
Return a unique identifier for the concrete attribute type.
Definition Attributes.h:52
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition Location.cpp:157
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
This class provides the API for ops that are known to be terminators.
A trait used to provide symbol table functionalities to a region operation.
bool hasTrait() const
Returns true if the operation was registered with a particular trait, e.g.
bool hasInterface() const
Returns true if this operation has the given interface registered to it.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition OwningOpRef.h:29
This is a "type erased" representation of a registered operation.
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
TypeID getTypeID()
Return a unique identifier for the concrete type.
Definition Types.h:101
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:480