MLIR 22.0.0git
PDLExtensionOps.cpp
Go to the documentation of this file.
1//===- PDLExtensionOps.cpp - PDL extension for the Transform dialect ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/IR/Builders.h"
14#include "llvm/ADT/ScopeExit.h"
15
16using namespace mlir;
17
19
20#define GET_OP_CLASSES
21#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc"
22
23//===----------------------------------------------------------------------===//
24// PatternApplicatorExtension
25//===----------------------------------------------------------------------===//
26
27namespace {
28/// A TransformState extension that keeps track of compiled PDL pattern sets.
29/// This is intended to be used along the WithPDLPatterns op. The extension
30/// can be constructed given an operation that has a SymbolTable trait and
31/// contains pdl::PatternOp instances. The patterns are compiled lazily and one
32/// by one when requested; this behavior is subject to change.
33class PatternApplicatorExtension : public transform::TransformState::Extension {
34public:
35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
36
37 /// Creates the extension for patterns contained in `patternContainer`.
38 explicit PatternApplicatorExtension(transform::TransformState &state,
39 Operation *patternContainer)
40 : Extension(state), patterns(patternContainer) {}
41
42 /// Appends to `results` the operations contained in `root` that matched the
43 /// PDL pattern with the given name. Note that `root` may or may not be the
44 /// operation that contains PDL patterns. Reports an error if the pattern
45 /// cannot be found. Note that when no operations are matched, this still
46 /// succeeds as long as the pattern exists.
47 LogicalResult findAllMatches(StringRef patternName, Operation *root,
49
50private:
51 /// Map from the pattern name to a singleton set of rewrite patterns that only
52 /// contains the pattern with this name. Populated when the pattern is first
53 /// requested.
54 // TODO: reconsider the efficiency of this storage when more usage data is
55 // available. Storing individual patterns in a set and triggering compilation
56 // for each of them has overhead. So does compiling a large set of patterns
57 // only to apply a handful of them.
58 llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
59
60 /// A symbol table operation containing the relevant PDL patterns.
62};
63
64LogicalResult PatternApplicatorExtension::findAllMatches(
65 StringRef patternName, Operation *root,
67 auto it = compiledPatterns.find(patternName);
68 if (it == compiledPatterns.end()) {
69 auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
70 if (!patternOp)
71 return failure();
72
73 // Copy the pattern operation into a new module that is compiled and
74 // consumed by the PDL interpreter.
75 OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
76 auto builder = OpBuilder::atBlockEnd(pdlModuleOp->getBody());
77 builder.clone(*patternOp);
78 PDLPatternModule patternModule(std::move(pdlModuleOp));
79
80 // Merge in the hooks owned by the dialect. Make a copy as they may be
81 // also used by the following operations.
82 auto *dialect =
83 root->getContext()->getLoadedDialect<transform::TransformDialect>();
84 for (const auto &[name, constraintFn] :
85 dialect->getExtraData<transform::PDLMatchHooks>()
87 patternModule.registerConstraintFunction(name, constraintFn);
88 }
89
90 // Register a noop rewriter because PDL requires patterns to end with some
91 // rewrite call.
92 patternModule.registerRewriteFunction(
93 "transform.dialect", [](PatternRewriter &, Operation *) {});
94
95 it = compiledPatterns
96 .try_emplace(patternOp.getName(), std::move(patternModule))
97 .first;
98 }
99
100 PatternApplicator applicator(it->second);
101 // We want to discourage direct use of PatternRewriter in APIs but In this
102 // very specific case, an IRRewriter is not enough.
103 PatternRewriter rewriter(root->getContext());
104 applicator.applyDefaultCostModel();
105 root->walk([&](Operation *op) {
106 if (succeeded(applicator.matchAndRewrite(op, rewriter)))
107 results.push_back(op);
108 });
109
110 return success();
111}
112} // namespace
113
114//===----------------------------------------------------------------------===//
115// PDLMatchHooks
116//===----------------------------------------------------------------------===//
117
119 llvm::StringMap<PDLConstraintFunction> &&constraintFns) {
120 // Steal the constraint functions from the given map.
121 for (auto &it : constraintFns)
122 pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second));
123}
124
125const llvm::StringMap<PDLConstraintFunction> &
127 return pdlMatchHooks.getConstraintFunctions();
128}
129
130//===----------------------------------------------------------------------===//
131// PDLMatchOp
132//===----------------------------------------------------------------------===//
133
135transform::PDLMatchOp::apply(transform::TransformRewriter &rewriter,
138 auto *extension = state.getExtension<PatternApplicatorExtension>();
139 assert(extension &&
140 "expected PatternApplicatorExtension to be attached by the parent op");
142 for (Operation *root : state.getPayloadOps(getRoot())) {
143 if (failed(extension->findAllMatches(
144 getPatternName().getLeafReference().getValue(), root, targets))) {
146 << "could not find pattern '" << getPatternName() << "'";
147 }
148 }
149 results.set(llvm::cast<OpResult>(getResult()), targets);
151}
152
153void transform::PDLMatchOp::getEffects(
154 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
155 onlyReadsHandle(getRootMutable(), effects);
156 producesHandle(getOperation()->getOpResults(), effects);
157 onlyReadsPayload(effects);
158}
159
160//===----------------------------------------------------------------------===//
161// WithPDLPatternsOp
162//===----------------------------------------------------------------------===//
163
164DiagnosedSilenceableFailure
165transform::WithPDLPatternsOp::apply(transform::TransformRewriter &rewriter,
166 transform::TransformResults &results,
167 transform::TransformState &state) {
168 TransformOpInterface transformOp = nullptr;
169 for (Operation &nested : getBody().front()) {
170 if (!isa<pdl::PatternOp>(nested)) {
171 transformOp = cast<TransformOpInterface>(nested);
172 break;
173 }
174 }
175
176 state.addExtension<PatternApplicatorExtension>(getOperation());
177 auto guard = llvm::make_scope_exit(
178 [&]() { state.removeExtension<PatternApplicatorExtension>(); });
179
180 auto scope = state.make_region_scope(getBody());
181 if (failed(mapBlockArguments(state)))
183 return state.applyTransform(transformOp);
184}
185
186void transform::WithPDLPatternsOp::getEffects(
187 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
189}
190
191LogicalResult transform::WithPDLPatternsOp::verify() {
192 Block *body = getBodyBlock();
193 Operation *topLevelOp = nullptr;
194 for (Operation &op : body->getOperations()) {
195 if (isa<pdl::PatternOp>(op))
196 continue;
197
198 if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
199 if (topLevelOp) {
200 InFlightDiagnostic diag =
201 emitOpError() << "expects only one non-pattern op in its body";
202 diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
203 diag.attachNote(op.getLoc()) << "second non-pattern op";
204 return diag;
205 }
206 topLevelOp = &op;
207 continue;
208 }
209
210 InFlightDiagnostic diag =
212 << "expects only pattern and top-level transform ops in its body";
213 diag.attachNote(op.getLoc()) << "offending op";
214 return diag;
215 }
216
217 if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
218 InFlightDiagnostic diag = emitOpError() << "cannot be nested";
219 diag.attachNote(parent.getLoc()) << "parent operation";
220 return diag;
221 }
222
223 if (!topLevelOp) {
224 InFlightDiagnostic diag = emitOpError()
225 << "expects at least one non-pattern op";
226 return diag;
227 }
228
229 return success();
230}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition TypeID.h:323
OpListType & getOperations()
Definition Block.h:137
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition Builders.h:246
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class acts as an owning reference to an op, and will automatically destroy the held op on destru...
Definition OwningOpRef.h:29
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})
Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...
void applyDefaultCostModel()
Apply the default cost model that solely uses the pattern's static benefit.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
PDL constraint callbacks that can be used by the PDL extension of the Transform dialect.
void mergeInPDLMatchHooks(llvm::StringMap< PDLConstraintFunction > &&constraintFns)
Takes ownership of the named PDL constraint function from the given map and makes them available for ...
const llvm::StringMap<::mlir::PDLConstraintFunction > & getPDLConstraintHooks() const
Returns the named PDL constraint functions available in the dialect as a map from their name to the f...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
Base class for TransformState extensions that allow TransformState to contain user-specified informat...
The state maintained across applications of various ops implementing the TransformOpInterface.
Ty & addExtension(Args &&...args)
Adds a new Extension of the type specified as template parameter, constructing it with the arguments ...
Ty * getExtension()
Returns the extension of the specified type.
auto getPayloadOps(Value value) const
Returns an iterator that enumerates all ops that the given transform IR value corresponds to.
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
RegionScope make_region_scope(Region &region)
Creates a new region scope for the given region.
void removeExtension()
Removes the extension of the specified type.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Include the generated interface declarations.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns