MLIR  22.0.0git
PDLExtensionOps.cpp
Go to the documentation of this file.
1 //===- PDLExtensionOps.cpp - PDL extension for the Transform dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/Builders.h"
14 #include "llvm/ADT/ScopeExit.h"
15 
16 using namespace mlir;
17 
19 
20 #define GET_OP_CLASSES
21 #include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc"
22 
23 //===----------------------------------------------------------------------===//
24 // PatternApplicatorExtension
25 //===----------------------------------------------------------------------===//
26 
27 namespace {
28 /// A TransformState extension that keeps track of compiled PDL pattern sets.
29 /// This is intended to be used along the WithPDLPatterns op. The extension
30 /// can be constructed given an operation that has a SymbolTable trait and
31 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one
32 /// by one when requested; this behavior is subject to change.
33 class PatternApplicatorExtension : public transform::TransformState::Extension {
34 public:
35  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
36 
37  /// Creates the extension for patterns contained in `patternContainer`.
38  explicit PatternApplicatorExtension(transform::TransformState &state,
39  Operation *patternContainer)
40  : Extension(state), patterns(patternContainer) {}
41 
42  /// Appends to `results` the operations contained in `root` that matched the
43  /// PDL pattern with the given name. Note that `root` may or may not be the
44  /// operation that contains PDL patterns. Reports an error if the pattern
45  /// cannot be found. Note that when no operations are matched, this still
46  /// succeeds as long as the pattern exists.
47  LogicalResult findAllMatches(StringRef patternName, Operation *root,
49 
50 private:
51  /// Map from the pattern name to a singleton set of rewrite patterns that only
52  /// contains the pattern with this name. Populated when the pattern is first
53  /// requested.
54  // TODO: reconsider the efficiency of this storage when more usage data is
55  // available. Storing individual patterns in a set and triggering compilation
56  // for each of them has overhead. So does compiling a large set of patterns
57  // only to apply a handful of them.
58  llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
59 
60  /// A symbol table operation containing the relevant PDL patterns.
62 };
63 
64 LogicalResult PatternApplicatorExtension::findAllMatches(
65  StringRef patternName, Operation *root,
67  auto it = compiledPatterns.find(patternName);
68  if (it == compiledPatterns.end()) {
69  auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
70  if (!patternOp)
71  return failure();
72 
73  // Copy the pattern operation into a new module that is compiled and
74  // consumed by the PDL interpreter.
75  OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
76  auto builder = OpBuilder::atBlockEnd(pdlModuleOp->getBody());
77  builder.clone(*patternOp);
78  PDLPatternModule patternModule(std::move(pdlModuleOp));
79 
80  // Merge in the hooks owned by the dialect. Make a copy as they may be
81  // also used by the following operations.
82  auto *dialect =
83  root->getContext()->getLoadedDialect<transform::TransformDialect>();
84  for (const auto &[name, constraintFn] :
85  dialect->getExtraData<transform::PDLMatchHooks>()
87  patternModule.registerConstraintFunction(name, constraintFn);
88  }
89 
90  // Register a noop rewriter because PDL requires patterns to end with some
91  // rewrite call.
92  patternModule.registerRewriteFunction(
93  "transform.dialect", [](PatternRewriter &, Operation *) {});
94 
95  it = compiledPatterns
96  .try_emplace(patternOp.getName(), std::move(patternModule))
97  .first;
98  }
99 
100  PatternApplicator applicator(it->second);
101  // We want to discourage direct use of PatternRewriter in APIs but In this
102  // very specific case, an IRRewriter is not enough.
103  PatternRewriter rewriter(root->getContext());
104  applicator.applyDefaultCostModel();
105  root->walk([&](Operation *op) {
106  if (succeeded(applicator.matchAndRewrite(op, rewriter)))
107  results.push_back(op);
108  });
109 
110  return success();
111 }
112 } // namespace
113 
114 //===----------------------------------------------------------------------===//
115 // PDLMatchHooks
116 //===----------------------------------------------------------------------===//
117 
119  llvm::StringMap<PDLConstraintFunction> &&constraintFns) {
120  // Steal the constraint functions from the given map.
121  for (auto &it : constraintFns)
122  pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second));
123 }
124 
125 const llvm::StringMap<PDLConstraintFunction> &
127  return pdlMatchHooks.getConstraintFunctions();
128 }
129 
130 //===----------------------------------------------------------------------===//
131 // PDLMatchOp
132 //===----------------------------------------------------------------------===//
133 
135 transform::PDLMatchOp::apply(transform::TransformRewriter &rewriter,
137  transform::TransformState &state) {
138  auto *extension = state.getExtension<PatternApplicatorExtension>();
139  assert(extension &&
140  "expected PatternApplicatorExtension to be attached by the parent op");
141  SmallVector<Operation *> targets;
142  for (Operation *root : state.getPayloadOps(getRoot())) {
143  if (failed(extension->findAllMatches(
144  getPatternName().getLeafReference().getValue(), root, targets))) {
146  << "could not find pattern '" << getPatternName() << "'";
147  }
148  }
149  results.set(llvm::cast<OpResult>(getResult()), targets);
151 }
152 
153 void transform::PDLMatchOp::getEffects(
155  onlyReadsHandle(getRootMutable(), effects);
156  producesHandle(getOperation()->getOpResults(), effects);
157  onlyReadsPayload(effects);
158 }
159 
160 //===----------------------------------------------------------------------===//
161 // WithPDLPatternsOp
162 //===----------------------------------------------------------------------===//
163 
165 transform::WithPDLPatternsOp::apply(transform::TransformRewriter &rewriter,
167  transform::TransformState &state) {
168  TransformOpInterface transformOp = nullptr;
169  for (Operation &nested : getBody().front()) {
170  if (!isa<pdl::PatternOp>(nested)) {
171  transformOp = cast<TransformOpInterface>(nested);
172  break;
173  }
174  }
175 
176  state.addExtension<PatternApplicatorExtension>(getOperation());
177  auto guard = llvm::make_scope_exit(
178  [&]() { state.removeExtension<PatternApplicatorExtension>(); });
179 
180  auto scope = state.make_region_scope(getBody());
181  if (failed(mapBlockArguments(state)))
183  return state.applyTransform(transformOp);
184 }
185 
186 void transform::WithPDLPatternsOp::getEffects(
189 }
190 
191 LogicalResult transform::WithPDLPatternsOp::verify() {
192  Block *body = getBodyBlock();
193  Operation *topLevelOp = nullptr;
194  for (Operation &op : body->getOperations()) {
195  if (isa<pdl::PatternOp>(op))
196  continue;
197 
199  if (topLevelOp) {
201  emitOpError() << "expects only one non-pattern op in its body";
202  diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
203  diag.attachNote(op.getLoc()) << "second non-pattern op";
204  return diag;
205  }
206  topLevelOp = &op;
207  continue;
208  }
209 
211  emitOpError()
212  << "expects only pattern and top-level transform ops in its body";
213  diag.attachNote(op.getLoc()) << "offending op";
214  return diag;
215  }
216 
217  if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
218  InFlightDiagnostic diag = emitOpError() << "cannot be nested";
219  diag.attachNote(parent.getLoc()) << "parent operation";
220  return diag;
221  }
222 
223  if (!topLevelOp) {
224  InFlightDiagnostic diag = emitOpError()
225  << "expects at least one non-pattern op";
226  return diag;
227  }
228 
229  return success();
230 }
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:323
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType & getOperations()
Definition: Block.h:137
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition: Builders.h:244
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})
Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...
void applyDefaultCostModel()
Apply the default cost model that solely uses the pattern's static benefit.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
PDL constraint callbacks that can be used by the PDL extension of the Transform dialect.
void mergeInPDLMatchHooks(llvm::StringMap< PDLConstraintFunction > &&constraintFns)
Takes ownership of the named PDL constraint function from the given map and makes them available for ...
const llvm::StringMap<::mlir::PDLConstraintFunction > & getPDLConstraintHooks() const
Returns the named PDL constraint functions available in the dialect as a map from their name to the f...
This trait is supposed to be attached to Transform dialect operations that can be standalone top-leve...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
Base class for TransformState extensions that allow TransformState to contain user-specified informat...
The state maintained across applications of various ops implementing the TransformOpInterface.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Include the generated interface declarations.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423