MLIR  22.0.0git
TransformDialect.cpp
Go to the documentation of this file.
1 //===- TransformDialect.cpp - Transform Dialect Definition ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
16 #include "mlir/IR/Verifier.h"
17 #include "llvm/ADT/SCCIterator.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 
20 using namespace mlir;
21 
22 #include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
23 
24 #define GET_ATTRDEF_CLASSES
25 #include "mlir/Dialect/Transform/IR/TransformAttrs.cpp.inc"
26 
27 #ifndef NDEBUG
29  StringRef name, MLIRContext *context) {
30  // Since the operation is being inserted into the Transform dialect and the
31  // dialect does not implement the interface fallback, only check for the op
32  // itself having the interface implementation.
34  *RegisteredOperationName::lookup(name, context);
35  assert((opName.hasInterface<TransformOpInterface>() ||
36  opName.hasInterface<PatternDescriptorOpInterface>() ||
37  opName.hasInterface<ConversionPatternDescriptorOpInterface>() ||
38  opName.hasInterface<TypeConverterBuilderOpInterface>() ||
39  opName.hasTrait<OpTrait::IsTerminator>()) &&
40  "non-terminator ops injected into the transform dialect must "
41  "implement TransformOpInterface or PatternDescriptorOpInterface or "
42  "ConversionPatternDescriptorOpInterface");
43  if (!opName.hasInterface<PatternDescriptorOpInterface>() &&
44  !opName.hasInterface<ConversionPatternDescriptorOpInterface>() &&
45  !opName.hasInterface<TypeConverterBuilderOpInterface>()) {
46  assert(opName.hasInterface<MemoryEffectOpInterface>() &&
47  "ops injected into the transform dialect must implement "
48  "MemoryEffectsOpInterface");
49  }
50 }
51 
53  TypeID typeID, MLIRContext *context) {
54  const auto &abstractType = AbstractType::lookup(typeID, context);
55  assert((abstractType.hasInterface(
56  TransformHandleTypeInterface::getInterfaceID()) ||
57  abstractType.hasInterface(
58  TransformParamTypeInterface::getInterfaceID()) ||
59  abstractType.hasInterface(
60  TransformValueHandleTypeInterface::getInterfaceID())) &&
61  "expected Transform dialect type to implement one of the three "
62  "interfaces");
63 }
64 #endif // NDEBUG
65 
66 void transform::TransformDialect::initialize() {
67  // Using the checked versions to enable the same assertions as for the ops
68  // from extensions.
69  addOperationsChecked<
70 #define GET_OP_LIST
71 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
72  >();
73  initializeTypes();
74  addAttributes<
75 #define GET_ATTRDEF_LIST
76 #include "mlir/Dialect/Transform/IR/TransformAttrs.cpp.inc"
77  >();
78  initializeLibraryModule();
79 }
80 
82  StringRef keyword;
83  SMLoc loc = parser.getCurrentLocation();
84  if (failed(parser.parseKeyword(&keyword)))
85  return nullptr;
86 
87  auto it = typeParsingHooks.find(keyword);
88  if (it == typeParsingHooks.end()) {
89  parser.emitError(loc) << "unknown type mnemonic: " << keyword;
90  return nullptr;
91  }
92 
93  return it->getValue()(parser);
94 }
95 
97  DialectAsmPrinter &printer) const {
98  auto it = typePrintingHooks.find(type.getTypeID());
99  assert(it != typePrintingHooks.end() && "printing unknown type");
100  it->getSecond()(type, printer);
101 }
102 
103 LogicalResult transform::TransformDialect::loadIntoLibraryModule(
105  return detail::mergeSymbolsInto(getLibraryModule(), std::move(library));
106 }
107 
108 void transform::TransformDialect::initializeLibraryModule() {
109  MLIRContext *context = getContext();
110  auto loc =
111  FileLineColLoc::get(context, "<transform-dialect-library-module>", 0, 0);
112  libraryModule = ModuleOp::create(loc, "__transform_library");
113  libraryModule.get()->setAttr(TransformDialect::kWithNamedSequenceAttrName,
114  UnitAttr::get(context));
115 }
116 
117 void transform::TransformDialect::reportDuplicateTypeRegistration(
118  StringRef mnemonic) {
119  std::string buffer;
120  llvm::raw_string_ostream msg(buffer);
121  msg << "extensible dialect type '" << mnemonic
122  << "' is already registered with a different implementation";
123  llvm::report_fatal_error(StringRef(buffer));
124 }
125 
126 void transform::TransformDialect::reportDuplicateOpRegistration(
127  StringRef opName) {
128  std::string buffer;
129  llvm::raw_string_ostream msg(buffer);
130  msg << "extensible dialect operation '" << opName
131  << "' is already registered with a mismatching TypeID";
132  llvm::report_fatal_error(StringRef(buffer));
133 }
134 
135 LogicalResult transform::TransformDialect::verifyOperationAttribute(
136  Operation *op, NamedAttribute attribute) {
137  if (attribute.getName().getValue() == kWithNamedSequenceAttrName) {
138  if (!op->hasTrait<OpTrait::SymbolTable>()) {
139  return emitError(op->getLoc()) << attribute.getName()
140  << " attribute can only be attached to "
141  "operations with symbol tables";
142  }
143 
144  // Pre-verify calls and callables because call graph construction below
145  // assumes they are valid, but this verifier runs before verifying the
146  // nested operations.
147  WalkResult walkResult = op->walk([](Operation *nested) {
148  if (!isa<CallableOpInterface, CallOpInterface>(nested))
149  return WalkResult::advance();
150 
151  if (failed(verify(nested, /*verifyRecursively=*/false)))
152  return WalkResult::interrupt();
153  return WalkResult::advance();
154  });
155  if (walkResult.wasInterrupted())
156  return failure();
157 
158  const mlir::CallGraph callgraph(op);
159  for (auto scc = llvm::scc_begin(&callgraph); !scc.isAtEnd(); ++scc) {
160  if (!scc.hasCycle())
161  continue;
162 
163  // Need to check this here additionally because this verification may run
164  // before we check the nested operations.
165  if ((*scc->begin())->isExternal())
166  return op->emitOpError() << "contains a call to an external operation, "
167  "which is not allowed";
168 
169  Operation *first = (*scc->begin())->getCallableRegion()->getParentOp();
171  << "recursion not allowed in named sequences";
172  for (auto it = std::next(scc->begin()); it != scc->end(); ++it) {
173  // Need to check this here additionally because this verification may
174  // run before we check the nested operations.
175  if ((*it)->isExternal()) {
176  return op->emitOpError() << "contains a call to an external "
177  "operation, which is not allowed";
178  }
179 
180  Operation *current = (*it)->getCallableRegion()->getParentOp();
181  diag.attachNote(current->getLoc()) << "operation on recursion stack";
182  }
183  return diag;
184  }
185  return success();
186  }
187  if (attribute.getName().getValue() == kTargetTagAttrName) {
188  if (!llvm::isa<StringAttr>(attribute.getValue())) {
189  return op->emitError()
190  << attribute.getName() << " attribute must be a string";
191  }
192  return success();
193  }
194  if (attribute.getName().getValue() == kArgConsumedAttrName ||
195  attribute.getName().getValue() == kArgReadOnlyAttrName) {
196  if (!llvm::isa<UnitAttr>(attribute.getValue())) {
197  return op->emitError()
198  << attribute.getName() << " must be a unit attribute";
199  }
200  return success();
201  }
202  if (attribute.getName().getValue() ==
203  FindPayloadReplacementOpInterface::kSilenceTrackingFailuresAttrName) {
204  if (!llvm::isa<UnitAttr>(attribute.getValue())) {
205  return op->emitError()
206  << attribute.getName() << " must be a unit attribute";
207  }
208  return success();
209  }
210  return emitError(op->getLoc())
211  << "unknown attribute: " << attribute.getName();
212 }
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static const AbstractType & lookup(TypeID typeID, MLIRContext *context)
Look up the specified abstract type in the MLIRContext and return a reference to it.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
The DialectAsmParser has methods for interacting with the asm parser when parsing attributes and type...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition: Location.cpp:157
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:773
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:452
bool hasTrait() const
Returns true if the operation was registered with a particular trait, e.g.
bool hasInterface() const
Returns true if this operation has the given interface registered to it.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition: OwningOpRef.h:29
This is a "type erased" representation of a registered operation.
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
TypeID getTypeID()
Return a unique identifier for the concrete type.
Definition: Types.h:101
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: WalkResult.h:51
static WalkResult interrupt()
Definition: WalkResult.h:46
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
InFlightDiagnostic mergeSymbolsInto(Operation *target, OwningOpRef< Operation * > other)
Merge all symbols from other into target.
Definition: Utils.cpp:80
void checkImplementsTransformHandleTypeInterface(TypeID typeID, MLIRContext *context)
Asserts that the type provided as template argument implements the TransformHandleTypeInterface.
void checkImplementsTransformOpInterface(StringRef name, MLIRContext *context)
Asserts that the operations provided as template arguments implement the TransformOpInterface and Mem...
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423