14#include "llvm/ADT/ScopeExit.h"
21#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc"
40 : Extension(state),
patterns(patternContainer) {}
47 LogicalResult findAllMatches(StringRef patternName,
Operation *root,
58 llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
64LogicalResult PatternApplicatorExtension::findAllMatches(
67 auto it = compiledPatterns.find(patternName);
68 if (it == compiledPatterns.end()) {
69 auto patternOp =
patterns.lookup<pdl::PatternOp>(patternName);
77 builder.clone(*patternOp);
78 PDLPatternModule patternModule(std::move(pdlModuleOp));
84 for (
const auto &[name, constraintFn] :
87 patternModule.registerConstraintFunction(name, constraintFn);
92 patternModule.registerRewriteFunction(
96 .try_emplace(patternOp.getName(), std::move(patternModule))
107 results.push_back(op);
119 llvm::StringMap<PDLConstraintFunction> &&constraintFns) {
121 for (
auto &it : constraintFns)
122 pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second));
125const llvm::StringMap<PDLConstraintFunction> &
127 return pdlMatchHooks.getConstraintFunctions();
138 auto *extension = state.
getExtension<PatternApplicatorExtension>();
140 "expected PatternApplicatorExtension to be attached by the parent op");
143 if (failed(extension->findAllMatches(
144 getPatternName().getLeafReference().getValue(), root, targets))) {
146 <<
"could not find pattern '" << getPatternName() <<
"'";
149 results.
set(llvm::cast<OpResult>(getResult()), targets);
153void transform::PDLMatchOp::getEffects(
154 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
164DiagnosedSilenceableFailure
165transform::WithPDLPatternsOp::apply(transform::TransformRewriter &rewriter,
166 transform::TransformResults &results,
167 transform::TransformState &state) {
168 TransformOpInterface transformOp =
nullptr;
169 for (Operation &nested : getBody().front()) {
170 if (!isa<pdl::PatternOp>(nested)) {
171 transformOp = cast<TransformOpInterface>(nested);
176 state.
addExtension<PatternApplicatorExtension>(getOperation());
177 auto guard = llvm::make_scope_exit(
181 if (
failed(mapBlockArguments(state)))
186void transform::WithPDLPatternsOp::getEffects(
187 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
191LogicalResult transform::WithPDLPatternsOp::verify() {
192 Block *body = getBodyBlock();
193 Operation *topLevelOp =
nullptr;
195 if (isa<pdl::PatternOp>(op))
198 if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
200 InFlightDiagnostic
diag =
201 emitOpError() <<
"expects only one non-pattern op in its body";
202 diag.attachNote(topLevelOp->
getLoc()) <<
"first non-pattern op";
203 diag.attachNote(op.getLoc()) <<
"second non-pattern op";
210 InFlightDiagnostic
diag =
212 <<
"expects only pattern and top-level transform ops in its body";
213 diag.attachNote(op.getLoc()) <<
"offending op";
217 if (
auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
219 diag.attachNote(parent.getLoc()) <<
"parent operation";
225 <<
"expects at least one non-pattern op";
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
OpListType & getOperations()
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})
Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...
void applyDefaultCostModel()
Apply the default cost model that solely uses the pattern's static benefit.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Include the generated interface declarations.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns