MLIR  18.0.0git
PDLToPDLInterp.cpp
Go to the documentation of this file.
1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "PredicateTree.h"
15 #include "mlir/Pass/Pass.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/ScopedHashTable.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_CONVERTPDLTOPDLINTERP
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 using namespace mlir::pdl_to_pdl_interp;
30 
31 //===----------------------------------------------------------------------===//
32 // PatternLowering
33 //===----------------------------------------------------------------------===//
34 
35 namespace {
36 /// This class generators operations within the PDL Interpreter dialect from a
37 /// given module containing PDL pattern operations.
38 struct PatternLowering {
39 public:
40  PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
42 
43  /// Generate code for matching and rewriting based on the pattern operations
44  /// within the module.
45  void lower(ModuleOp module);
46 
47 private:
48  using ValueMap = llvm::ScopedHashTable<Position *, Value>;
49  using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
50 
51  /// Generate interpreter operations for the tree rooted at the given matcher
52  /// node, in the specified region.
53  Block *generateMatcher(MatcherNode &node, Region &region);
54 
55  /// Get or create an access to the provided positional value in the current
56  /// block. This operation may mutate the provided block pointer if nested
57  /// regions (i.e., pdl_interp.iterate) are required.
58  Value getValueAt(Block *&currentBlock, Position *pos);
59 
60  /// Create the interpreter predicate operations. This operation may mutate the
61  /// provided current block pointer if nested regions (iterates) are required.
62  void generate(BoolNode *boolNode, Block *&currentBlock, Value val);
63 
64  /// Create the interpreter switch / predicate operations, with several case
65  /// destinations. This operation never mutates the provided current block
66  /// pointer, because the switch operation does not need Values beyond `val`.
67  void generate(SwitchNode *switchNode, Block *currentBlock, Value val);
68 
69  /// Create the interpreter operations to record a successful pattern match
70  /// using the contained root operation. This operation may mutate the current
71  /// block pointer if nested regions (i.e., pdl_interp.iterate) are required.
72  void generate(SuccessNode *successNode, Block *&currentBlock);
73 
74  /// Generate a rewriter function for the given pattern operation, and returns
75  /// a reference to that function.
76  SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
77  SmallVectorImpl<Position *> &usedMatchValues);
78 
79  /// Generate the rewriter code for the given operation.
80  void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
81  DenseMap<Value, Value> &rewriteValues,
82  function_ref<Value(Value)> mapRewriteValue);
83  void generateRewriter(pdl::AttributeOp attrOp,
84  DenseMap<Value, Value> &rewriteValues,
85  function_ref<Value(Value)> mapRewriteValue);
86  void generateRewriter(pdl::EraseOp eraseOp,
87  DenseMap<Value, Value> &rewriteValues,
88  function_ref<Value(Value)> mapRewriteValue);
89  void generateRewriter(pdl::OperationOp operationOp,
90  DenseMap<Value, Value> &rewriteValues,
91  function_ref<Value(Value)> mapRewriteValue);
92  void generateRewriter(pdl::RangeOp rangeOp,
93  DenseMap<Value, Value> &rewriteValues,
94  function_ref<Value(Value)> mapRewriteValue);
95  void generateRewriter(pdl::ReplaceOp replaceOp,
96  DenseMap<Value, Value> &rewriteValues,
97  function_ref<Value(Value)> mapRewriteValue);
98  void generateRewriter(pdl::ResultOp resultOp,
99  DenseMap<Value, Value> &rewriteValues,
100  function_ref<Value(Value)> mapRewriteValue);
101  void generateRewriter(pdl::ResultsOp resultOp,
102  DenseMap<Value, Value> &rewriteValues,
103  function_ref<Value(Value)> mapRewriteValue);
104  void generateRewriter(pdl::TypeOp typeOp,
105  DenseMap<Value, Value> &rewriteValues,
106  function_ref<Value(Value)> mapRewriteValue);
107  void generateRewriter(pdl::TypesOp typeOp,
108  DenseMap<Value, Value> &rewriteValues,
109  function_ref<Value(Value)> mapRewriteValue);
110 
111  /// Generate the values used for resolving the result types of an operation
112  /// created within a dag rewriter region. If the result types of the operation
113  /// should be inferred, `hasInferredResultTypes` is set to true.
114  void generateOperationResultTypeRewriter(
115  pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
116  SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
117  bool &hasInferredResultTypes);
118 
119  /// A builder to use when generating interpreter operations.
120  OpBuilder builder;
121 
122  /// The matcher function used for all match related logic within PDL patterns.
123  pdl_interp::FuncOp matcherFunc;
124 
125  /// The rewriter module containing the all rewrite related logic within PDL
126  /// patterns.
127  ModuleOp rewriterModule;
128 
129  /// The symbol table of the rewriter module used for insertion.
130  SymbolTable rewriterSymbolTable;
131 
132  /// A scoped map connecting a position with the corresponding interpreter
133  /// value.
134  ValueMap values;
135 
136  /// A stack of blocks used as the failure destination for matcher nodes that
137  /// don't have an explicit failure path.
138  SmallVector<Block *, 8> failureBlockStack;
139 
140  /// A mapping between values defined in a pattern match, and the corresponding
141  /// positional value.
142  DenseMap<Value, Position *> valueToPosition;
143 
144  /// The set of operation values whose location will be used for newly
145  /// generated operations.
146  SetVector<Value> locOps;
147 
148  /// A mapping between pattern operations and the corresponding configuration
149  /// set.
151 };
152 } // namespace
153 
154 PatternLowering::PatternLowering(
155  pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
157  : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
158  rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule),
159  configMap(configMap) {}
160 
161 void PatternLowering::lower(ModuleOp module) {
162  PredicateUniquer predicateUniquer;
163  PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
164 
165  // Define top-level scope for the arguments to the matcher function.
166  ValueMapScope topLevelValueScope(values);
167 
168  // Insert the root operation, i.e. argument to the matcher, at the root
169  // position.
170  Block *matcherEntryBlock = &matcherFunc.front();
171  values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
172 
173  // Generate a root matcher node from the provided PDL module.
174  std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
175  module, predicateBuilder, valueToPosition);
176  Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody());
177  assert(failureBlockStack.empty() && "failed to empty the stack");
178 
179  // After generation, merged the first matched block into the entry.
180  matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
181  firstMatcherBlock->getOperations());
182  firstMatcherBlock->erase();
183 }
184 
185 Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region) {
186  // Push a new scope for the values used by this matcher.
187  Block *block = &region.emplaceBlock();
188  ValueMapScope scope(values);
189 
190  // If this is the return node, simply insert the corresponding interpreter
191  // finalize.
192  if (isa<ExitNode>(node)) {
193  builder.setInsertionPointToEnd(block);
194  builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
195  return block;
196  }
197 
198  // Get the next block in the match sequence.
199  // This is intentionally executed first, before we get the value for the
200  // position associated with the node, so that we preserve an "there exist"
201  // semantics: if getting a value requires an upward traversal (going from a
202  // value to its consumers), we want to perform the check on all the consumers
203  // before we pass control to the failure node.
204  std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
205  Block *failureBlock;
206  if (failureNode) {
207  failureBlock = generateMatcher(*failureNode, region);
208  failureBlockStack.push_back(failureBlock);
209  } else {
210  assert(!failureBlockStack.empty() && "expected valid failure block");
211  failureBlock = failureBlockStack.back();
212  }
213 
214  // If this node contains a position, get the corresponding value for this
215  // block.
216  Block *currentBlock = block;
217  Position *position = node.getPosition();
218  Value val = position ? getValueAt(currentBlock, position) : Value();
219 
220  // If this value corresponds to an operation, record that we are going to use
221  // its location as part of a fused location.
222  bool isOperationValue = val && isa<pdl::OperationType>(val.getType());
223  if (isOperationValue)
224  locOps.insert(val);
225 
226  // Dispatch to the correct method based on derived node type.
228  .Case<BoolNode, SwitchNode>([&](auto *derivedNode) {
229  this->generate(derivedNode, currentBlock, val);
230  })
231  .Case([&](SuccessNode *successNode) {
232  generate(successNode, currentBlock);
233  });
234 
235  // Pop all the failure blocks that were inserted due to nesting of
236  // pdl_interp.iterate.
237  while (failureBlockStack.back() != failureBlock) {
238  failureBlockStack.pop_back();
239  assert(!failureBlockStack.empty() && "unable to locate failure block");
240  }
241 
242  // Pop the new failure block.
243  if (failureNode)
244  failureBlockStack.pop_back();
245 
246  if (isOperationValue)
247  locOps.remove(val);
248 
249  return block;
250 }
251 
252 Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
253  if (Value val = values.lookup(pos))
254  return val;
255 
256  // Get the value for the parent position.
257  Value parentVal;
258  if (Position *parent = pos->getParent())
259  parentVal = getValueAt(currentBlock, parent);
260 
261  // TODO: Use a location from the position.
262  Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc();
263  builder.setInsertionPointToEnd(currentBlock);
264  Value value;
265  switch (pos->getKind()) {
267  auto *operationPos = cast<OperationPosition>(pos);
268  if (operationPos->isOperandDefiningOp())
269  // Standard (downward) traversal which directly follows the defining op.
270  value = builder.create<pdl_interp::GetDefiningOpOp>(
271  loc, builder.getType<pdl::OperationType>(), parentVal);
272  else
273  // A passthrough operation position.
274  value = parentVal;
275  break;
276  }
277  case Predicates::UsersPos: {
278  auto *usersPos = cast<UsersPosition>(pos);
279 
280  // The first operation retrieves the representative value of a range.
281  // This applies only when the parent is a range of values and we were
282  // requested to use a representative value (e.g., upward traversal).
283  if (isa<pdl::RangeType>(parentVal.getType()) &&
284  usersPos->useRepresentative())
285  value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0);
286  else
287  value = parentVal;
288 
289  // The second operation retrieves the users.
290  value = builder.create<pdl_interp::GetUsersOp>(loc, value);
291  break;
292  }
293  case Predicates::ForEachPos: {
294  assert(!failureBlockStack.empty() && "expected valid failure block");
295  auto foreach = builder.create<pdl_interp::ForEachOp>(
296  loc, parentVal, failureBlockStack.back(), /*initLoop=*/true);
297  value = foreach.getLoopVariable();
298 
299  // Create the continuation block.
300  Block *continueBlock = builder.createBlock(&foreach.getRegion());
301  builder.create<pdl_interp::ContinueOp>(loc);
302  failureBlockStack.push_back(continueBlock);
303 
304  currentBlock = &foreach.getRegion().front();
305  break;
306  }
307  case Predicates::OperandPos: {
308  auto *operandPos = cast<OperandPosition>(pos);
309  value = builder.create<pdl_interp::GetOperandOp>(
310  loc, builder.getType<pdl::ValueType>(), parentVal,
311  operandPos->getOperandNumber());
312  break;
313  }
315  auto *operandPos = cast<OperandGroupPosition>(pos);
316  Type valueTy = builder.getType<pdl::ValueType>();
317  value = builder.create<pdl_interp::GetOperandsOp>(
318  loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
319  parentVal, operandPos->getOperandGroupNumber());
320  break;
321  }
323  auto *attrPos = cast<AttributePosition>(pos);
324  value = builder.create<pdl_interp::GetAttributeOp>(
325  loc, builder.getType<pdl::AttributeType>(), parentVal,
326  attrPos->getName().strref());
327  break;
328  }
329  case Predicates::TypePos: {
330  if (isa<pdl::AttributeType>(parentVal.getType()))
331  value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
332  else
333  value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
334  break;
335  }
336  case Predicates::ResultPos: {
337  auto *resPos = cast<ResultPosition>(pos);
338  value = builder.create<pdl_interp::GetResultOp>(
339  loc, builder.getType<pdl::ValueType>(), parentVal,
340  resPos->getResultNumber());
341  break;
342  }
344  auto *resPos = cast<ResultGroupPosition>(pos);
345  Type valueTy = builder.getType<pdl::ValueType>();
346  value = builder.create<pdl_interp::GetResultsOp>(
347  loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
348  parentVal, resPos->getResultGroupNumber());
349  break;
350  }
352  auto *attrPos = cast<AttributeLiteralPosition>(pos);
353  value =
354  builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue());
355  break;
356  }
358  auto *typePos = cast<TypeLiteralPosition>(pos);
359  Attribute rawTypeAttr = typePos->getValue();
360  if (TypeAttr typeAttr = dyn_cast<TypeAttr>(rawTypeAttr))
361  value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr);
362  else
363  value = builder.create<pdl_interp::CreateTypesOp>(
364  loc, cast<ArrayAttr>(rawTypeAttr));
365  break;
366  }
367  default:
368  llvm_unreachable("Generating unknown Position getter");
369  break;
370  }
371 
372  values.insert(pos, value);
373  return value;
374 }
375 
376 void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
377  Value val) {
378  Location loc = val.getLoc();
379  Qualifier *question = boolNode->getQuestion();
380  Qualifier *answer = boolNode->getAnswer();
381  Region *region = currentBlock->getParent();
382 
383  // Execute the getValue queries first, so that we create success
384  // matcher in the correct (possibly nested) region.
385  SmallVector<Value> args;
386  if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
387  args = {getValueAt(currentBlock, equalToQuestion->getValue())};
388  } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
389  for (Position *position : cstQuestion->getArgs())
390  args.push_back(getValueAt(currentBlock, position));
391  }
392 
393  // Generate the matcher in the current (potentially nested) region
394  // and get the failure successor.
395  Block *success = generateMatcher(*boolNode->getSuccessNode(), *region);
396  Block *failure = failureBlockStack.back();
397 
398  // Finally, create the predicate.
399  builder.setInsertionPointToEnd(currentBlock);
400  Predicates::Kind kind = question->getKind();
401  switch (kind) {
403  builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure);
404  break;
406  auto *opNameAnswer = cast<OperationNameAnswer>(answer);
407  builder.create<pdl_interp::CheckOperationNameOp>(
408  loc, val, opNameAnswer->getValue().getStringRef(), success, failure);
409  break;
410  }
412  auto *ans = cast<TypeAnswer>(answer);
413  if (isa<pdl::RangeType>(val.getType()))
414  builder.create<pdl_interp::CheckTypesOp>(
415  loc, val, llvm::cast<ArrayAttr>(ans->getValue()), success, failure);
416  else
417  builder.create<pdl_interp::CheckTypeOp>(
418  loc, val, llvm::cast<TypeAttr>(ans->getValue()), success, failure);
419  break;
420  }
422  auto *ans = cast<AttributeAnswer>(answer);
423  builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
424  success, failure);
425  break;
426  }
429  builder.create<pdl_interp::CheckOperandCountOp>(
430  loc, val, cast<UnsignedAnswer>(answer)->getValue(),
431  /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
432  success, failure);
433  break;
436  builder.create<pdl_interp::CheckResultCountOp>(
437  loc, val, cast<UnsignedAnswer>(answer)->getValue(),
438  /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
439  success, failure);
440  break;
442  bool trueAnswer = isa<TrueAnswer>(answer);
443  builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(),
444  trueAnswer ? success : failure,
445  trueAnswer ? failure : success);
446  break;
447  }
449  auto *cstQuestion = cast<ConstraintQuestion>(question);
450  builder.create<pdl_interp::ApplyConstraintOp>(
451  loc, cstQuestion->getName(), args, cstQuestion->getIsNegated(), success,
452  failure);
453  break;
454  }
455  default:
456  llvm_unreachable("Generating unknown Predicate operation");
457  }
458 }
459 
460 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
461 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
462  llvm::MapVector<Qualifier *, Block *> &dests) {
463  std::vector<ValT> values;
464  std::vector<Block *> blocks;
465  values.reserve(dests.size());
466  blocks.reserve(dests.size());
467  for (const auto &it : dests) {
468  blocks.push_back(it.second);
469  values.push_back(cast<PredT>(it.first)->getValue());
470  }
471  builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
472 }
473 
474 void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
475  Value val) {
476  Qualifier *question = switchNode->getQuestion();
477  Region *region = currentBlock->getParent();
478  Block *defaultDest = failureBlockStack.back();
479 
480  // If the switch question is not an exact answer, i.e. for the `at_least`
481  // cases, we generate a special block sequence.
482  Predicates::Kind kind = question->getKind();
485  // Order the children such that the cases are in reverse numerical order.
486  SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
487  llvm::seq<unsigned>(0, switchNode->getChildren().size()));
488  llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) {
489  return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() >
490  cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue();
491  });
492 
493  // Build the destination for each child using the next highest child as a
494  // a failure destination. This essentially creates the following control
495  // flow:
496  //
497  // if (operand_count < 1)
498  // goto failure
499  // if (child1.match())
500  // ...
501  //
502  // if (operand_count < 2)
503  // goto failure
504  // if (child2.match())
505  // ...
506  //
507  // failure:
508  // ...
509  //
510  failureBlockStack.push_back(defaultDest);
511  Location loc = val.getLoc();
512  for (unsigned idx : sortedChildren) {
513  auto &child = switchNode->getChild(idx);
514  Block *childBlock = generateMatcher(*child.second, *region);
515  Block *predicateBlock = builder.createBlock(childBlock);
516  builder.setInsertionPointToEnd(predicateBlock);
517  unsigned ans = cast<UnsignedAnswer>(child.first)->getValue();
518  switch (kind) {
520  builder.create<pdl_interp::CheckOperandCountOp>(
521  loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
522  break;
524  builder.create<pdl_interp::CheckResultCountOp>(
525  loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
526  break;
527  default:
528  llvm_unreachable("Generating invalid AtLeast operation");
529  }
530  failureBlockStack.back() = predicateBlock;
531  }
532  Block *firstPredicateBlock = failureBlockStack.pop_back_val();
533  currentBlock->getOperations().splice(currentBlock->end(),
534  firstPredicateBlock->getOperations());
535  firstPredicateBlock->erase();
536  return;
537  }
538 
539  // Otherwise, generate each of the children and generate an interpreter
540  // switch.
541  llvm::MapVector<Qualifier *, Block *> children;
542  for (auto &it : switchNode->getChildren())
543  children.insert({it.first, generateMatcher(*it.second, *region)});
544  builder.setInsertionPointToEnd(currentBlock);
545 
546  switch (question->getKind()) {
548  return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
549  int32_t>(val, defaultDest, builder, children);
551  return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
552  int32_t>(val, defaultDest, builder, children);
554  return createSwitchOp<pdl_interp::SwitchOperationNameOp,
555  OperationNameAnswer>(val, defaultDest, builder,
556  children);
558  if (isa<pdl::RangeType>(val.getType())) {
559  return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
560  val, defaultDest, builder, children);
561  }
562  return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
563  val, defaultDest, builder, children);
565  return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
566  val, defaultDest, builder, children);
567  default:
568  llvm_unreachable("Generating unknown switch predicate.");
569  }
570 }
571 
572 void PatternLowering::generate(SuccessNode *successNode, Block *&currentBlock) {
573  pdl::PatternOp pattern = successNode->getPattern();
574  Value root = successNode->getRoot();
575 
576  // Generate a rewriter for the pattern this success node represents, and track
577  // any values used from the match region.
578  SmallVector<Position *, 8> usedMatchValues;
579  SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
580 
581  // Process any values used in the rewrite that are defined in the match.
582  std::vector<Value> mappedMatchValues;
583  mappedMatchValues.reserve(usedMatchValues.size());
584  for (Position *position : usedMatchValues)
585  mappedMatchValues.push_back(getValueAt(currentBlock, position));
586 
587  // Collect the set of operations generated by the rewriter.
588  SmallVector<StringRef, 4> generatedOps;
589  for (auto op :
590  pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>())
591  generatedOps.push_back(*op.getOpName());
592  ArrayAttr generatedOpsAttr;
593  if (!generatedOps.empty())
594  generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
595 
596  // Grab the root kind if present.
597  StringAttr rootKindAttr;
598  if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>())
599  if (std::optional<StringRef> rootKind = rootOp.getOpName())
600  rootKindAttr = builder.getStringAttr(*rootKind);
601 
602  builder.setInsertionPointToEnd(currentBlock);
603  auto matchOp = builder.create<pdl_interp::RecordMatchOp>(
604  pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
605  rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(),
606  failureBlockStack.back());
607 
608  // Set the config of the lowered match to the parent pattern.
609  if (configMap)
610  configMap->try_emplace(matchOp, configMap->lookup(pattern));
611 }
612 
613 SymbolRefAttr PatternLowering::generateRewriter(
614  pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
615  builder.setInsertionPointToEnd(rewriterModule.getBody());
616  auto rewriterFunc = builder.create<pdl_interp::FuncOp>(
617  pattern.getLoc(), "pdl_generated_rewriter",
618  builder.getFunctionType(std::nullopt, std::nullopt));
619  rewriterSymbolTable.insert(rewriterFunc);
620 
621  // Generate the rewriter function body.
622  builder.setInsertionPointToEnd(&rewriterFunc.front());
623 
624  // Map an input operand of the pattern to a generated interpreter value.
625  DenseMap<Value, Value> rewriteValues;
626  auto mapRewriteValue = [&](Value oldValue) {
627  Value &newValue = rewriteValues[oldValue];
628  if (newValue)
629  return newValue;
630 
631  // Prefer materializing constants directly when possible.
632  Operation *oldOp = oldValue.getDefiningOp();
633  if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
634  if (Attribute value = attrOp.getValueAttr()) {
635  return newValue = builder.create<pdl_interp::CreateAttributeOp>(
636  attrOp.getLoc(), value);
637  }
638  } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
639  if (TypeAttr type = typeOp.getConstantTypeAttr()) {
640  return newValue = builder.create<pdl_interp::CreateTypeOp>(
641  typeOp.getLoc(), type);
642  }
643  } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
644  if (ArrayAttr type = typeOp.getConstantTypesAttr()) {
645  return newValue = builder.create<pdl_interp::CreateTypesOp>(
646  typeOp.getLoc(), typeOp.getType(), type);
647  }
648  }
649 
650  // Otherwise, add this as an input to the rewriter.
651  Position *inputPos = valueToPosition.lookup(oldValue);
652  assert(inputPos && "expected value to be a pattern input");
653  usedMatchValues.push_back(inputPos);
654  return newValue = rewriterFunc.front().addArgument(oldValue.getType(),
655  oldValue.getLoc());
656  };
657 
658  // If this is a custom rewriter, simply dispatch to the registered rewrite
659  // method.
660  pdl::RewriteOp rewriter = pattern.getRewriter();
661  if (StringAttr rewriteName = rewriter.getNameAttr()) {
662  SmallVector<Value> args;
663  if (rewriter.getRoot())
664  args.push_back(mapRewriteValue(rewriter.getRoot()));
665  auto mappedArgs =
666  llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue);
667  args.append(mappedArgs.begin(), mappedArgs.end());
668  builder.create<pdl_interp::ApplyRewriteOp>(
669  rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args);
670  } else {
671  // Otherwise this is a dag rewriter defined using PDL operations.
672  for (Operation &rewriteOp : *rewriter.getBody()) {
674  .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
675  pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp,
676  pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>([&](auto op) {
677  this->generateRewriter(op, rewriteValues, mapRewriteValue);
678  });
679  }
680  }
681 
682  // Update the signature of the rewrite function.
683  rewriterFunc.setType(builder.getFunctionType(
684  llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
685  /*results=*/std::nullopt));
686 
687  builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
688  return SymbolRefAttr::get(
689  builder.getContext(),
690  pdl_interp::PDLInterpDialect::getRewriterModuleName(),
691  SymbolRefAttr::get(rewriterFunc));
692 }
693 
694 void PatternLowering::generateRewriter(
695  pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
696  function_ref<Value(Value)> mapRewriteValue) {
697  SmallVector<Value, 2> arguments;
698  for (Value argument : rewriteOp.getArgs())
699  arguments.push_back(mapRewriteValue(argument));
700  auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
701  rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.getNameAttr(),
702  arguments);
703  for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
704  rewriteValues[std::get<0>(it)] = std::get<1>(it);
705 }
706 
707 void PatternLowering::generateRewriter(
708  pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
709  function_ref<Value(Value)> mapRewriteValue) {
710  Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
711  attrOp.getLoc(), attrOp.getValueAttr());
712  rewriteValues[attrOp] = newAttr;
713 }
714 
715 void PatternLowering::generateRewriter(
716  pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
717  function_ref<Value(Value)> mapRewriteValue) {
718  builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
719  mapRewriteValue(eraseOp.getOpValue()));
720 }
721 
722 void PatternLowering::generateRewriter(
723  pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
724  function_ref<Value(Value)> mapRewriteValue) {
725  SmallVector<Value, 4> operands;
726  for (Value operand : operationOp.getOperandValues())
727  operands.push_back(mapRewriteValue(operand));
728 
729  SmallVector<Value, 4> attributes;
730  for (Value attr : operationOp.getAttributeValues())
731  attributes.push_back(mapRewriteValue(attr));
732 
733  bool hasInferredResultTypes = false;
734  SmallVector<Value, 2> types;
735  generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types,
736  rewriteValues, hasInferredResultTypes);
737 
738  // Create the new operation.
739  Location loc = operationOp.getLoc();
740  Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
741  loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands,
742  attributes, operationOp.getAttributeValueNames());
743  rewriteValues[operationOp.getOp()] = createdOp;
744 
745  // Generate accesses for any results that have their types constrained.
746  // Handle the case where there is a single range representing all of the
747  // result types.
748  OperandRange resultTys = operationOp.getTypeValues();
749  if (resultTys.size() == 1 && isa<pdl::RangeType>(resultTys[0].getType())) {
750  Value &type = rewriteValues[resultTys[0]];
751  if (!type) {
752  auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
753  type = builder.create<pdl_interp::GetValueTypeOp>(loc, results);
754  }
755  return;
756  }
757 
758  // Otherwise, populate the individual results.
759  bool seenVariableLength = false;
760  Type valueTy = builder.getType<pdl::ValueType>();
761  Type valueRangeTy = pdl::RangeType::get(valueTy);
762  for (const auto &it : llvm::enumerate(resultTys)) {
763  Value &type = rewriteValues[it.value()];
764  if (type)
765  continue;
766  bool isVariadic = isa<pdl::RangeType>(it.value().getType());
767  seenVariableLength |= isVariadic;
768 
769  // After a variable length result has been seen, we need to use result
770  // groups because the exact index of the result is not statically known.
771  Value resultVal;
772  if (seenVariableLength)
773  resultVal = builder.create<pdl_interp::GetResultsOp>(
774  loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index());
775  else
776  resultVal = builder.create<pdl_interp::GetResultOp>(
777  loc, valueTy, createdOp, it.index());
778  type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal);
779  }
780 }
781 
782 void PatternLowering::generateRewriter(
783  pdl::RangeOp rangeOp, DenseMap<Value, Value> &rewriteValues,
784  function_ref<Value(Value)> mapRewriteValue) {
785  SmallVector<Value, 4> replOperands;
786  for (Value operand : rangeOp.getArguments())
787  replOperands.push_back(mapRewriteValue(operand));
788  rewriteValues[rangeOp] = builder.create<pdl_interp::CreateRangeOp>(
789  rangeOp.getLoc(), rangeOp.getType(), replOperands);
790 }
791 
792 void PatternLowering::generateRewriter(
793  pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
794  function_ref<Value(Value)> mapRewriteValue) {
795  SmallVector<Value, 4> replOperands;
796 
797  // If the replacement was another operation, get its results. `pdl` allows
798  // for using an operation for simplicitly, but the interpreter isn't as
799  // user facing.
800  if (Value replOp = replaceOp.getReplOperation()) {
801  // Don't use replace if we know the replaced operation has no results.
802  auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>();
803  if (!opOp || !opOp.getTypeValues().empty()) {
804  replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
805  replOp.getLoc(), mapRewriteValue(replOp)));
806  }
807  } else {
808  for (Value operand : replaceOp.getReplValues())
809  replOperands.push_back(mapRewriteValue(operand));
810  }
811 
812  // If there are no replacement values, just create an erase instead.
813  if (replOperands.empty()) {
814  builder.create<pdl_interp::EraseOp>(
815  replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue()));
816  return;
817  }
818 
819  builder.create<pdl_interp::ReplaceOp>(replaceOp.getLoc(),
820  mapRewriteValue(replaceOp.getOpValue()),
821  replOperands);
822 }
823 
824 void PatternLowering::generateRewriter(
825  pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
826  function_ref<Value(Value)> mapRewriteValue) {
827  rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
828  resultOp.getLoc(), builder.getType<pdl::ValueType>(),
829  mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
830 }
831 
832 void PatternLowering::generateRewriter(
833  pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
834  function_ref<Value(Value)> mapRewriteValue) {
835  rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
836  resultOp.getLoc(), resultOp.getType(),
837  mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
838 }
839 
840 void PatternLowering::generateRewriter(
841  pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
842  function_ref<Value(Value)> mapRewriteValue) {
843  // If the type isn't constant, the users (e.g. OperationOp) will resolve this
844  // type.
845  if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) {
846  rewriteValues[typeOp] =
847  builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
848  }
849 }
850 
851 void PatternLowering::generateRewriter(
852  pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
853  function_ref<Value(Value)> mapRewriteValue) {
854  // If the type isn't constant, the users (e.g. OperationOp) will resolve this
855  // type.
856  if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) {
857  rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
858  typeOp.getLoc(), typeOp.getType(), typeAttr);
859  }
860 }
861 
862 void PatternLowering::generateOperationResultTypeRewriter(
863  pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
864  SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
865  bool &hasInferredResultTypes) {
866  Block *rewriterBlock = op->getBlock();
867 
868  // Try to handle resolution for each of the result types individually. This is
869  // preferred over type inferrence because it will allow for us to use existing
870  // types directly, as opposed to trying to rebuild the type list.
871  OperandRange resultTypeValues = op.getTypeValues();
872  auto tryResolveResultTypes = [&] {
873  types.reserve(resultTypeValues.size());
874  for (const auto &it : llvm::enumerate(resultTypeValues)) {
875  Value resultType = it.value();
876 
877  // Check for an already translated value.
878  if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
879  types.push_back(existingRewriteValue);
880  continue;
881  }
882 
883  // Check for an input from the matcher.
884  if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
885  types.push_back(mapRewriteValue(resultType));
886  continue;
887  }
888 
889  // Otherwise, we couldn't infer the result types. Bail out here to see if
890  // we can infer the types for this operation from another way.
891  types.clear();
892  return failure();
893  }
894  return success();
895  };
896  if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes()))
897  return;
898 
899  // Otherwise, check if the operation has type inference support itself.
900  if (op.hasTypeInference()) {
901  hasInferredResultTypes = true;
902  return;
903  }
904 
905  // Look for an operation that was replaced by `op`. The result types will be
906  // inferred from the results that were replaced.
907  for (OpOperand &use : op.getOp().getUses()) {
908  // Check that the use corresponds to a ReplaceOp and that it is the
909  // replacement value, not the operation being replaced.
910  pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
911  if (!replOpUser || use.getOperandNumber() == 0)
912  continue;
913  // Make sure the replaced operation was defined before this one. PDL
914  // rewrites only have single block regions, so if the op isn't in the
915  // rewriter block (i.e. the current block of the operation) we already know
916  // it dominates (i.e. it's in the matcher).
917  Value replOpVal = replOpUser.getOpValue();
918  Operation *replacedOp = replOpVal.getDefiningOp();
919  if (replacedOp->getBlock() == rewriterBlock &&
920  !replacedOp->isBeforeInBlock(op))
921  continue;
922 
923  Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
924  replacedOp->getLoc(), mapRewriteValue(replOpVal));
925  types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
926  replacedOp->getLoc(), replacedOpResults));
927  return;
928  }
929 
930  // If the types could not be inferred from any context and there weren't any
931  // explicit result types, assume the user actually meant for the operation to
932  // have no results.
933  if (resultTypeValues.empty())
934  return;
935 
936  // The verifier asserts that the result types of each pdl.getOperation can be
937  // inferred. If we reach here, there is a bug either in the logic above or
938  // in the verifier for pdl.getOperation.
939  op->emitOpError() << "unable to infer result type for operation";
940  llvm_unreachable("unable to infer result type for operation");
941 }
942 
943 //===----------------------------------------------------------------------===//
944 // Conversion Pass
945 //===----------------------------------------------------------------------===//
946 
947 namespace {
948 struct PDLToPDLInterpPass
949  : public impl::ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
950  PDLToPDLInterpPass() = default;
951  PDLToPDLInterpPass(const PDLToPDLInterpPass &rhs) = default;
952  PDLToPDLInterpPass(DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
953  : configMap(&configMap) {}
954  void runOnOperation() final;
955 
956  /// A map containing the configuration for each pattern.
957  DenseMap<Operation *, PDLPatternConfigSet *> *configMap = nullptr;
958 };
959 } // namespace
960 
961 /// Convert the given module containing PDL pattern operations into a PDL
962 /// Interpreter operations.
963 void PDLToPDLInterpPass::runOnOperation() {
964  ModuleOp module = getOperation();
965 
966  // Create the main matcher function This function contains all of the match
967  // related functionality from patterns in the module.
968  OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
969  auto matcherFunc = builder.create<pdl_interp::FuncOp>(
970  module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
971  builder.getFunctionType(builder.getType<pdl::OperationType>(),
972  /*results=*/std::nullopt),
973  /*attrs=*/std::nullopt);
974 
975  // Create a nested module to hold the functions invoked for rewriting the IR
976  // after a successful match.
977  ModuleOp rewriterModule = builder.create<ModuleOp>(
978  module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
979 
980  // Generate the code for the patterns within the module.
981  PatternLowering generator(matcherFunc, rewriterModule, configMap);
982  generator.lower(module);
983 
984  // After generation, delete all of the pattern operations.
985  for (pdl::PatternOp pattern :
986  llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) {
987  // Drop the now dead config mappings.
988  if (configMap)
989  configMap->erase(pattern);
990 
991  pattern.erase();
992  }
993 }
994 
995 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
996  return std::make_unique<PDLToPDLInterpPass>();
997 }
998 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass(
1000  return std::make_unique<PDLToPDLInterpPass>(configMap);
1001 }
static MLIRContext * getContext(OpFoldResult val)
static const mlir::GenInfo * generator
static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, llvm::MapVector< Qualifier *, Block * > &dests)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
Operation & back()
Definition: Block.h:145
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:60
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
OpListType & getOperations()
Definition: Block.h:130
Operation & front()
Definition: Block.h:146
iterator end()
Definition: Block.h:137
void push_back(Operation *op)
Definition: Block.h:142
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:96
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:93
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:206
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents an operand of an operation.
Definition: Value.h:263
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Operation.h:825
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
This class contains a set of configurations for a specific pattern.
Definition: PatternMatch.h:979
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & emplaceBlock()
Definition: Region.h:46
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This class represents the base of a predicate matcher node.
Definition: PredicateTree.h:50
Position * getPosition() const
Returns the position on which the question predicate should be checked.
Definition: PredicateTree.h:63
std::unique_ptr< MatcherNode > & getFailureNode()
Returns the node that should be visited if this, or a subsequent node fails.
Definition: PredicateTree.h:70
Qualifier * getQuestion() const
Returns the predicate checked on this node.
Definition: PredicateTree.h:66
A position describes a value on the input IR on which a predicate may be applied, such as an operatio...
Definition: Predicate.h:143
Position * getParent() const
Returns the parent position. The root operation position has no parent.
Definition: Predicate.h:152
Predicates::Kind getKind() const
Returns the kind of this position.
Definition: Predicate.h:155
This class provides utilities for constructing predicates.
Definition: Predicate.h:566
This class provides a storage uniquer that is used to allocate predicate instances.
Definition: Predicate.h:523
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Definition: Predicate.h:387
Predicates::Kind getKind() const
Returns the kind of this qualifier.
Definition: Predicate.h:392
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
@ OperationPos
Positions, ordered by decreasing priority.
Definition: Predicate.h:46
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< OperationPass< ModuleOp > > createPDLToPDLInterpPass()
Creates and returns a pass to convert PDL ops to PDL interpreter ops.
A BoolNode denotes a question with a boolean-like result.
std::unique_ptr< MatcherNode > & getSuccessNode()
Returns the node that should be visited on success.
Qualifier * getAnswer() const
Returns the expected answer of this boolean node.
An Answer representing an OperationName value.
Definition: Predicate.h:412
A SuccessNode denotes that a given high level pattern has successfully been matched.
pdl::PatternOp getPattern() const
Return the high level pattern operation that is matched with this node.
Value getRoot() const
Return the chosen root of the pattern.
A SwitchNode denotes a question with multiple potential results.
std::pair< Qualifier *, std::unique_ptr< MatcherNode > > & getChild(unsigned i)
Returns the child at the given index.