15 #include "llvm/ADT/ScopeExit.h"
21 #define GET_OP_CLASSES
22 #include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc"
41 : Extension(state), patterns(patternContainer) {}
48 LogicalResult findAllMatches(StringRef patternName,
Operation *root,
59 llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
65 LogicalResult PatternApplicatorExtension::findAllMatches(
68 auto it = compiledPatterns.find(patternName);
69 if (it == compiledPatterns.end()) {
70 auto patternOp = patterns.
lookup<pdl::PatternOp>(patternName);
78 builder.clone(*patternOp);
79 PDLPatternModule patternModule(std::move(pdlModuleOp));
85 for (
const auto &[name, constraintFn] :
88 patternModule.registerConstraintFunction(name, constraintFn);
93 patternModule.registerRewriteFunction(
97 .try_emplace(patternOp.getName(), std::move(patternModule))
106 explicit TrivialPatternRewriter(
MLIRContext *context)
109 TrivialPatternRewriter rewriter(root->
getContext());
113 results.push_back(op);
125 llvm::StringMap<PDLConstraintFunction> &&constraintFns) {
127 for (
auto &it : constraintFns)
128 pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second));
131 const llvm::StringMap<PDLConstraintFunction> &
133 return pdlMatchHooks.getConstraintFunctions();
144 auto *extension = state.getExtension<PatternApplicatorExtension>();
146 "expected PatternApplicatorExtension to be attached by the parent op");
148 for (
Operation *root : state.getPayloadOps(getRoot())) {
149 if (failed(extension->findAllMatches(
150 getPatternName().getLeafReference().getValue(), root, targets))) {
152 <<
"could not find pattern '" << getPatternName() <<
"'";
155 results.
set(llvm::cast<OpResult>(getResult()), targets);
159 void transform::PDLMatchOp::getEffects(
174 TransformOpInterface transformOp =
nullptr;
175 for (
Operation &nested : getBody().front()) {
176 if (!isa<pdl::PatternOp>(nested)) {
177 transformOp = cast<TransformOpInterface>(nested);
182 state.addExtension<PatternApplicatorExtension>(getOperation());
183 auto guard = llvm::make_scope_exit(
184 [&]() { state.removeExtension<PatternApplicatorExtension>(); });
186 auto scope = state.make_region_scope(getBody());
187 if (failed(mapBlockArguments(state)))
189 return state.applyTransform(transformOp);
192 void transform::WithPDLPatternsOp::getEffects(
198 Block *body = getBodyBlock();
201 if (isa<pdl::PatternOp>(op))
207 emitOpError() <<
"expects only one non-pattern op in its body";
208 diag.attachNote(topLevelOp->
getLoc()) <<
"first non-pattern op";
209 diag.attachNote(op.getLoc()) <<
"second non-pattern op";
218 <<
"expects only pattern and top-level transform ops in its body";
219 diag.attachNote(op.getLoc()) <<
"offending op";
223 if (
auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
225 diag.attachNote(parent.getLoc()) <<
"parent operation";
231 <<
"expects at least one non-pattern op";
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Block represents an ordered list of Operations.
OpListType & getOperations()
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})
Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...
void applyDefaultCostModel()
Apply the default cost model that solely uses the pattern's static benefit.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Include the generated interface declarations.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...