14 #include "llvm/ADT/ScopeExit.h"
20 #define GET_OP_CLASSES
21 #include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc"
40 : Extension(state),
patterns(patternContainer) {}
47 LogicalResult findAllMatches(StringRef patternName,
Operation *root,
58 llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
64 LogicalResult PatternApplicatorExtension::findAllMatches(
67 auto it = compiledPatterns.find(patternName);
68 if (it == compiledPatterns.end()) {
69 auto patternOp =
patterns.lookup<pdl::PatternOp>(patternName);
77 builder.clone(*patternOp);
78 PDLPatternModule patternModule(std::move(pdlModuleOp));
84 for (
const auto &[name, constraintFn] :
87 patternModule.registerConstraintFunction(name, constraintFn);
92 patternModule.registerRewriteFunction(
96 .try_emplace(patternOp.getName(), std::move(patternModule))
107 results.push_back(op);
119 llvm::StringMap<PDLConstraintFunction> &&constraintFns) {
121 for (
auto &it : constraintFns)
122 pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second));
125 const llvm::StringMap<PDLConstraintFunction> &
127 return pdlMatchHooks.getConstraintFunctions();
138 auto *extension = state.getExtension<PatternApplicatorExtension>();
140 "expected PatternApplicatorExtension to be attached by the parent op");
142 for (
Operation *root : state.getPayloadOps(getRoot())) {
143 if (
failed(extension->findAllMatches(
144 getPatternName().getLeafReference().getValue(), root, targets))) {
146 <<
"could not find pattern '" << getPatternName() <<
"'";
149 results.
set(llvm::cast<OpResult>(getResult()), targets);
153 void transform::PDLMatchOp::getEffects(
168 TransformOpInterface transformOp =
nullptr;
169 for (
Operation &nested : getBody().front()) {
170 if (!isa<pdl::PatternOp>(nested)) {
171 transformOp = cast<TransformOpInterface>(nested);
176 state.addExtension<PatternApplicatorExtension>(getOperation());
177 auto guard = llvm::make_scope_exit(
178 [&]() { state.removeExtension<PatternApplicatorExtension>(); });
180 auto scope = state.make_region_scope(getBody());
181 if (
failed(mapBlockArguments(state)))
183 return state.applyTransform(transformOp);
186 void transform::WithPDLPatternsOp::getEffects(
192 Block *body = getBodyBlock();
195 if (isa<pdl::PatternOp>(op))
201 emitOpError() <<
"expects only one non-pattern op in its body";
202 diag.attachNote(topLevelOp->
getLoc()) <<
"first non-pattern op";
203 diag.attachNote(op.getLoc()) <<
"second non-pattern op";
212 <<
"expects only pattern and top-level transform ops in its body";
213 diag.attachNote(op.getLoc()) <<
"offending op";
217 if (
auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
219 diag.attachNote(parent.getLoc()) <<
"parent operation";
225 <<
"expects at least one non-pattern op";
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Block represents an ordered list of Operations.
OpListType & getOperations()
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class represents a diagnostic that is inflight and set to be reported.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})
Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...
void applyDefaultCostModel()
Apply the default cost model that solely uses the pattern's static benefit.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Include the generated interface declarations.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...