MLIR  22.0.0git
PDLToPDLInterp.cpp
Go to the documentation of this file.
1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "PredicateTree.h"
14 #include "mlir/Pass/Pass.h"
15 #include "llvm/ADT/MapVector.h"
16 #include "llvm/ADT/ScopedHashTable.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_CONVERTPDLTOPDLINTERPPASS
23 #include "mlir/Conversion/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 using namespace mlir::pdl_to_pdl_interp;
28 
29 //===----------------------------------------------------------------------===//
30 // PatternLowering
31 //===----------------------------------------------------------------------===//
32 
33 namespace {
34 /// This class generators operations within the PDL Interpreter dialect from a
35 /// given module containing PDL pattern operations.
36 struct PatternLowering {
37 public:
38  PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
40 
41  /// Generate code for matching and rewriting based on the pattern operations
42  /// within the module.
43  void lower(ModuleOp module);
44 
45 private:
46  using ValueMap = llvm::ScopedHashTable<Position *, Value>;
47  using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
48 
49  /// Generate interpreter operations for the tree rooted at the given matcher
50  /// node, in the specified region.
51  Block *generateMatcher(MatcherNode &node, Region &region,
52  Block *block = nullptr);
53 
54  /// Get or create an access to the provided positional value in the current
55  /// block. This operation may mutate the provided block pointer if nested
56  /// regions (i.e., pdl_interp.iterate) are required.
57  Value getValueAt(Block *&currentBlock, Position *pos);
58 
59  /// Create the interpreter predicate operations. This operation may mutate the
60  /// provided current block pointer if nested regions (iterates) are required.
61  void generate(BoolNode *boolNode, Block *&currentBlock, Value val);
62 
63  /// Create the interpreter switch / predicate operations, with several case
64  /// destinations. This operation never mutates the provided current block
65  /// pointer, because the switch operation does not need Values beyond `val`.
66  void generate(SwitchNode *switchNode, Block *currentBlock, Value val);
67 
68  /// Create the interpreter operations to record a successful pattern match
69  /// using the contained root operation. This operation may mutate the current
70  /// block pointer if nested regions (i.e., pdl_interp.iterate) are required.
71  void generate(SuccessNode *successNode, Block *&currentBlock);
72 
73  /// Generate a rewriter function for the given pattern operation, and returns
74  /// a reference to that function.
75  SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
76  SmallVectorImpl<Position *> &usedMatchValues);
77 
78  /// Generate the rewriter code for the given operation.
79  void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
80  DenseMap<Value, Value> &rewriteValues,
81  function_ref<Value(Value)> mapRewriteValue);
82  void generateRewriter(pdl::AttributeOp attrOp,
83  DenseMap<Value, Value> &rewriteValues,
84  function_ref<Value(Value)> mapRewriteValue);
85  void generateRewriter(pdl::EraseOp eraseOp,
86  DenseMap<Value, Value> &rewriteValues,
87  function_ref<Value(Value)> mapRewriteValue);
88  void generateRewriter(pdl::OperationOp operationOp,
89  DenseMap<Value, Value> &rewriteValues,
90  function_ref<Value(Value)> mapRewriteValue);
91  void generateRewriter(pdl::RangeOp rangeOp,
92  DenseMap<Value, Value> &rewriteValues,
93  function_ref<Value(Value)> mapRewriteValue);
94  void generateRewriter(pdl::ReplaceOp replaceOp,
95  DenseMap<Value, Value> &rewriteValues,
96  function_ref<Value(Value)> mapRewriteValue);
97  void generateRewriter(pdl::ResultOp resultOp,
98  DenseMap<Value, Value> &rewriteValues,
99  function_ref<Value(Value)> mapRewriteValue);
100  void generateRewriter(pdl::ResultsOp resultOp,
101  DenseMap<Value, Value> &rewriteValues,
102  function_ref<Value(Value)> mapRewriteValue);
103  void generateRewriter(pdl::TypeOp typeOp,
104  DenseMap<Value, Value> &rewriteValues,
105  function_ref<Value(Value)> mapRewriteValue);
106  void generateRewriter(pdl::TypesOp typeOp,
107  DenseMap<Value, Value> &rewriteValues,
108  function_ref<Value(Value)> mapRewriteValue);
109 
110  /// Generate the values used for resolving the result types of an operation
111  /// created within a dag rewriter region. If the result types of the operation
112  /// should be inferred, `hasInferredResultTypes` is set to true.
113  void generateOperationResultTypeRewriter(
114  pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
115  SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
116  bool &hasInferredResultTypes);
117 
118  /// A builder to use when generating interpreter operations.
119  OpBuilder builder;
120 
121  /// The matcher function used for all match related logic within PDL patterns.
122  pdl_interp::FuncOp matcherFunc;
123 
124  /// The rewriter module containing the all rewrite related logic within PDL
125  /// patterns.
126  ModuleOp rewriterModule;
127 
128  /// The symbol table of the rewriter module used for insertion.
129  SymbolTable rewriterSymbolTable;
130 
131  /// A scoped map connecting a position with the corresponding interpreter
132  /// value.
133  ValueMap values;
134 
135  /// A stack of blocks used as the failure destination for matcher nodes that
136  /// don't have an explicit failure path.
137  SmallVector<Block *, 8> failureBlockStack;
138 
139  /// A mapping between values defined in a pattern match, and the corresponding
140  /// positional value.
141  DenseMap<Value, Position *> valueToPosition;
142 
143  /// The set of operation values whose location will be used for newly
144  /// generated operations.
145  SetVector<Value> locOps;
146 
147  /// A mapping between pattern operations and the corresponding configuration
148  /// set.
150 
151  /// A mapping from a constraint question to the ApplyConstraintOp
152  /// that implements it.
154 };
155 } // namespace
156 
157 PatternLowering::PatternLowering(
158  pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
160  : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
161  rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule),
162  configMap(configMap) {}
163 
164 void PatternLowering::lower(ModuleOp module) {
165  PredicateUniquer predicateUniquer;
166  PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
167 
168  // Define top-level scope for the arguments to the matcher function.
169  ValueMapScope topLevelValueScope(values);
170 
171  // Insert the root operation, i.e. argument to the matcher, at the root
172  // position.
173  Block *matcherEntryBlock = &matcherFunc.front();
174  values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
175 
176  // Generate a root matcher node from the provided PDL module.
177  std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
178  module, predicateBuilder, valueToPosition);
179  Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody());
180  assert(failureBlockStack.empty() && "failed to empty the stack");
181 
182  // After generation, merged the first matched block into the entry.
183  matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
184  firstMatcherBlock->getOperations());
185  firstMatcherBlock->erase();
186 }
187 
188 Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region,
189  Block *block) {
190  // Push a new scope for the values used by this matcher.
191  if (!block)
192  block = &region.emplaceBlock();
193  ValueMapScope scope(values);
194 
195  // If this is the return node, simply insert the corresponding interpreter
196  // finalize.
197  if (isa<ExitNode>(node)) {
198  builder.setInsertionPointToEnd(block);
199  pdl_interp::FinalizeOp::create(builder, matcherFunc.getLoc());
200  return block;
201  }
202 
203  // Get the next block in the match sequence.
204  // This is intentionally executed first, before we get the value for the
205  // position associated with the node, so that we preserve an "there exist"
206  // semantics: if getting a value requires an upward traversal (going from a
207  // value to its consumers), we want to perform the check on all the consumers
208  // before we pass control to the failure node.
209  std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
210  Block *failureBlock;
211  if (failureNode) {
212  failureBlock = generateMatcher(*failureNode, region);
213  failureBlockStack.push_back(failureBlock);
214  } else {
215  assert(!failureBlockStack.empty() && "expected valid failure block");
216  failureBlock = failureBlockStack.back();
217  }
218 
219  // If this node contains a position, get the corresponding value for this
220  // block.
221  Block *currentBlock = block;
222  Position *position = node.getPosition();
223  Value val = position ? getValueAt(currentBlock, position) : Value();
224 
225  // If this value corresponds to an operation, record that we are going to use
226  // its location as part of a fused location.
227  bool isOperationValue = val && isa<pdl::OperationType>(val.getType());
228  if (isOperationValue)
229  locOps.insert(val);
230 
231  // Dispatch to the correct method based on derived node type.
233  .Case<BoolNode, SwitchNode>([&](auto *derivedNode) {
234  this->generate(derivedNode, currentBlock, val);
235  })
236  .Case([&](SuccessNode *successNode) {
237  generate(successNode, currentBlock);
238  });
239 
240  // Pop all the failure blocks that were inserted due to nesting of
241  // pdl_interp.iterate.
242  while (failureBlockStack.back() != failureBlock) {
243  failureBlockStack.pop_back();
244  assert(!failureBlockStack.empty() && "unable to locate failure block");
245  }
246 
247  // Pop the new failure block.
248  if (failureNode)
249  failureBlockStack.pop_back();
250 
251  if (isOperationValue)
252  locOps.remove(val);
253 
254  return block;
255 }
256 
257 Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
258  if (Value val = values.lookup(pos))
259  return val;
260 
261  // Get the value for the parent position.
262  Value parentVal;
263  if (Position *parent = pos->getParent())
264  parentVal = getValueAt(currentBlock, parent);
265 
266  // TODO: Use a location from the position.
267  Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc();
268  builder.setInsertionPointToEnd(currentBlock);
269  Value value;
270  switch (pos->getKind()) {
272  auto *operationPos = cast<OperationPosition>(pos);
273  if (operationPos->isOperandDefiningOp())
274  // Standard (downward) traversal which directly follows the defining op.
275  value = pdl_interp::GetDefiningOpOp::create(
276  builder, loc, builder.getType<pdl::OperationType>(), parentVal);
277  else
278  // A passthrough operation position.
279  value = parentVal;
280  break;
281  }
282  case Predicates::UsersPos: {
283  auto *usersPos = cast<UsersPosition>(pos);
284 
285  // The first operation retrieves the representative value of a range.
286  // This applies only when the parent is a range of values and we were
287  // requested to use a representative value (e.g., upward traversal).
288  if (isa<pdl::RangeType>(parentVal.getType()) &&
289  usersPos->useRepresentative())
290  value = pdl_interp::ExtractOp::create(builder, loc, parentVal, 0);
291  else
292  value = parentVal;
293 
294  // The second operation retrieves the users.
295  value = pdl_interp::GetUsersOp::create(builder, loc, value);
296  break;
297  }
298  case Predicates::ForEachPos: {
299  assert(!failureBlockStack.empty() && "expected valid failure block");
300  auto foreach = pdl_interp::ForEachOp::create(
301  builder, loc, parentVal, failureBlockStack.back(), /*initLoop=*/true);
302  value = foreach.getLoopVariable();
303 
304  // Create the continuation block.
305  Block *continueBlock = builder.createBlock(&foreach.getRegion());
306  pdl_interp::ContinueOp::create(builder, loc);
307  failureBlockStack.push_back(continueBlock);
308 
309  currentBlock = &foreach.getRegion().front();
310  break;
311  }
312  case Predicates::OperandPos: {
313  auto *operandPos = cast<OperandPosition>(pos);
314  value = pdl_interp::GetOperandOp::create(
315  builder, loc, builder.getType<pdl::ValueType>(), parentVal,
316  operandPos->getOperandNumber());
317  break;
318  }
320  auto *operandPos = cast<OperandGroupPosition>(pos);
321  Type valueTy = builder.getType<pdl::ValueType>();
322  value = pdl_interp::GetOperandsOp::create(
323  builder, loc,
324  operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
325  parentVal, operandPos->getOperandGroupNumber());
326  break;
327  }
329  auto *attrPos = cast<AttributePosition>(pos);
330  value = pdl_interp::GetAttributeOp::create(
331  builder, loc, builder.getType<pdl::AttributeType>(), parentVal,
332  attrPos->getName().strref());
333  break;
334  }
335  case Predicates::TypePos: {
336  if (isa<pdl::AttributeType>(parentVal.getType()))
337  value = pdl_interp::GetAttributeTypeOp::create(builder, loc, parentVal);
338  else
339  value = pdl_interp::GetValueTypeOp::create(builder, loc, parentVal);
340  break;
341  }
342  case Predicates::ResultPos: {
343  auto *resPos = cast<ResultPosition>(pos);
344  value = pdl_interp::GetResultOp::create(
345  builder, loc, builder.getType<pdl::ValueType>(), parentVal,
346  resPos->getResultNumber());
347  break;
348  }
350  auto *resPos = cast<ResultGroupPosition>(pos);
351  Type valueTy = builder.getType<pdl::ValueType>();
352  value = pdl_interp::GetResultsOp::create(
353  builder, loc,
354  resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
355  parentVal, resPos->getResultGroupNumber());
356  break;
357  }
359  auto *attrPos = cast<AttributeLiteralPosition>(pos);
360  value = pdl_interp::CreateAttributeOp::create(builder, loc,
361  attrPos->getValue());
362  break;
363  }
365  auto *typePos = cast<TypeLiteralPosition>(pos);
366  Attribute rawTypeAttr = typePos->getValue();
367  if (TypeAttr typeAttr = dyn_cast<TypeAttr>(rawTypeAttr))
368  value = pdl_interp::CreateTypeOp::create(builder, loc, typeAttr);
369  else
370  value = pdl_interp::CreateTypesOp::create(builder, loc,
371  cast<ArrayAttr>(rawTypeAttr));
372  break;
373  }
375  // Due to the order of traversal, the ApplyConstraintOp has already been
376  // created and we can find it in constraintOpMap.
377  auto *constrResPos = cast<ConstraintPosition>(pos);
378  auto i = constraintOpMap.find(constrResPos->getQuestion());
379  assert(i != constraintOpMap.end());
380  value = i->second->getResult(constrResPos->getIndex());
381  break;
382  }
383  default:
384  llvm_unreachable("Generating unknown Position getter");
385  break;
386  }
387 
388  values.insert(pos, value);
389  return value;
390 }
391 
392 void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
393  Value val) {
394  Location loc = val.getLoc();
395  Qualifier *question = boolNode->getQuestion();
396  Qualifier *answer = boolNode->getAnswer();
397  Region *region = currentBlock->getParent();
398 
399  // Execute the getValue queries first, so that we create success
400  // matcher in the correct (possibly nested) region.
401  SmallVector<Value> args;
402  if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
403  args = {getValueAt(currentBlock, equalToQuestion->getValue())};
404  } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
405  for (Position *position : cstQuestion->getArgs())
406  args.push_back(getValueAt(currentBlock, position));
407  }
408 
409  // Generate a new block as success successor and get the failure successor.
410  Block *success = &region->emplaceBlock();
411  Block *failure = failureBlockStack.back();
412 
413  // Create the predicate.
414  builder.setInsertionPointToEnd(currentBlock);
415  Predicates::Kind kind = question->getKind();
416  switch (kind) {
418  pdl_interp::IsNotNullOp::create(builder, loc, val, success, failure);
419  break;
421  auto *opNameAnswer = cast<OperationNameAnswer>(answer);
422  pdl_interp::CheckOperationNameOp::create(
423  builder, loc, val, opNameAnswer->getValue().getStringRef(), success,
424  failure);
425  break;
426  }
428  auto *ans = cast<TypeAnswer>(answer);
429  if (isa<pdl::RangeType>(val.getType()))
430  pdl_interp::CheckTypesOp::create(builder, loc, val,
431  llvm::cast<ArrayAttr>(ans->getValue()),
432  success, failure);
433  else
434  pdl_interp::CheckTypeOp::create(builder, loc, val,
435  llvm::cast<TypeAttr>(ans->getValue()),
436  success, failure);
437  break;
438  }
440  auto *ans = cast<AttributeAnswer>(answer);
441  pdl_interp::CheckAttributeOp::create(builder, loc, val, ans->getValue(),
442  success, failure);
443  break;
444  }
447  pdl_interp::CheckOperandCountOp::create(
448  builder, loc, val, cast<UnsignedAnswer>(answer)->getValue(),
449  /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
450  success, failure);
451  break;
454  pdl_interp::CheckResultCountOp::create(
455  builder, loc, val, cast<UnsignedAnswer>(answer)->getValue(),
456  /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
457  success, failure);
458  break;
460  bool trueAnswer = isa<TrueAnswer>(answer);
461  pdl_interp::AreEqualOp::create(builder, loc, val, args.front(),
462  trueAnswer ? success : failure,
463  trueAnswer ? failure : success);
464  break;
465  }
467  auto *cstQuestion = cast<ConstraintQuestion>(question);
468  auto applyConstraintOp = pdl_interp::ApplyConstraintOp::create(
469  builder, loc, cstQuestion->getResultTypes(), cstQuestion->getName(),
470  args, cstQuestion->getIsNegated(), success, failure);
471 
472  constraintOpMap.insert({cstQuestion, applyConstraintOp});
473  break;
474  }
475  default:
476  llvm_unreachable("Generating unknown Predicate operation");
477  }
478 
479  // Generate the matcher in the current (potentially nested) region.
480  // This might use the results of the current predicate.
481  generateMatcher(*boolNode->getSuccessNode(), *region, success);
482 }
483 
484 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
485 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
486  llvm::MapVector<Qualifier *, Block *> &dests) {
487  std::vector<ValT> values;
488  std::vector<Block *> blocks;
489  values.reserve(dests.size());
490  blocks.reserve(dests.size());
491  for (const auto &it : dests) {
492  blocks.push_back(it.second);
493  values.push_back(cast<PredT>(it.first)->getValue());
494  }
495  OpT::create(builder, val.getLoc(), val, values, defaultDest, blocks);
496 }
497 
498 void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
499  Value val) {
500  Qualifier *question = switchNode->getQuestion();
501  Region *region = currentBlock->getParent();
502  Block *defaultDest = failureBlockStack.back();
503 
504  // If the switch question is not an exact answer, i.e. for the `at_least`
505  // cases, we generate a special block sequence.
506  Predicates::Kind kind = question->getKind();
509  // Order the children such that the cases are in reverse numerical order.
510  SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
511  llvm::seq<unsigned>(0, switchNode->getChildren().size()));
512  llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) {
513  return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() >
514  cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue();
515  });
516 
517  // Build the destination for each child using the next highest child as a
518  // a failure destination. This essentially creates the following control
519  // flow:
520  //
521  // if (operand_count < 1)
522  // goto failure
523  // if (child1.match())
524  // ...
525  //
526  // if (operand_count < 2)
527  // goto failure
528  // if (child2.match())
529  // ...
530  //
531  // failure:
532  // ...
533  //
534  failureBlockStack.push_back(defaultDest);
535  Location loc = val.getLoc();
536  for (unsigned idx : sortedChildren) {
537  auto &child = switchNode->getChild(idx);
538  Block *childBlock = generateMatcher(*child.second, *region);
539  Block *predicateBlock = builder.createBlock(childBlock);
540  builder.setInsertionPointToEnd(predicateBlock);
541  unsigned ans = cast<UnsignedAnswer>(child.first)->getValue();
542  switch (kind) {
544  pdl_interp::CheckOperandCountOp::create(builder, loc, val, ans,
545  /*compareAtLeast=*/true,
546  childBlock, defaultDest);
547  break;
549  pdl_interp::CheckResultCountOp::create(builder, loc, val, ans,
550  /*compareAtLeast=*/true,
551  childBlock, defaultDest);
552  break;
553  default:
554  llvm_unreachable("Generating invalid AtLeast operation");
555  }
556  failureBlockStack.back() = predicateBlock;
557  }
558  Block *firstPredicateBlock = failureBlockStack.pop_back_val();
559  currentBlock->getOperations().splice(currentBlock->end(),
560  firstPredicateBlock->getOperations());
561  firstPredicateBlock->erase();
562  return;
563  }
564 
565  // Otherwise, generate each of the children and generate an interpreter
566  // switch.
567  llvm::MapVector<Qualifier *, Block *> children;
568  for (auto &it : switchNode->getChildren())
569  children.insert({it.first, generateMatcher(*it.second, *region)});
570  builder.setInsertionPointToEnd(currentBlock);
571 
572  switch (question->getKind()) {
574  return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
575  int32_t>(val, defaultDest, builder, children);
577  return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
578  int32_t>(val, defaultDest, builder, children);
580  return createSwitchOp<pdl_interp::SwitchOperationNameOp,
581  OperationNameAnswer>(val, defaultDest, builder,
582  children);
584  if (isa<pdl::RangeType>(val.getType())) {
585  return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
586  val, defaultDest, builder, children);
587  }
588  return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
589  val, defaultDest, builder, children);
591  return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
592  val, defaultDest, builder, children);
593  default:
594  llvm_unreachable("Generating unknown switch predicate.");
595  }
596 }
597 
598 void PatternLowering::generate(SuccessNode *successNode, Block *&currentBlock) {
599  pdl::PatternOp pattern = successNode->getPattern();
600  Value root = successNode->getRoot();
601 
602  // Generate a rewriter for the pattern this success node represents, and track
603  // any values used from the match region.
604  SmallVector<Position *, 8> usedMatchValues;
605  SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
606 
607  // Process any values used in the rewrite that are defined in the match.
608  std::vector<Value> mappedMatchValues;
609  mappedMatchValues.reserve(usedMatchValues.size());
610  for (Position *position : usedMatchValues)
611  mappedMatchValues.push_back(getValueAt(currentBlock, position));
612 
613  // Collect the set of operations generated by the rewriter.
614  SmallVector<StringRef, 4> generatedOps;
615  for (auto op :
616  pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>())
617  generatedOps.push_back(*op.getOpName());
618  ArrayAttr generatedOpsAttr;
619  if (!generatedOps.empty())
620  generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
621 
622  // Grab the root kind if present.
623  StringAttr rootKindAttr;
624  if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>())
625  if (std::optional<StringRef> rootKind = rootOp.getOpName())
626  rootKindAttr = builder.getStringAttr(*rootKind);
627 
628  builder.setInsertionPointToEnd(currentBlock);
629  auto matchOp = pdl_interp::RecordMatchOp::create(
630  builder, pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
631  rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(),
632  failureBlockStack.back());
633 
634  // Set the config of the lowered match to the parent pattern.
635  if (configMap)
636  configMap->try_emplace(matchOp, configMap->lookup(pattern));
637 }
638 
639 SymbolRefAttr PatternLowering::generateRewriter(
640  pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
641  builder.setInsertionPointToEnd(rewriterModule.getBody());
642  auto rewriterFunc = pdl_interp::FuncOp::create(
643  builder, pattern.getLoc(), "pdl_generated_rewriter",
644  builder.getFunctionType({}, {}));
645  rewriterSymbolTable.insert(rewriterFunc);
646 
647  // Generate the rewriter function body.
648  builder.setInsertionPointToEnd(&rewriterFunc.front());
649 
650  // Map an input operand of the pattern to a generated interpreter value.
651  DenseMap<Value, Value> rewriteValues;
652  auto mapRewriteValue = [&](Value oldValue) {
653  Value &newValue = rewriteValues[oldValue];
654  if (newValue)
655  return newValue;
656 
657  // Prefer materializing constants directly when possible.
658  Operation *oldOp = oldValue.getDefiningOp();
659  if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
660  if (Attribute value = attrOp.getValueAttr()) {
661  return newValue = pdl_interp::CreateAttributeOp::create(
662  builder, attrOp.getLoc(), value);
663  }
664  } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
665  if (TypeAttr type = typeOp.getConstantTypeAttr()) {
666  return newValue = pdl_interp::CreateTypeOp::create(
667  builder, typeOp.getLoc(), type);
668  }
669  } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
670  if (ArrayAttr type = typeOp.getConstantTypesAttr()) {
671  return newValue = pdl_interp::CreateTypesOp::create(
672  builder, typeOp.getLoc(), typeOp.getType(), type);
673  }
674  }
675 
676  // Otherwise, add this as an input to the rewriter.
677  Position *inputPos = valueToPosition.lookup(oldValue);
678  assert(inputPos && "expected value to be a pattern input");
679  usedMatchValues.push_back(inputPos);
680  return newValue = rewriterFunc.front().addArgument(oldValue.getType(),
681  oldValue.getLoc());
682  };
683 
684  // If this is a custom rewriter, simply dispatch to the registered rewrite
685  // method.
686  pdl::RewriteOp rewriter = pattern.getRewriter();
687  if (StringAttr rewriteName = rewriter.getNameAttr()) {
688  SmallVector<Value> args;
689  if (rewriter.getRoot())
690  args.push_back(mapRewriteValue(rewriter.getRoot()));
691  auto mappedArgs =
692  llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue);
693  args.append(mappedArgs.begin(), mappedArgs.end());
694  pdl_interp::ApplyRewriteOp::create(builder, rewriter.getLoc(),
695  /*resultTypes=*/TypeRange(), rewriteName,
696  args);
697  } else {
698  // Otherwise this is a dag rewriter defined using PDL operations.
699  for (Operation &rewriteOp : *rewriter.getBody()) {
701  .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
702  pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp,
703  pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>([&](auto op) {
704  this->generateRewriter(op, rewriteValues, mapRewriteValue);
705  });
706  }
707  }
708 
709  // Update the signature of the rewrite function.
710  rewriterFunc.setType(builder.getFunctionType(
711  llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
712  /*results=*/{}));
713 
714  pdl_interp::FinalizeOp::create(builder, rewriter.getLoc());
715  return SymbolRefAttr::get(
716  builder.getContext(),
717  pdl_interp::PDLInterpDialect::getRewriterModuleName(),
718  SymbolRefAttr::get(rewriterFunc));
719 }
720 
721 void PatternLowering::generateRewriter(
722  pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
723  function_ref<Value(Value)> mapRewriteValue) {
724  SmallVector<Value, 2> arguments;
725  for (Value argument : rewriteOp.getArgs())
726  arguments.push_back(mapRewriteValue(argument));
727  auto interpOp = pdl_interp::ApplyRewriteOp::create(
728  builder, rewriteOp.getLoc(), rewriteOp.getResultTypes(),
729  rewriteOp.getNameAttr(), arguments);
730  for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
731  rewriteValues[std::get<0>(it)] = std::get<1>(it);
732 }
733 
734 void PatternLowering::generateRewriter(
735  pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
736  function_ref<Value(Value)> mapRewriteValue) {
737  Value newAttr = pdl_interp::CreateAttributeOp::create(
738  builder, attrOp.getLoc(), attrOp.getValueAttr());
739  rewriteValues[attrOp] = newAttr;
740 }
741 
742 void PatternLowering::generateRewriter(
743  pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
744  function_ref<Value(Value)> mapRewriteValue) {
745  pdl_interp::EraseOp::create(builder, eraseOp.getLoc(),
746  mapRewriteValue(eraseOp.getOpValue()));
747 }
748 
749 void PatternLowering::generateRewriter(
750  pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
751  function_ref<Value(Value)> mapRewriteValue) {
752  SmallVector<Value, 4> operands;
753  for (Value operand : operationOp.getOperandValues())
754  operands.push_back(mapRewriteValue(operand));
755 
756  SmallVector<Value, 4> attributes;
757  for (Value attr : operationOp.getAttributeValues())
758  attributes.push_back(mapRewriteValue(attr));
759 
760  bool hasInferredResultTypes = false;
761  SmallVector<Value, 2> types;
762  generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types,
763  rewriteValues, hasInferredResultTypes);
764 
765  // Create the new operation.
766  Location loc = operationOp.getLoc();
767  Value createdOp = pdl_interp::CreateOperationOp::create(
768  builder, loc, *operationOp.getOpName(), types, hasInferredResultTypes,
769  operands, attributes, operationOp.getAttributeValueNames());
770  rewriteValues[operationOp.getOp()] = createdOp;
771 
772  // Generate accesses for any results that have their types constrained.
773  // Handle the case where there is a single range representing all of the
774  // result types.
775  OperandRange resultTys = operationOp.getTypeValues();
776  if (resultTys.size() == 1 && isa<pdl::RangeType>(resultTys[0].getType())) {
777  Value &type = rewriteValues[resultTys[0]];
778  if (!type) {
779  auto results = pdl_interp::GetResultsOp::create(builder, loc, createdOp);
780  type = pdl_interp::GetValueTypeOp::create(builder, loc, results);
781  }
782  return;
783  }
784 
785  // Otherwise, populate the individual results.
786  bool seenVariableLength = false;
787  Type valueTy = builder.getType<pdl::ValueType>();
788  Type valueRangeTy = pdl::RangeType::get(valueTy);
789  for (const auto &it : llvm::enumerate(resultTys)) {
790  Value &type = rewriteValues[it.value()];
791  if (type)
792  continue;
793  bool isVariadic = isa<pdl::RangeType>(it.value().getType());
794  seenVariableLength |= isVariadic;
795 
796  // After a variable length result has been seen, we need to use result
797  // groups because the exact index of the result is not statically known.
798  Value resultVal;
799  if (seenVariableLength)
800  resultVal = pdl_interp::GetResultsOp::create(
801  builder, loc, isVariadic ? valueRangeTy : valueTy, createdOp,
802  it.index());
803  else
804  resultVal = pdl_interp::GetResultOp::create(builder, loc, valueTy,
805  createdOp, it.index());
806  type = pdl_interp::GetValueTypeOp::create(builder, loc, resultVal);
807  }
808 }
809 
810 void PatternLowering::generateRewriter(
811  pdl::RangeOp rangeOp, DenseMap<Value, Value> &rewriteValues,
812  function_ref<Value(Value)> mapRewriteValue) {
813  SmallVector<Value, 4> replOperands;
814  for (Value operand : rangeOp.getArguments())
815  replOperands.push_back(mapRewriteValue(operand));
816  rewriteValues[rangeOp] = pdl_interp::CreateRangeOp::create(
817  builder, rangeOp.getLoc(), rangeOp.getType(), replOperands);
818 }
819 
820 void PatternLowering::generateRewriter(
821  pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
822  function_ref<Value(Value)> mapRewriteValue) {
823  SmallVector<Value, 4> replOperands;
824 
825  // If the replacement was another operation, get its results. `pdl` allows
826  // for using an operation for simplicitly, but the interpreter isn't as
827  // user facing.
828  if (Value replOp = replaceOp.getReplOperation()) {
829  // Don't use replace if we know the replaced operation has no results.
830  auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>();
831  if (!opOp || !opOp.getTypeValues().empty()) {
832  replOperands.push_back(pdl_interp::GetResultsOp::create(
833  builder, replOp.getLoc(), mapRewriteValue(replOp)));
834  }
835  } else {
836  for (Value operand : replaceOp.getReplValues())
837  replOperands.push_back(mapRewriteValue(operand));
838  }
839 
840  // If there are no replacement values, just create an erase instead.
841  if (replOperands.empty()) {
842  pdl_interp::EraseOp::create(builder, replaceOp.getLoc(),
843  mapRewriteValue(replaceOp.getOpValue()));
844  return;
845  }
846 
847  pdl_interp::ReplaceOp::create(builder, replaceOp.getLoc(),
848  mapRewriteValue(replaceOp.getOpValue()),
849  replOperands);
850 }
851 
852 void PatternLowering::generateRewriter(
853  pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
854  function_ref<Value(Value)> mapRewriteValue) {
855  rewriteValues[resultOp] = pdl_interp::GetResultOp::create(
856  builder, resultOp.getLoc(), builder.getType<pdl::ValueType>(),
857  mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
858 }
859 
860 void PatternLowering::generateRewriter(
861  pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
862  function_ref<Value(Value)> mapRewriteValue) {
863  rewriteValues[resultOp] = pdl_interp::GetResultsOp::create(
864  builder, resultOp.getLoc(), resultOp.getType(),
865  mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
866 }
867 
868 void PatternLowering::generateRewriter(
869  pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
870  function_ref<Value(Value)> mapRewriteValue) {
871  // If the type isn't constant, the users (e.g. OperationOp) will resolve this
872  // type.
873  if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) {
874  rewriteValues[typeOp] =
875  pdl_interp::CreateTypeOp::create(builder, typeOp.getLoc(), typeAttr);
876  }
877 }
878 
879 void PatternLowering::generateRewriter(
880  pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
881  function_ref<Value(Value)> mapRewriteValue) {
882  // If the type isn't constant, the users (e.g. OperationOp) will resolve this
883  // type.
884  if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) {
885  rewriteValues[typeOp] = pdl_interp::CreateTypesOp::create(
886  builder, typeOp.getLoc(), typeOp.getType(), typeAttr);
887  }
888 }
889 
890 void PatternLowering::generateOperationResultTypeRewriter(
891  pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
892  SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
893  bool &hasInferredResultTypes) {
894  Block *rewriterBlock = op->getBlock();
895 
896  // Try to handle resolution for each of the result types individually. This is
897  // preferred over type inferrence because it will allow for us to use existing
898  // types directly, as opposed to trying to rebuild the type list.
899  OperandRange resultTypeValues = op.getTypeValues();
900  auto tryResolveResultTypes = [&] {
901  types.reserve(resultTypeValues.size());
902  for (const auto &it : llvm::enumerate(resultTypeValues)) {
903  Value resultType = it.value();
904 
905  // Check for an already translated value.
906  if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
907  types.push_back(existingRewriteValue);
908  continue;
909  }
910 
911  // Check for an input from the matcher.
912  if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
913  types.push_back(mapRewriteValue(resultType));
914  continue;
915  }
916 
917  // Otherwise, we couldn't infer the result types. Bail out here to see if
918  // we can infer the types for this operation from another way.
919  types.clear();
920  return failure();
921  }
922  return success();
923  };
924  if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes()))
925  return;
926 
927  // Otherwise, check if the operation has type inference support itself.
928  if (op.hasTypeInference()) {
929  hasInferredResultTypes = true;
930  return;
931  }
932 
933  // Look for an operation that was replaced by `op`. The result types will be
934  // inferred from the results that were replaced.
935  for (OpOperand &use : op.getOp().getUses()) {
936  // Check that the use corresponds to a ReplaceOp and that it is the
937  // replacement value, not the operation being replaced.
938  pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
939  if (!replOpUser || use.getOperandNumber() == 0)
940  continue;
941  // Make sure the replaced operation was defined before this one. PDL
942  // rewrites only have single block regions, so if the op isn't in the
943  // rewriter block (i.e. the current block of the operation) we already know
944  // it dominates (i.e. it's in the matcher).
945  Value replOpVal = replOpUser.getOpValue();
946  Operation *replacedOp = replOpVal.getDefiningOp();
947  if (replacedOp->getBlock() == rewriterBlock &&
948  !replacedOp->isBeforeInBlock(op))
949  continue;
950 
951  Value replacedOpResults = pdl_interp::GetResultsOp::create(
952  builder, replacedOp->getLoc(), mapRewriteValue(replOpVal));
953  types.push_back(pdl_interp::GetValueTypeOp::create(
954  builder, replacedOp->getLoc(), replacedOpResults));
955  return;
956  }
957 
958  // If the types could not be inferred from any context and there weren't any
959  // explicit result types, assume the user actually meant for the operation to
960  // have no results.
961  if (resultTypeValues.empty())
962  return;
963 
964  // The verifier asserts that the result types of each pdl.getOperation can be
965  // inferred. If we reach here, there is a bug either in the logic above or
966  // in the verifier for pdl.getOperation.
967  op->emitOpError() << "unable to infer result type for operation";
968  llvm_unreachable("unable to infer result type for operation");
969 }
970 
971 //===----------------------------------------------------------------------===//
972 // Conversion Pass
973 //===----------------------------------------------------------------------===//
974 
975 namespace {
976 struct PDLToPDLInterpPass
977  : public impl::ConvertPDLToPDLInterpPassBase<PDLToPDLInterpPass> {
978  PDLToPDLInterpPass() = default;
979  PDLToPDLInterpPass(const PDLToPDLInterpPass &rhs) = default;
980  PDLToPDLInterpPass(DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
981  : configMap(&configMap) {}
982  void runOnOperation() final;
983 
984  /// A map containing the configuration for each pattern.
985  DenseMap<Operation *, PDLPatternConfigSet *> *configMap = nullptr;
986 };
987 } // namespace
988 
989 /// Convert the given module containing PDL pattern operations into a PDL
990 /// Interpreter operations.
991 void PDLToPDLInterpPass::runOnOperation() {
992  ModuleOp module = getOperation();
993 
994  // Create the main matcher function This function contains all of the match
995  // related functionality from patterns in the module.
996  OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
997  auto matcherFunc = pdl_interp::FuncOp::create(
998  builder, module.getLoc(),
999  pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
1000  builder.getFunctionType(builder.getType<pdl::OperationType>(),
1001  /*results=*/{}),
1002  /*attrs=*/ArrayRef<NamedAttribute>());
1003 
1004  // Create a nested module to hold the functions invoked for rewriting the IR
1005  // after a successful match.
1006  ModuleOp rewriterModule =
1007  ModuleOp::create(builder, module.getLoc(),
1008  pdl_interp::PDLInterpDialect::getRewriterModuleName());
1009 
1010  // Generate the code for the patterns within the module.
1011  PatternLowering generator(matcherFunc, rewriterModule, configMap);
1012  generator.lower(module);
1013 
1014  // After generation, delete all of the pattern operations.
1015  for (pdl::PatternOp pattern :
1016  llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) {
1017  // Drop the now dead config mappings.
1018  if (configMap)
1019  configMap->erase(pattern);
1020 
1021  pattern.erase();
1022  }
1023 }
1024 
1025 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertPDLToPDLInterpPass(
1027  return std::make_unique<PDLToPDLInterpPass>(configMap);
1028 }
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1244::ArityGroupAndKind::Kind kind
static const mlir::GenInfo * generator
static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, llvm::MapVector< Qualifier *, Block * > &dests)
llvm::ScopedHashTable< mlir::Value, std::string > ValueMap
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Operation & back()
Definition: Block.h:152
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:66
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:75
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:207
This class represents an operand of an operation.
Definition: Value.h:257
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & emplaceBlock()
Definition: Region.h:46
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
This class represents the base of a predicate matcher node.
Definition: PredicateTree.h:50
Position * getPosition() const
Returns the position on which the question predicate should be checked.
Definition: PredicateTree.h:63
std::unique_ptr< MatcherNode > & getFailureNode()
Returns the node that should be visited if this, or a subsequent node fails.
Definition: PredicateTree.h:70
Qualifier * getQuestion() const
Returns the predicate checked on this node.
Definition: PredicateTree.h:66
A position describes a value on the input IR on which a predicate may be applied, such as an operatio...
Definition: Predicate.h:144
Position * getParent() const
Returns the parent position. The root operation position has no parent.
Definition: Predicate.h:153
Predicates::Kind getKind() const
Returns the kind of this position.
Definition: Predicate.h:156
This class provides utilities for constructing predicates.
Definition: Predicate.h:610
This class provides a storage uniquer that is used to allocate predicate instances.
Definition: Predicate.h:566
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Definition: Predicate.h:422
Predicates::Kind getKind() const
Returns the kind of this qualifier.
Definition: Predicate.h:427
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
@ OperationPos
Positions, ordered by decreasing priority.
Definition: Predicate.h:46
Include the generated interface declarations.
std::unique_ptr< OperationPass< ModuleOp > > createConvertPDLToPDLInterpPass(DenseMap< Operation *, PDLPatternConfigSet * > &configMap)
Creates and returns a pass to convert PDL ops to PDL interpreter ops.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A BoolNode denotes a question with a boolean-like result.
std::unique_ptr< MatcherNode > & getSuccessNode()
Returns the node that should be visited on success.
Qualifier * getAnswer() const
Returns the expected answer of this boolean node.
An Answer representing an OperationName value.
Definition: Predicate.h:448
A SuccessNode denotes that a given high level pattern has successfully been matched.
pdl::PatternOp getPattern() const
Return the high level pattern operation that is matched with this node.
Value getRoot() const
Return the chosen root of the pattern.
A SwitchNode denotes a question with multiple potential results.
std::pair< Qualifier *, std::unique_ptr< MatcherNode > > & getChild(unsigned i)
Returns the child at the given index.