MLIR  19.0.0git
PDLExtensionOps.cpp
Go to the documentation of this file.
1 //===- PDLExtensionOps.cpp - PDL extension for the Transform dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/Builders.h"
15 #include "llvm/ADT/ScopeExit.h"
16 
17 using namespace mlir;
18 
20 
21 #define GET_OP_CLASSES
22 #include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc"
23 
24 //===----------------------------------------------------------------------===//
25 // PatternApplicatorExtension
26 //===----------------------------------------------------------------------===//
27 
28 namespace {
29 /// A TransformState extension that keeps track of compiled PDL pattern sets.
30 /// This is intended to be used along the WithPDLPatterns op. The extension
31 /// can be constructed given an operation that has a SymbolTable trait and
32 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one
33 /// by one when requested; this behavior is subject to change.
34 class PatternApplicatorExtension : public transform::TransformState::Extension {
35 public:
36  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
37 
38  /// Creates the extension for patterns contained in `patternContainer`.
39  explicit PatternApplicatorExtension(transform::TransformState &state,
40  Operation *patternContainer)
41  : Extension(state), patterns(patternContainer) {}
42 
43  /// Appends to `results` the operations contained in `root` that matched the
44  /// PDL pattern with the given name. Note that `root` may or may not be the
45  /// operation that contains PDL patterns. Reports an error if the pattern
46  /// cannot be found. Note that when no operations are matched, this still
47  /// succeeds as long as the pattern exists.
48  LogicalResult findAllMatches(StringRef patternName, Operation *root,
50 
51 private:
52  /// Map from the pattern name to a singleton set of rewrite patterns that only
53  /// contains the pattern with this name. Populated when the pattern is first
54  /// requested.
55  // TODO: reconsider the efficiency of this storage when more usage data is
56  // available. Storing individual patterns in a set and triggering compilation
57  // for each of them has overhead. So does compiling a large set of patterns
58  // only to apply a handful of them.
59  llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
60 
61  /// A symbol table operation containing the relevant PDL patterns.
62  SymbolTable patterns;
63 };
64 
65 LogicalResult PatternApplicatorExtension::findAllMatches(
66  StringRef patternName, Operation *root,
68  auto it = compiledPatterns.find(patternName);
69  if (it == compiledPatterns.end()) {
70  auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
71  if (!patternOp)
72  return failure();
73 
74  // Copy the pattern operation into a new module that is compiled and
75  // consumed by the PDL interpreter.
76  OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
77  auto builder = OpBuilder::atBlockEnd(pdlModuleOp->getBody());
78  builder.clone(*patternOp);
79  PDLPatternModule patternModule(std::move(pdlModuleOp));
80 
81  // Merge in the hooks owned by the dialect. Make a copy as they may be
82  // also used by the following operations.
83  auto *dialect =
84  root->getContext()->getLoadedDialect<transform::TransformDialect>();
85  for (const auto &[name, constraintFn] :
86  dialect->getExtraData<transform::PDLMatchHooks>()
88  patternModule.registerConstraintFunction(name, constraintFn);
89  }
90 
91  // Register a noop rewriter because PDL requires patterns to end with some
92  // rewrite call.
93  patternModule.registerRewriteFunction(
94  "transform.dialect", [](PatternRewriter &, Operation *) {});
95 
96  it = compiledPatterns
97  .try_emplace(patternOp.getName(), std::move(patternModule))
98  .first;
99  }
100 
101  PatternApplicator applicator(it->second);
102  // We want to discourage direct use of PatternRewriter in APIs but In this
103  // very specific case, an IRRewriter is not enough.
104  struct TrivialPatternRewriter : public PatternRewriter {
105  public:
106  explicit TrivialPatternRewriter(MLIRContext *context)
107  : PatternRewriter(context) {}
108  };
109  TrivialPatternRewriter rewriter(root->getContext());
110  applicator.applyDefaultCostModel();
111  root->walk([&](Operation *op) {
112  if (succeeded(applicator.matchAndRewrite(op, rewriter)))
113  results.push_back(op);
114  });
115 
116  return success();
117 }
118 } // namespace
119 
120 //===----------------------------------------------------------------------===//
121 // PDLMatchHooks
122 //===----------------------------------------------------------------------===//
123 
125  llvm::StringMap<PDLConstraintFunction> &&constraintFns) {
126  // Steal the constraint functions from the given map.
127  for (auto &it : constraintFns)
128  pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second));
129 }
130 
131 const llvm::StringMap<PDLConstraintFunction> &
133  return pdlMatchHooks.getConstraintFunctions();
134 }
135 
136 //===----------------------------------------------------------------------===//
137 // PDLMatchOp
138 //===----------------------------------------------------------------------===//
139 
141 transform::PDLMatchOp::apply(transform::TransformRewriter &rewriter,
143  transform::TransformState &state) {
144  auto *extension = state.getExtension<PatternApplicatorExtension>();
145  assert(extension &&
146  "expected PatternApplicatorExtension to be attached by the parent op");
147  SmallVector<Operation *> targets;
148  for (Operation *root : state.getPayloadOps(getRoot())) {
149  if (failed(extension->findAllMatches(
150  getPatternName().getLeafReference().getValue(), root, targets))) {
152  << "could not find pattern '" << getPatternName() << "'";
153  }
154  }
155  results.set(llvm::cast<OpResult>(getResult()), targets);
157 }
158 
159 void transform::PDLMatchOp::getEffects(
161  onlyReadsHandle(getRoot(), effects);
162  producesHandle(getMatched(), effects);
163  onlyReadsPayload(effects);
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // WithPDLPatternsOp
168 //===----------------------------------------------------------------------===//
169 
171 transform::WithPDLPatternsOp::apply(transform::TransformRewriter &rewriter,
173  transform::TransformState &state) {
174  TransformOpInterface transformOp = nullptr;
175  for (Operation &nested : getBody().front()) {
176  if (!isa<pdl::PatternOp>(nested)) {
177  transformOp = cast<TransformOpInterface>(nested);
178  break;
179  }
180  }
181 
182  state.addExtension<PatternApplicatorExtension>(getOperation());
183  auto guard = llvm::make_scope_exit(
184  [&]() { state.removeExtension<PatternApplicatorExtension>(); });
185 
186  auto scope = state.make_region_scope(getBody());
187  if (failed(mapBlockArguments(state)))
189  return state.applyTransform(transformOp);
190 }
191 
192 void transform::WithPDLPatternsOp::getEffects(
195 }
196 
198  Block *body = getBodyBlock();
199  Operation *topLevelOp = nullptr;
200  for (Operation &op : body->getOperations()) {
201  if (isa<pdl::PatternOp>(op))
202  continue;
203 
205  if (topLevelOp) {
207  emitOpError() << "expects only one non-pattern op in its body";
208  diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
209  diag.attachNote(op.getLoc()) << "second non-pattern op";
210  return diag;
211  }
212  topLevelOp = &op;
213  continue;
214  }
215 
217  emitOpError()
218  << "expects only pattern and top-level transform ops in its body";
219  diag.attachNote(op.getLoc()) << "offending op";
220  return diag;
221  }
222 
223  if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
224  InFlightDiagnostic diag = emitOpError() << "cannot be nested";
225  diag.attachNote(parent.getLoc()) << "parent operation";
226  return diag;
227  }
228 
229  if (!topLevelOp) {
230  InFlightDiagnostic diag = emitOpError()
231  << "expects at least one non-pattern op";
232  return diag;
233  }
234 
235  return success();
236 }
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:274
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:263
Block represents an ordered list of Operations.
Definition: Block.h:30
OpListType & getOperations()
Definition: Block.h:134
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition: Builders.h:248
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})
Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...
void applyDefaultCostModel()
Apply the default cost model that solely uses the pattern's static benefit.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
PDL constraint callbacks that can be used by the PDL extension of the Transform dialect.
void mergeInPDLMatchHooks(llvm::StringMap< PDLConstraintFunction > &&constraintFns)
Takes ownership of the named PDL constraint function from the given map and makes them available for ...
const llvm::StringMap<::mlir::PDLConstraintFunction > & getPDLConstraintHooks() const
Returns the named PDL constraint functions available in the dialect as a map from their name to the f...
This trait is supposed to be attached to Transform dialect operations that can be standalone top-leve...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
Base class for TransformState extensions that allow TransformState to contain user-specified informat...
The state maintained across applications of various ops implementing the TransformOpInterface.
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void producesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26