15 #include "llvm/ADT/ScopeExit.h"
21 #define GET_OP_CLASSES
22 #include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc"
41 : Extension(state),
patterns(patternContainer) {}
48 LogicalResult findAllMatches(StringRef patternName,
Operation *root,
59 llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
65 LogicalResult PatternApplicatorExtension::findAllMatches(
68 auto it = compiledPatterns.find(patternName);
69 if (it == compiledPatterns.end()) {
70 auto patternOp =
patterns.lookup<pdl::PatternOp>(patternName);
78 builder.clone(*patternOp);
79 PDLPatternModule patternModule(std::move(pdlModuleOp));
85 for (
const auto &[name, constraintFn] :
88 patternModule.registerConstraintFunction(name, constraintFn);
93 patternModule.registerRewriteFunction(
97 .try_emplace(patternOp.getName(), std::move(patternModule))
106 explicit TrivialPatternRewriter(
MLIRContext *context)
109 TrivialPatternRewriter rewriter(root->
getContext());
113 results.push_back(op);
125 llvm::StringMap<PDLConstraintFunction> &&constraintFns) {
127 for (
auto &it : constraintFns)
128 pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second));
131 const llvm::StringMap<PDLConstraintFunction> &
133 return pdlMatchHooks.getConstraintFunctions();
144 auto *extension = state.getExtension<PatternApplicatorExtension>();
146 "expected PatternApplicatorExtension to be attached by the parent op");
148 for (
Operation *root : state.getPayloadOps(getRoot())) {
149 if (failed(extension->findAllMatches(
150 getPatternName().getLeafReference().getValue(), root, targets))) {
152 <<
"could not find pattern '" << getPatternName() <<
"'";
155 results.
set(llvm::cast<OpResult>(getResult()), targets);
159 void transform::PDLMatchOp::getEffects(
174 TransformOpInterface transformOp =
nullptr;
175 for (
Operation &nested : getBody().front()) {
176 if (!isa<pdl::PatternOp>(nested)) {
177 transformOp = cast<TransformOpInterface>(nested);
182 state.addExtension<PatternApplicatorExtension>(getOperation());
183 auto guard = llvm::make_scope_exit(
184 [&]() { state.removeExtension<PatternApplicatorExtension>(); });
186 auto scope = state.make_region_scope(getBody());
187 if (failed(mapBlockArguments(state)))
189 return state.applyTransform(transformOp);
192 void transform::WithPDLPatternsOp::getEffects(
198 Block *body = getBodyBlock();
201 if (isa<pdl::PatternOp>(op))
207 emitOpError() <<
"expects only one non-pattern op in its body";
208 diag.attachNote(topLevelOp->
getLoc()) <<
"first non-pattern op";
209 diag.attachNote(op.getLoc()) <<
"second non-pattern op";
218 <<
"expects only pattern and top-level transform ops in its body";
219 diag.attachNote(op.getLoc()) <<
"offending op";
223 if (
auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
225 diag.attachNote(parent.getLoc()) <<
"parent operation";
231 <<
"expects at least one non-pattern op";
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Block represents an ordered list of Operations.
OpListType & getOperations()
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})
Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...
void applyDefaultCostModel()
Apply the default cost model that solely uses the pattern's static benefit.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Include the generated interface declarations.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...