MLIR  20.0.0git
PDLToPDLInterp.cpp
Go to the documentation of this file.
1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "PredicateTree.h"
15 #include "mlir/Pass/Pass.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/ScopedHashTable.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_CONVERTPDLTOPDLINTERP
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 using namespace mlir::pdl_to_pdl_interp;
30 
31 //===----------------------------------------------------------------------===//
32 // PatternLowering
33 //===----------------------------------------------------------------------===//
34 
35 namespace {
36 /// This class generators operations within the PDL Interpreter dialect from a
37 /// given module containing PDL pattern operations.
38 struct PatternLowering {
39 public:
40  PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
42 
43  /// Generate code for matching and rewriting based on the pattern operations
44  /// within the module.
45  void lower(ModuleOp module);
46 
47 private:
48  using ValueMap = llvm::ScopedHashTable<Position *, Value>;
49  using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
50 
51  /// Generate interpreter operations for the tree rooted at the given matcher
52  /// node, in the specified region.
53  Block *generateMatcher(MatcherNode &node, Region &region,
54  Block *block = nullptr);
55 
56  /// Get or create an access to the provided positional value in the current
57  /// block. This operation may mutate the provided block pointer if nested
58  /// regions (i.e., pdl_interp.iterate) are required.
59  Value getValueAt(Block *&currentBlock, Position *pos);
60 
61  /// Create the interpreter predicate operations. This operation may mutate the
62  /// provided current block pointer if nested regions (iterates) are required.
63  void generate(BoolNode *boolNode, Block *&currentBlock, Value val);
64 
65  /// Create the interpreter switch / predicate operations, with several case
66  /// destinations. This operation never mutates the provided current block
67  /// pointer, because the switch operation does not need Values beyond `val`.
68  void generate(SwitchNode *switchNode, Block *currentBlock, Value val);
69 
70  /// Create the interpreter operations to record a successful pattern match
71  /// using the contained root operation. This operation may mutate the current
72  /// block pointer if nested regions (i.e., pdl_interp.iterate) are required.
73  void generate(SuccessNode *successNode, Block *&currentBlock);
74 
75  /// Generate a rewriter function for the given pattern operation, and returns
76  /// a reference to that function.
77  SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
78  SmallVectorImpl<Position *> &usedMatchValues);
79 
80  /// Generate the rewriter code for the given operation.
81  void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
82  DenseMap<Value, Value> &rewriteValues,
83  function_ref<Value(Value)> mapRewriteValue);
84  void generateRewriter(pdl::AttributeOp attrOp,
85  DenseMap<Value, Value> &rewriteValues,
86  function_ref<Value(Value)> mapRewriteValue);
87  void generateRewriter(pdl::EraseOp eraseOp,
88  DenseMap<Value, Value> &rewriteValues,
89  function_ref<Value(Value)> mapRewriteValue);
90  void generateRewriter(pdl::OperationOp operationOp,
91  DenseMap<Value, Value> &rewriteValues,
92  function_ref<Value(Value)> mapRewriteValue);
93  void generateRewriter(pdl::RangeOp rangeOp,
94  DenseMap<Value, Value> &rewriteValues,
95  function_ref<Value(Value)> mapRewriteValue);
96  void generateRewriter(pdl::ReplaceOp replaceOp,
97  DenseMap<Value, Value> &rewriteValues,
98  function_ref<Value(Value)> mapRewriteValue);
99  void generateRewriter(pdl::ResultOp resultOp,
100  DenseMap<Value, Value> &rewriteValues,
101  function_ref<Value(Value)> mapRewriteValue);
102  void generateRewriter(pdl::ResultsOp resultOp,
103  DenseMap<Value, Value> &rewriteValues,
104  function_ref<Value(Value)> mapRewriteValue);
105  void generateRewriter(pdl::TypeOp typeOp,
106  DenseMap<Value, Value> &rewriteValues,
107  function_ref<Value(Value)> mapRewriteValue);
108  void generateRewriter(pdl::TypesOp typeOp,
109  DenseMap<Value, Value> &rewriteValues,
110  function_ref<Value(Value)> mapRewriteValue);
111 
112  /// Generate the values used for resolving the result types of an operation
113  /// created within a dag rewriter region. If the result types of the operation
114  /// should be inferred, `hasInferredResultTypes` is set to true.
115  void generateOperationResultTypeRewriter(
116  pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
117  SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
118  bool &hasInferredResultTypes);
119 
120  /// A builder to use when generating interpreter operations.
121  OpBuilder builder;
122 
123  /// The matcher function used for all match related logic within PDL patterns.
124  pdl_interp::FuncOp matcherFunc;
125 
126  /// The rewriter module containing the all rewrite related logic within PDL
127  /// patterns.
128  ModuleOp rewriterModule;
129 
130  /// The symbol table of the rewriter module used for insertion.
131  SymbolTable rewriterSymbolTable;
132 
133  /// A scoped map connecting a position with the corresponding interpreter
134  /// value.
135  ValueMap values;
136 
137  /// A stack of blocks used as the failure destination for matcher nodes that
138  /// don't have an explicit failure path.
139  SmallVector<Block *, 8> failureBlockStack;
140 
141  /// A mapping between values defined in a pattern match, and the corresponding
142  /// positional value.
143  DenseMap<Value, Position *> valueToPosition;
144 
145  /// The set of operation values whose location will be used for newly
146  /// generated operations.
147  SetVector<Value> locOps;
148 
149  /// A mapping between pattern operations and the corresponding configuration
150  /// set.
152 
153  /// A mapping from a constraint question to the ApplyConstraintOp
154  /// that implements it.
156 };
157 } // namespace
158 
159 PatternLowering::PatternLowering(
160  pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
162  : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
163  rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule),
164  configMap(configMap) {}
165 
166 void PatternLowering::lower(ModuleOp module) {
167  PredicateUniquer predicateUniquer;
168  PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
169 
170  // Define top-level scope for the arguments to the matcher function.
171  ValueMapScope topLevelValueScope(values);
172 
173  // Insert the root operation, i.e. argument to the matcher, at the root
174  // position.
175  Block *matcherEntryBlock = &matcherFunc.front();
176  values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
177 
178  // Generate a root matcher node from the provided PDL module.
179  std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
180  module, predicateBuilder, valueToPosition);
181  Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody());
182  assert(failureBlockStack.empty() && "failed to empty the stack");
183 
184  // After generation, merged the first matched block into the entry.
185  matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
186  firstMatcherBlock->getOperations());
187  firstMatcherBlock->erase();
188 }
189 
190 Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region,
191  Block *block) {
192  // Push a new scope for the values used by this matcher.
193  if (!block)
194  block = &region.emplaceBlock();
195  ValueMapScope scope(values);
196 
197  // If this is the return node, simply insert the corresponding interpreter
198  // finalize.
199  if (isa<ExitNode>(node)) {
200  builder.setInsertionPointToEnd(block);
201  builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
202  return block;
203  }
204 
205  // Get the next block in the match sequence.
206  // This is intentionally executed first, before we get the value for the
207  // position associated with the node, so that we preserve an "there exist"
208  // semantics: if getting a value requires an upward traversal (going from a
209  // value to its consumers), we want to perform the check on all the consumers
210  // before we pass control to the failure node.
211  std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
212  Block *failureBlock;
213  if (failureNode) {
214  failureBlock = generateMatcher(*failureNode, region);
215  failureBlockStack.push_back(failureBlock);
216  } else {
217  assert(!failureBlockStack.empty() && "expected valid failure block");
218  failureBlock = failureBlockStack.back();
219  }
220 
221  // If this node contains a position, get the corresponding value for this
222  // block.
223  Block *currentBlock = block;
224  Position *position = node.getPosition();
225  Value val = position ? getValueAt(currentBlock, position) : Value();
226 
227  // If this value corresponds to an operation, record that we are going to use
228  // its location as part of a fused location.
229  bool isOperationValue = val && isa<pdl::OperationType>(val.getType());
230  if (isOperationValue)
231  locOps.insert(val);
232 
233  // Dispatch to the correct method based on derived node type.
235  .Case<BoolNode, SwitchNode>([&](auto *derivedNode) {
236  this->generate(derivedNode, currentBlock, val);
237  })
238  .Case([&](SuccessNode *successNode) {
239  generate(successNode, currentBlock);
240  });
241 
242  // Pop all the failure blocks that were inserted due to nesting of
243  // pdl_interp.iterate.
244  while (failureBlockStack.back() != failureBlock) {
245  failureBlockStack.pop_back();
246  assert(!failureBlockStack.empty() && "unable to locate failure block");
247  }
248 
249  // Pop the new failure block.
250  if (failureNode)
251  failureBlockStack.pop_back();
252 
253  if (isOperationValue)
254  locOps.remove(val);
255 
256  return block;
257 }
258 
259 Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
260  if (Value val = values.lookup(pos))
261  return val;
262 
263  // Get the value for the parent position.
264  Value parentVal;
265  if (Position *parent = pos->getParent())
266  parentVal = getValueAt(currentBlock, parent);
267 
268  // TODO: Use a location from the position.
269  Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc();
270  builder.setInsertionPointToEnd(currentBlock);
271  Value value;
272  switch (pos->getKind()) {
274  auto *operationPos = cast<OperationPosition>(pos);
275  if (operationPos->isOperandDefiningOp())
276  // Standard (downward) traversal which directly follows the defining op.
277  value = builder.create<pdl_interp::GetDefiningOpOp>(
278  loc, builder.getType<pdl::OperationType>(), parentVal);
279  else
280  // A passthrough operation position.
281  value = parentVal;
282  break;
283  }
284  case Predicates::UsersPos: {
285  auto *usersPos = cast<UsersPosition>(pos);
286 
287  // The first operation retrieves the representative value of a range.
288  // This applies only when the parent is a range of values and we were
289  // requested to use a representative value (e.g., upward traversal).
290  if (isa<pdl::RangeType>(parentVal.getType()) &&
291  usersPos->useRepresentative())
292  value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0);
293  else
294  value = parentVal;
295 
296  // The second operation retrieves the users.
297  value = builder.create<pdl_interp::GetUsersOp>(loc, value);
298  break;
299  }
300  case Predicates::ForEachPos: {
301  assert(!failureBlockStack.empty() && "expected valid failure block");
302  auto foreach = builder.create<pdl_interp::ForEachOp>(
303  loc, parentVal, failureBlockStack.back(), /*initLoop=*/true);
304  value = foreach.getLoopVariable();
305 
306  // Create the continuation block.
307  Block *continueBlock = builder.createBlock(&foreach.getRegion());
308  builder.create<pdl_interp::ContinueOp>(loc);
309  failureBlockStack.push_back(continueBlock);
310 
311  currentBlock = &foreach.getRegion().front();
312  break;
313  }
314  case Predicates::OperandPos: {
315  auto *operandPos = cast<OperandPosition>(pos);
316  value = builder.create<pdl_interp::GetOperandOp>(
317  loc, builder.getType<pdl::ValueType>(), parentVal,
318  operandPos->getOperandNumber());
319  break;
320  }
322  auto *operandPos = cast<OperandGroupPosition>(pos);
323  Type valueTy = builder.getType<pdl::ValueType>();
324  value = builder.create<pdl_interp::GetOperandsOp>(
325  loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
326  parentVal, operandPos->getOperandGroupNumber());
327  break;
328  }
330  auto *attrPos = cast<AttributePosition>(pos);
331  value = builder.create<pdl_interp::GetAttributeOp>(
332  loc, builder.getType<pdl::AttributeType>(), parentVal,
333  attrPos->getName().strref());
334  break;
335  }
336  case Predicates::TypePos: {
337  if (isa<pdl::AttributeType>(parentVal.getType()))
338  value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
339  else
340  value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
341  break;
342  }
343  case Predicates::ResultPos: {
344  auto *resPos = cast<ResultPosition>(pos);
345  value = builder.create<pdl_interp::GetResultOp>(
346  loc, builder.getType<pdl::ValueType>(), parentVal,
347  resPos->getResultNumber());
348  break;
349  }
351  auto *resPos = cast<ResultGroupPosition>(pos);
352  Type valueTy = builder.getType<pdl::ValueType>();
353  value = builder.create<pdl_interp::GetResultsOp>(
354  loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
355  parentVal, resPos->getResultGroupNumber());
356  break;
357  }
359  auto *attrPos = cast<AttributeLiteralPosition>(pos);
360  value =
361  builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue());
362  break;
363  }
365  auto *typePos = cast<TypeLiteralPosition>(pos);
366  Attribute rawTypeAttr = typePos->getValue();
367  if (TypeAttr typeAttr = dyn_cast<TypeAttr>(rawTypeAttr))
368  value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr);
369  else
370  value = builder.create<pdl_interp::CreateTypesOp>(
371  loc, cast<ArrayAttr>(rawTypeAttr));
372  break;
373  }
375  // Due to the order of traversal, the ApplyConstraintOp has already been
376  // created and we can find it in constraintOpMap.
377  auto *constrResPos = cast<ConstraintPosition>(pos);
378  auto i = constraintOpMap.find(constrResPos->getQuestion());
379  assert(i != constraintOpMap.end());
380  value = i->second->getResult(constrResPos->getIndex());
381  break;
382  }
383  default:
384  llvm_unreachable("Generating unknown Position getter");
385  break;
386  }
387 
388  values.insert(pos, value);
389  return value;
390 }
391 
392 void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
393  Value val) {
394  Location loc = val.getLoc();
395  Qualifier *question = boolNode->getQuestion();
396  Qualifier *answer = boolNode->getAnswer();
397  Region *region = currentBlock->getParent();
398 
399  // Execute the getValue queries first, so that we create success
400  // matcher in the correct (possibly nested) region.
401  SmallVector<Value> args;
402  if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
403  args = {getValueAt(currentBlock, equalToQuestion->getValue())};
404  } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
405  for (Position *position : cstQuestion->getArgs())
406  args.push_back(getValueAt(currentBlock, position));
407  }
408 
409  // Generate a new block as success successor and get the failure successor.
410  Block *success = &region->emplaceBlock();
411  Block *failure = failureBlockStack.back();
412 
413  // Create the predicate.
414  builder.setInsertionPointToEnd(currentBlock);
415  Predicates::Kind kind = question->getKind();
416  switch (kind) {
418  builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure);
419  break;
421  auto *opNameAnswer = cast<OperationNameAnswer>(answer);
422  builder.create<pdl_interp::CheckOperationNameOp>(
423  loc, val, opNameAnswer->getValue().getStringRef(), success, failure);
424  break;
425  }
427  auto *ans = cast<TypeAnswer>(answer);
428  if (isa<pdl::RangeType>(val.getType()))
429  builder.create<pdl_interp::CheckTypesOp>(
430  loc, val, llvm::cast<ArrayAttr>(ans->getValue()), success, failure);
431  else
432  builder.create<pdl_interp::CheckTypeOp>(
433  loc, val, llvm::cast<TypeAttr>(ans->getValue()), success, failure);
434  break;
435  }
437  auto *ans = cast<AttributeAnswer>(answer);
438  builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
439  success, failure);
440  break;
441  }
444  builder.create<pdl_interp::CheckOperandCountOp>(
445  loc, val, cast<UnsignedAnswer>(answer)->getValue(),
446  /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
447  success, failure);
448  break;
451  builder.create<pdl_interp::CheckResultCountOp>(
452  loc, val, cast<UnsignedAnswer>(answer)->getValue(),
453  /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
454  success, failure);
455  break;
457  bool trueAnswer = isa<TrueAnswer>(answer);
458  builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(),
459  trueAnswer ? success : failure,
460  trueAnswer ? failure : success);
461  break;
462  }
464  auto *cstQuestion = cast<ConstraintQuestion>(question);
465  auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
466  loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
467  cstQuestion->getIsNegated(), success, failure);
468 
469  constraintOpMap.insert({cstQuestion, applyConstraintOp});
470  break;
471  }
472  default:
473  llvm_unreachable("Generating unknown Predicate operation");
474  }
475 
476  // Generate the matcher in the current (potentially nested) region.
477  // This might use the results of the current predicate.
478  generateMatcher(*boolNode->getSuccessNode(), *region, success);
479 }
480 
481 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
482 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
483  llvm::MapVector<Qualifier *, Block *> &dests) {
484  std::vector<ValT> values;
485  std::vector<Block *> blocks;
486  values.reserve(dests.size());
487  blocks.reserve(dests.size());
488  for (const auto &it : dests) {
489  blocks.push_back(it.second);
490  values.push_back(cast<PredT>(it.first)->getValue());
491  }
492  builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
493 }
494 
495 void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
496  Value val) {
497  Qualifier *question = switchNode->getQuestion();
498  Region *region = currentBlock->getParent();
499  Block *defaultDest = failureBlockStack.back();
500 
501  // If the switch question is not an exact answer, i.e. for the `at_least`
502  // cases, we generate a special block sequence.
503  Predicates::Kind kind = question->getKind();
506  // Order the children such that the cases are in reverse numerical order.
507  SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
508  llvm::seq<unsigned>(0, switchNode->getChildren().size()));
509  llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) {
510  return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() >
511  cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue();
512  });
513 
514  // Build the destination for each child using the next highest child as a
515  // a failure destination. This essentially creates the following control
516  // flow:
517  //
518  // if (operand_count < 1)
519  // goto failure
520  // if (child1.match())
521  // ...
522  //
523  // if (operand_count < 2)
524  // goto failure
525  // if (child2.match())
526  // ...
527  //
528  // failure:
529  // ...
530  //
531  failureBlockStack.push_back(defaultDest);
532  Location loc = val.getLoc();
533  for (unsigned idx : sortedChildren) {
534  auto &child = switchNode->getChild(idx);
535  Block *childBlock = generateMatcher(*child.second, *region);
536  Block *predicateBlock = builder.createBlock(childBlock);
537  builder.setInsertionPointToEnd(predicateBlock);
538  unsigned ans = cast<UnsignedAnswer>(child.first)->getValue();
539  switch (kind) {
541  builder.create<pdl_interp::CheckOperandCountOp>(
542  loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
543  break;
545  builder.create<pdl_interp::CheckResultCountOp>(
546  loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
547  break;
548  default:
549  llvm_unreachable("Generating invalid AtLeast operation");
550  }
551  failureBlockStack.back() = predicateBlock;
552  }
553  Block *firstPredicateBlock = failureBlockStack.pop_back_val();
554  currentBlock->getOperations().splice(currentBlock->end(),
555  firstPredicateBlock->getOperations());
556  firstPredicateBlock->erase();
557  return;
558  }
559 
560  // Otherwise, generate each of the children and generate an interpreter
561  // switch.
562  llvm::MapVector<Qualifier *, Block *> children;
563  for (auto &it : switchNode->getChildren())
564  children.insert({it.first, generateMatcher(*it.second, *region)});
565  builder.setInsertionPointToEnd(currentBlock);
566 
567  switch (question->getKind()) {
569  return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
570  int32_t>(val, defaultDest, builder, children);
572  return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
573  int32_t>(val, defaultDest, builder, children);
575  return createSwitchOp<pdl_interp::SwitchOperationNameOp,
576  OperationNameAnswer>(val, defaultDest, builder,
577  children);
579  if (isa<pdl::RangeType>(val.getType())) {
580  return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
581  val, defaultDest, builder, children);
582  }
583  return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
584  val, defaultDest, builder, children);
586  return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
587  val, defaultDest, builder, children);
588  default:
589  llvm_unreachable("Generating unknown switch predicate.");
590  }
591 }
592 
593 void PatternLowering::generate(SuccessNode *successNode, Block *&currentBlock) {
594  pdl::PatternOp pattern = successNode->getPattern();
595  Value root = successNode->getRoot();
596 
597  // Generate a rewriter for the pattern this success node represents, and track
598  // any values used from the match region.
599  SmallVector<Position *, 8> usedMatchValues;
600  SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
601 
602  // Process any values used in the rewrite that are defined in the match.
603  std::vector<Value> mappedMatchValues;
604  mappedMatchValues.reserve(usedMatchValues.size());
605  for (Position *position : usedMatchValues)
606  mappedMatchValues.push_back(getValueAt(currentBlock, position));
607 
608  // Collect the set of operations generated by the rewriter.
609  SmallVector<StringRef, 4> generatedOps;
610  for (auto op :
611  pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>())
612  generatedOps.push_back(*op.getOpName());
613  ArrayAttr generatedOpsAttr;
614  if (!generatedOps.empty())
615  generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
616 
617  // Grab the root kind if present.
618  StringAttr rootKindAttr;
619  if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>())
620  if (std::optional<StringRef> rootKind = rootOp.getOpName())
621  rootKindAttr = builder.getStringAttr(*rootKind);
622 
623  builder.setInsertionPointToEnd(currentBlock);
624  auto matchOp = builder.create<pdl_interp::RecordMatchOp>(
625  pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
626  rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(),
627  failureBlockStack.back());
628 
629  // Set the config of the lowered match to the parent pattern.
630  if (configMap)
631  configMap->try_emplace(matchOp, configMap->lookup(pattern));
632 }
633 
634 SymbolRefAttr PatternLowering::generateRewriter(
635  pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
636  builder.setInsertionPointToEnd(rewriterModule.getBody());
637  auto rewriterFunc = builder.create<pdl_interp::FuncOp>(
638  pattern.getLoc(), "pdl_generated_rewriter",
639  builder.getFunctionType(std::nullopt, std::nullopt));
640  rewriterSymbolTable.insert(rewriterFunc);
641 
642  // Generate the rewriter function body.
643  builder.setInsertionPointToEnd(&rewriterFunc.front());
644 
645  // Map an input operand of the pattern to a generated interpreter value.
646  DenseMap<Value, Value> rewriteValues;
647  auto mapRewriteValue = [&](Value oldValue) {
648  Value &newValue = rewriteValues[oldValue];
649  if (newValue)
650  return newValue;
651 
652  // Prefer materializing constants directly when possible.
653  Operation *oldOp = oldValue.getDefiningOp();
654  if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
655  if (Attribute value = attrOp.getValueAttr()) {
656  return newValue = builder.create<pdl_interp::CreateAttributeOp>(
657  attrOp.getLoc(), value);
658  }
659  } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
660  if (TypeAttr type = typeOp.getConstantTypeAttr()) {
661  return newValue = builder.create<pdl_interp::CreateTypeOp>(
662  typeOp.getLoc(), type);
663  }
664  } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
665  if (ArrayAttr type = typeOp.getConstantTypesAttr()) {
666  return newValue = builder.create<pdl_interp::CreateTypesOp>(
667  typeOp.getLoc(), typeOp.getType(), type);
668  }
669  }
670 
671  // Otherwise, add this as an input to the rewriter.
672  Position *inputPos = valueToPosition.lookup(oldValue);
673  assert(inputPos && "expected value to be a pattern input");
674  usedMatchValues.push_back(inputPos);
675  return newValue = rewriterFunc.front().addArgument(oldValue.getType(),
676  oldValue.getLoc());
677  };
678 
679  // If this is a custom rewriter, simply dispatch to the registered rewrite
680  // method.
681  pdl::RewriteOp rewriter = pattern.getRewriter();
682  if (StringAttr rewriteName = rewriter.getNameAttr()) {
683  SmallVector<Value> args;
684  if (rewriter.getRoot())
685  args.push_back(mapRewriteValue(rewriter.getRoot()));
686  auto mappedArgs =
687  llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue);
688  args.append(mappedArgs.begin(), mappedArgs.end());
689  builder.create<pdl_interp::ApplyRewriteOp>(
690  rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args);
691  } else {
692  // Otherwise this is a dag rewriter defined using PDL operations.
693  for (Operation &rewriteOp : *rewriter.getBody()) {
695  .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
696  pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp,
697  pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>([&](auto op) {
698  this->generateRewriter(op, rewriteValues, mapRewriteValue);
699  });
700  }
701  }
702 
703  // Update the signature of the rewrite function.
704  rewriterFunc.setType(builder.getFunctionType(
705  llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
706  /*results=*/std::nullopt));
707 
708  builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
709  return SymbolRefAttr::get(
710  builder.getContext(),
711  pdl_interp::PDLInterpDialect::getRewriterModuleName(),
712  SymbolRefAttr::get(rewriterFunc));
713 }
714 
715 void PatternLowering::generateRewriter(
716  pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
717  function_ref<Value(Value)> mapRewriteValue) {
718  SmallVector<Value, 2> arguments;
719  for (Value argument : rewriteOp.getArgs())
720  arguments.push_back(mapRewriteValue(argument));
721  auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
722  rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.getNameAttr(),
723  arguments);
724  for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
725  rewriteValues[std::get<0>(it)] = std::get<1>(it);
726 }
727 
728 void PatternLowering::generateRewriter(
729  pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
730  function_ref<Value(Value)> mapRewriteValue) {
731  Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
732  attrOp.getLoc(), attrOp.getValueAttr());
733  rewriteValues[attrOp] = newAttr;
734 }
735 
736 void PatternLowering::generateRewriter(
737  pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
738  function_ref<Value(Value)> mapRewriteValue) {
739  builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
740  mapRewriteValue(eraseOp.getOpValue()));
741 }
742 
743 void PatternLowering::generateRewriter(
744  pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
745  function_ref<Value(Value)> mapRewriteValue) {
746  SmallVector<Value, 4> operands;
747  for (Value operand : operationOp.getOperandValues())
748  operands.push_back(mapRewriteValue(operand));
749 
750  SmallVector<Value, 4> attributes;
751  for (Value attr : operationOp.getAttributeValues())
752  attributes.push_back(mapRewriteValue(attr));
753 
754  bool hasInferredResultTypes = false;
755  SmallVector<Value, 2> types;
756  generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types,
757  rewriteValues, hasInferredResultTypes);
758 
759  // Create the new operation.
760  Location loc = operationOp.getLoc();
761  Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
762  loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands,
763  attributes, operationOp.getAttributeValueNames());
764  rewriteValues[operationOp.getOp()] = createdOp;
765 
766  // Generate accesses for any results that have their types constrained.
767  // Handle the case where there is a single range representing all of the
768  // result types.
769  OperandRange resultTys = operationOp.getTypeValues();
770  if (resultTys.size() == 1 && isa<pdl::RangeType>(resultTys[0].getType())) {
771  Value &type = rewriteValues[resultTys[0]];
772  if (!type) {
773  auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
774  type = builder.create<pdl_interp::GetValueTypeOp>(loc, results);
775  }
776  return;
777  }
778 
779  // Otherwise, populate the individual results.
780  bool seenVariableLength = false;
781  Type valueTy = builder.getType<pdl::ValueType>();
782  Type valueRangeTy = pdl::RangeType::get(valueTy);
783  for (const auto &it : llvm::enumerate(resultTys)) {
784  Value &type = rewriteValues[it.value()];
785  if (type)
786  continue;
787  bool isVariadic = isa<pdl::RangeType>(it.value().getType());
788  seenVariableLength |= isVariadic;
789 
790  // After a variable length result has been seen, we need to use result
791  // groups because the exact index of the result is not statically known.
792  Value resultVal;
793  if (seenVariableLength)
794  resultVal = builder.create<pdl_interp::GetResultsOp>(
795  loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index());
796  else
797  resultVal = builder.create<pdl_interp::GetResultOp>(
798  loc, valueTy, createdOp, it.index());
799  type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal);
800  }
801 }
802 
803 void PatternLowering::generateRewriter(
804  pdl::RangeOp rangeOp, DenseMap<Value, Value> &rewriteValues,
805  function_ref<Value(Value)> mapRewriteValue) {
806  SmallVector<Value, 4> replOperands;
807  for (Value operand : rangeOp.getArguments())
808  replOperands.push_back(mapRewriteValue(operand));
809  rewriteValues[rangeOp] = builder.create<pdl_interp::CreateRangeOp>(
810  rangeOp.getLoc(), rangeOp.getType(), replOperands);
811 }
812 
813 void PatternLowering::generateRewriter(
814  pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
815  function_ref<Value(Value)> mapRewriteValue) {
816  SmallVector<Value, 4> replOperands;
817 
818  // If the replacement was another operation, get its results. `pdl` allows
819  // for using an operation for simplicitly, but the interpreter isn't as
820  // user facing.
821  if (Value replOp = replaceOp.getReplOperation()) {
822  // Don't use replace if we know the replaced operation has no results.
823  auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>();
824  if (!opOp || !opOp.getTypeValues().empty()) {
825  replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
826  replOp.getLoc(), mapRewriteValue(replOp)));
827  }
828  } else {
829  for (Value operand : replaceOp.getReplValues())
830  replOperands.push_back(mapRewriteValue(operand));
831  }
832 
833  // If there are no replacement values, just create an erase instead.
834  if (replOperands.empty()) {
835  builder.create<pdl_interp::EraseOp>(
836  replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue()));
837  return;
838  }
839 
840  builder.create<pdl_interp::ReplaceOp>(replaceOp.getLoc(),
841  mapRewriteValue(replaceOp.getOpValue()),
842  replOperands);
843 }
844 
845 void PatternLowering::generateRewriter(
846  pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
847  function_ref<Value(Value)> mapRewriteValue) {
848  rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
849  resultOp.getLoc(), builder.getType<pdl::ValueType>(),
850  mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
851 }
852 
853 void PatternLowering::generateRewriter(
854  pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
855  function_ref<Value(Value)> mapRewriteValue) {
856  rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
857  resultOp.getLoc(), resultOp.getType(),
858  mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
859 }
860 
861 void PatternLowering::generateRewriter(
862  pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
863  function_ref<Value(Value)> mapRewriteValue) {
864  // If the type isn't constant, the users (e.g. OperationOp) will resolve this
865  // type.
866  if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) {
867  rewriteValues[typeOp] =
868  builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
869  }
870 }
871 
872 void PatternLowering::generateRewriter(
873  pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
874  function_ref<Value(Value)> mapRewriteValue) {
875  // If the type isn't constant, the users (e.g. OperationOp) will resolve this
876  // type.
877  if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) {
878  rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
879  typeOp.getLoc(), typeOp.getType(), typeAttr);
880  }
881 }
882 
883 void PatternLowering::generateOperationResultTypeRewriter(
884  pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
885  SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
886  bool &hasInferredResultTypes) {
887  Block *rewriterBlock = op->getBlock();
888 
889  // Try to handle resolution for each of the result types individually. This is
890  // preferred over type inferrence because it will allow for us to use existing
891  // types directly, as opposed to trying to rebuild the type list.
892  OperandRange resultTypeValues = op.getTypeValues();
893  auto tryResolveResultTypes = [&] {
894  types.reserve(resultTypeValues.size());
895  for (const auto &it : llvm::enumerate(resultTypeValues)) {
896  Value resultType = it.value();
897 
898  // Check for an already translated value.
899  if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
900  types.push_back(existingRewriteValue);
901  continue;
902  }
903 
904  // Check for an input from the matcher.
905  if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
906  types.push_back(mapRewriteValue(resultType));
907  continue;
908  }
909 
910  // Otherwise, we couldn't infer the result types. Bail out here to see if
911  // we can infer the types for this operation from another way.
912  types.clear();
913  return failure();
914  }
915  return success();
916  };
917  if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes()))
918  return;
919 
920  // Otherwise, check if the operation has type inference support itself.
921  if (op.hasTypeInference()) {
922  hasInferredResultTypes = true;
923  return;
924  }
925 
926  // Look for an operation that was replaced by `op`. The result types will be
927  // inferred from the results that were replaced.
928  for (OpOperand &use : op.getOp().getUses()) {
929  // Check that the use corresponds to a ReplaceOp and that it is the
930  // replacement value, not the operation being replaced.
931  pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
932  if (!replOpUser || use.getOperandNumber() == 0)
933  continue;
934  // Make sure the replaced operation was defined before this one. PDL
935  // rewrites only have single block regions, so if the op isn't in the
936  // rewriter block (i.e. the current block of the operation) we already know
937  // it dominates (i.e. it's in the matcher).
938  Value replOpVal = replOpUser.getOpValue();
939  Operation *replacedOp = replOpVal.getDefiningOp();
940  if (replacedOp->getBlock() == rewriterBlock &&
941  !replacedOp->isBeforeInBlock(op))
942  continue;
943 
944  Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
945  replacedOp->getLoc(), mapRewriteValue(replOpVal));
946  types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
947  replacedOp->getLoc(), replacedOpResults));
948  return;
949  }
950 
951  // If the types could not be inferred from any context and there weren't any
952  // explicit result types, assume the user actually meant for the operation to
953  // have no results.
954  if (resultTypeValues.empty())
955  return;
956 
957  // The verifier asserts that the result types of each pdl.getOperation can be
958  // inferred. If we reach here, there is a bug either in the logic above or
959  // in the verifier for pdl.getOperation.
960  op->emitOpError() << "unable to infer result type for operation";
961  llvm_unreachable("unable to infer result type for operation");
962 }
963 
964 //===----------------------------------------------------------------------===//
965 // Conversion Pass
966 //===----------------------------------------------------------------------===//
967 
968 namespace {
969 struct PDLToPDLInterpPass
970  : public impl::ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
971  PDLToPDLInterpPass() = default;
972  PDLToPDLInterpPass(const PDLToPDLInterpPass &rhs) = default;
973  PDLToPDLInterpPass(DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
974  : configMap(&configMap) {}
975  void runOnOperation() final;
976 
977  /// A map containing the configuration for each pattern.
978  DenseMap<Operation *, PDLPatternConfigSet *> *configMap = nullptr;
979 };
980 } // namespace
981 
982 /// Convert the given module containing PDL pattern operations into a PDL
983 /// Interpreter operations.
984 void PDLToPDLInterpPass::runOnOperation() {
985  ModuleOp module = getOperation();
986 
987  // Create the main matcher function This function contains all of the match
988  // related functionality from patterns in the module.
989  OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
990  auto matcherFunc = builder.create<pdl_interp::FuncOp>(
991  module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
992  builder.getFunctionType(builder.getType<pdl::OperationType>(),
993  /*results=*/std::nullopt),
994  /*attrs=*/std::nullopt);
995 
996  // Create a nested module to hold the functions invoked for rewriting the IR
997  // after a successful match.
998  ModuleOp rewriterModule = builder.create<ModuleOp>(
999  module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
1000 
1001  // Generate the code for the patterns within the module.
1002  PatternLowering generator(matcherFunc, rewriterModule, configMap);
1003  generator.lower(module);
1004 
1005  // After generation, delete all of the pattern operations.
1006  for (pdl::PatternOp pattern :
1007  llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) {
1008  // Drop the now dead config mappings.
1009  if (configMap)
1010  configMap->erase(pattern);
1011 
1012  pattern.erase();
1013  }
1014 }
1015 
1016 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
1017  return std::make_unique<PDLToPDLInterpPass>();
1018 }
1019 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass(
1021  return std::make_unique<PDLToPDLInterpPass>(configMap);
1022 }
static MLIRContext * getContext(OpFoldResult val)
static const mlir::GenInfo * generator
static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, llvm::MapVector< Qualifier *, Block * > &dests)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Operation & back()
Definition: Block.h:152
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:68
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
void push_back(Operation *op)
Definition: Block.h:149
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:120
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:99
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:215
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents an operand of an operation.
Definition: Value.h:267
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Block & emplaceBlock()
Definition: Region.h:46
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This class represents the base of a predicate matcher node.
Definition: PredicateTree.h:50
Position * getPosition() const
Returns the position on which the question predicate should be checked.
Definition: PredicateTree.h:63
std::unique_ptr< MatcherNode > & getFailureNode()
Returns the node that should be visited if this, or a subsequent node fails.
Definition: PredicateTree.h:70
Qualifier * getQuestion() const
Returns the predicate checked on this node.
Definition: PredicateTree.h:66
A position describes a value on the input IR on which a predicate may be applied, such as an operatio...
Definition: Predicate.h:144
Position * getParent() const
Returns the parent position. The root operation position has no parent.
Definition: Predicate.h:153
Predicates::Kind getKind() const
Returns the kind of this position.
Definition: Predicate.h:156
This class provides utilities for constructing predicates.
Definition: Predicate.h:596
This class provides a storage uniquer that is used to allocate predicate instances.
Definition: Predicate.h:552
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Definition: Predicate.h:410
Predicates::Kind getKind() const
Returns the kind of this qualifier.
Definition: Predicate.h:415
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
@ OperationPos
Positions, ordered by decreasing priority.
Definition: Predicate.h:46
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< OperationPass< ModuleOp > > createPDLToPDLInterpPass()
Creates and returns a pass to convert PDL ops to PDL interpreter ops.
A BoolNode denotes a question with a boolean-like result.
std::unique_ptr< MatcherNode > & getSuccessNode()
Returns the node that should be visited on success.
Qualifier * getAnswer() const
Returns the expected answer of this boolean node.
An Answer representing an OperationName value.
Definition: Predicate.h:435
A SuccessNode denotes that a given high level pattern has successfully been matched.
pdl::PatternOp getPattern() const
Return the high level pattern operation that is matched with this node.
Value getRoot() const
Return the chosen root of the pattern.
A SwitchNode denotes a question with multiple potential results.
std::pair< Qualifier *, std::unique_ptr< MatcherNode > > & getChild(unsigned i)
Returns the child at the given index.