MLIR 23.0.0git
PDLToPDLInterp.cpp
Go to the documentation of this file.
1//===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
11#include "PredicateTree.h"
14#include "mlir/Pass/Pass.h"
15#include "llvm/ADT/MapVector.h"
16#include "llvm/ADT/ScopedHashTable.h"
17#include "llvm/ADT/Sequence.h"
18#include "llvm/ADT/SmallVector.h"
19#include "llvm/ADT/TypeSwitch.h"
20
21namespace mlir {
22#define GEN_PASS_DEF_CONVERTPDLTOPDLINTERPPASS
23#include "mlir/Conversion/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27using namespace mlir::pdl_to_pdl_interp;
28
29//===----------------------------------------------------------------------===//
30// PatternLowering
31//===----------------------------------------------------------------------===//
32
33namespace {
34/// This class generators operations within the PDL Interpreter dialect from a
35/// given module containing PDL pattern operations.
36struct PatternLowering {
37public:
38 PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
40
41 /// Generate code for matching and rewriting based on the pattern operations
42 /// within the module.
43 void lower(ModuleOp module);
44
45private:
46 using ValueMap = llvm::ScopedHashTable<Position *, Value>;
47 using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
48
49 /// Generate interpreter operations for the tree rooted at the given matcher
50 /// node, in the specified region.
51 Block *generateMatcher(MatcherNode &node, Region &region,
52 Block *block = nullptr);
53
54 /// Get or create an access to the provided positional value in the current
55 /// block. This operation may mutate the provided block pointer if nested
56 /// regions (i.e., pdl_interp.iterate) are required.
57 Value getValueAt(Block *&currentBlock, Position *pos);
58
59 /// Create the interpreter predicate operations. This operation may mutate the
60 /// provided current block pointer if nested regions (iterates) are required.
61 void generate(BoolNode *boolNode, Block *&currentBlock, Value val);
62
63 /// Create the interpreter switch / predicate operations, with several case
64 /// destinations. This operation never mutates the provided current block
65 /// pointer, because the switch operation does not need Values beyond `val`.
66 void generate(SwitchNode *switchNode, Block *currentBlock, Value val);
67
68 /// Create the interpreter operations to record a successful pattern match
69 /// using the contained root operation. This operation may mutate the current
70 /// block pointer if nested regions (i.e., pdl_interp.iterate) are required.
71 void generate(SuccessNode *successNode, Block *&currentBlock);
72
73 /// Generate a rewriter function for the given pattern operation, and returns
74 /// a reference to that function.
75 SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
76 SmallVectorImpl<Position *> &usedMatchValues);
77
78 /// Generate the rewriter code for the given operation.
79 void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
80 DenseMap<Value, Value> &rewriteValues,
81 function_ref<Value(Value)> mapRewriteValue);
82 void generateRewriter(pdl::AttributeOp attrOp,
83 DenseMap<Value, Value> &rewriteValues,
84 function_ref<Value(Value)> mapRewriteValue);
85 void generateRewriter(pdl::EraseOp eraseOp,
86 DenseMap<Value, Value> &rewriteValues,
87 function_ref<Value(Value)> mapRewriteValue);
88 void generateRewriter(pdl::OperationOp operationOp,
89 DenseMap<Value, Value> &rewriteValues,
90 function_ref<Value(Value)> mapRewriteValue);
91 void generateRewriter(pdl::RangeOp rangeOp,
92 DenseMap<Value, Value> &rewriteValues,
93 function_ref<Value(Value)> mapRewriteValue);
94 void generateRewriter(pdl::ReplaceOp replaceOp,
95 DenseMap<Value, Value> &rewriteValues,
96 function_ref<Value(Value)> mapRewriteValue);
97 void generateRewriter(pdl::ResultOp resultOp,
98 DenseMap<Value, Value> &rewriteValues,
99 function_ref<Value(Value)> mapRewriteValue);
100 void generateRewriter(pdl::ResultsOp resultOp,
101 DenseMap<Value, Value> &rewriteValues,
102 function_ref<Value(Value)> mapRewriteValue);
103 void generateRewriter(pdl::TypeOp typeOp,
104 DenseMap<Value, Value> &rewriteValues,
105 function_ref<Value(Value)> mapRewriteValue);
106 void generateRewriter(pdl::TypesOp typeOp,
107 DenseMap<Value, Value> &rewriteValues,
108 function_ref<Value(Value)> mapRewriteValue);
109
110 /// Generate the values used for resolving the result types of an operation
111 /// created within a dag rewriter region. If the result types of the operation
112 /// should be inferred, `hasInferredResultTypes` is set to true.
113 void generateOperationResultTypeRewriter(
114 pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
115 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
116 bool &hasInferredResultTypes);
117
118 /// A builder to use when generating interpreter operations.
119 OpBuilder builder;
120
121 /// The matcher function used for all match related logic within PDL patterns.
122 pdl_interp::FuncOp matcherFunc;
123
124 /// The rewriter module containing the all rewrite related logic within PDL
125 /// patterns.
126 ModuleOp rewriterModule;
127
128 /// The symbol table of the rewriter module used for insertion.
129 SymbolTable rewriterSymbolTable;
130
131 /// A scoped map connecting a position with the corresponding interpreter
132 /// value.
133 ValueMap values;
134
135 /// A stack of blocks used as the failure destination for matcher nodes that
136 /// don't have an explicit failure path.
137 SmallVector<Block *, 8> failureBlockStack;
138
139 /// A mapping between values defined in a pattern match, and the corresponding
140 /// positional value.
141 DenseMap<Value, Position *> valueToPosition;
142
143 /// The set of operation values whose location will be used for newly
144 /// generated operations.
145 SetVector<Value> locOps;
146
147 /// A mapping between pattern operations and the corresponding configuration
148 /// set.
150
151 /// A mapping from a constraint question to the ApplyConstraintOp
152 /// that implements it.
154};
155} // namespace
156
157PatternLowering::PatternLowering(
158 pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
160 : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
161 rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule),
162 configMap(configMap) {}
163
164void PatternLowering::lower(ModuleOp module) {
165 PredicateUniquer predicateUniquer;
166 PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
167
168 // Define top-level scope for the arguments to the matcher function.
169 ValueMapScope topLevelValueScope(values);
170
171 // Insert the root operation, i.e. argument to the matcher, at the root
172 // position.
173 Block *matcherEntryBlock = &matcherFunc.front();
174 values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
175
176 // Generate a root matcher node from the provided PDL module.
177 std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
178 module, predicateBuilder, valueToPosition);
179 Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody());
180 assert(failureBlockStack.empty() && "failed to empty the stack");
181
182 // After generation, merged the first matched block into the entry.
183 matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
184 firstMatcherBlock->getOperations());
185 firstMatcherBlock->erase();
186}
187
188Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region,
189 Block *block) {
190 // Push a new scope for the values used by this matcher.
191 if (!block)
192 block = &region.emplaceBlock();
193 ValueMapScope scope(values);
194
195 // If this is the return node, simply insert the corresponding interpreter
196 // finalize.
197 if (isa<ExitNode>(node)) {
198 builder.setInsertionPointToEnd(block);
199 pdl_interp::FinalizeOp::create(builder, matcherFunc.getLoc());
200 return block;
201 }
202
203 // Get the next block in the match sequence.
204 // This is intentionally executed first, before we get the value for the
205 // position associated with the node, so that we preserve an "there exist"
206 // semantics: if getting a value requires an upward traversal (going from a
207 // value to its consumers), we want to perform the check on all the consumers
208 // before we pass control to the failure node.
209 std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
210 Block *failureBlock;
211 if (failureNode) {
212 failureBlock = generateMatcher(*failureNode, region);
213 failureBlockStack.push_back(failureBlock);
214 } else {
215 assert(!failureBlockStack.empty() && "expected valid failure block");
216 failureBlock = failureBlockStack.back();
217 }
218
219 // If this node contains a position, get the corresponding value for this
220 // block.
221 Block *currentBlock = block;
222 Position *position = node.getPosition();
223 Value val = position ? getValueAt(currentBlock, position) : Value();
224
225 // If this value corresponds to an operation, record that we are going to use
226 // its location as part of a fused location.
227 bool isOperationValue = val && isa<pdl::OperationType>(val.getType());
228 if (isOperationValue)
229 locOps.insert(val);
230
231 // Dispatch to the correct method based on derived node type.
233 .Case<BoolNode, SwitchNode>([&](auto *derivedNode) {
234 this->generate(derivedNode, currentBlock, val);
235 })
236 .Case([&](SuccessNode *successNode) {
237 generate(successNode, currentBlock);
238 });
239
240 // Pop all the failure blocks that were inserted due to nesting of
241 // pdl_interp.iterate.
242 while (failureBlockStack.back() != failureBlock) {
243 failureBlockStack.pop_back();
244 assert(!failureBlockStack.empty() && "unable to locate failure block");
245 }
246
247 // Pop the new failure block.
248 if (failureNode)
249 failureBlockStack.pop_back();
250
251 if (isOperationValue)
252 locOps.remove(val);
253
254 return block;
255}
256
257Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
258 if (Value val = values.lookup(pos))
259 return val;
260
261 // Get the value for the parent position.
262 Value parentVal;
263 if (Position *parent = pos->getParent())
264 parentVal = getValueAt(currentBlock, parent);
265
266 // TODO: Use a location from the position.
267 Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc();
268 builder.setInsertionPointToEnd(currentBlock);
269 Value value;
270 switch (pos->getKind()) {
272 auto *operationPos = cast<OperationPosition>(pos);
273 if (operationPos->isOperandDefiningOp())
274 // Standard (downward) traversal which directly follows the defining op.
275 value = pdl_interp::GetDefiningOpOp::create(
276 builder, loc, builder.getType<pdl::OperationType>(), parentVal);
277 else
278 // A passthrough operation position.
279 value = parentVal;
280 break;
281 }
283 auto *usersPos = cast<UsersPosition>(pos);
284
285 // The first operation retrieves the representative value of a range.
286 // This applies only when the parent is a range of values and we were
287 // requested to use a representative value (e.g., upward traversal).
288 if (isa<pdl::RangeType>(parentVal.getType()) &&
289 usersPos->useRepresentative())
290 value = pdl_interp::ExtractOp::create(builder, loc, parentVal, 0);
291 else
292 value = parentVal;
293
294 // The second operation retrieves the users.
295 value = pdl_interp::GetUsersOp::create(builder, loc, value);
296 break;
297 }
299 assert(!failureBlockStack.empty() && "expected valid failure block");
300 auto foreach = pdl_interp::ForEachOp::create(
301 builder, loc, parentVal, failureBlockStack.back(), /*initLoop=*/true);
302 value = foreach.getLoopVariable();
303
304 // Create the continuation block.
305 Block *continueBlock = builder.createBlock(&foreach.getRegion());
306 pdl_interp::ContinueOp::create(builder, loc);
307 failureBlockStack.push_back(continueBlock);
308
309 currentBlock = &foreach.getRegion().front();
310 break;
311 }
313 auto *operandPos = cast<OperandPosition>(pos);
314 value = pdl_interp::GetOperandOp::create(
315 builder, loc, builder.getType<pdl::ValueType>(), parentVal,
316 operandPos->getOperandNumber());
317 break;
318 }
320 auto *operandPos = cast<OperandGroupPosition>(pos);
321 Type valueTy = builder.getType<pdl::ValueType>();
322 value = pdl_interp::GetOperandsOp::create(
323 builder, loc,
324 operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
325 parentVal, operandPos->getOperandGroupNumber());
326 break;
327 }
329 auto *attrPos = cast<AttributePosition>(pos);
330 value = pdl_interp::GetAttributeOp::create(
331 builder, loc, builder.getType<pdl::AttributeType>(), parentVal,
332 attrPos->getName().strref());
333 break;
334 }
335 case Predicates::TypePos: {
336 if (isa<pdl::AttributeType>(parentVal.getType()))
337 value = pdl_interp::GetAttributeTypeOp::create(builder, loc, parentVal);
338 else
339 value = pdl_interp::GetValueTypeOp::create(builder, loc, parentVal);
340 break;
341 }
343 auto *resPos = cast<ResultPosition>(pos);
344 value = pdl_interp::GetResultOp::create(
345 builder, loc, builder.getType<pdl::ValueType>(), parentVal,
346 resPos->getResultNumber());
347 break;
348 }
350 auto *resPos = cast<ResultGroupPosition>(pos);
351 Type valueTy = builder.getType<pdl::ValueType>();
352 value = pdl_interp::GetResultsOp::create(
353 builder, loc,
354 resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
355 parentVal, resPos->getResultGroupNumber());
356 break;
357 }
359 auto *attrPos = cast<AttributeLiteralPosition>(pos);
360 value = pdl_interp::CreateAttributeOp::create(builder, loc,
361 attrPos->getValue());
362 break;
363 }
365 auto *typePos = cast<TypeLiteralPosition>(pos);
366 Attribute rawTypeAttr = typePos->getValue();
367 if (TypeAttr typeAttr = dyn_cast<TypeAttr>(rawTypeAttr))
368 value = pdl_interp::CreateTypeOp::create(builder, loc, typeAttr);
369 else
370 value = pdl_interp::CreateTypesOp::create(builder, loc,
371 cast<ArrayAttr>(rawTypeAttr));
372 break;
373 }
375 // Due to the order of traversal, the ApplyConstraintOp has already been
376 // created and we can find it in constraintOpMap.
377 auto *constrResPos = cast<ConstraintPosition>(pos);
378 auto i = constraintOpMap.find(constrResPos->getQuestion());
379 assert(i != constraintOpMap.end());
380 value = i->second->getResult(constrResPos->getIndex());
381 break;
382 }
383 default:
384 llvm_unreachable("Generating unknown Position getter");
385 break;
386 }
387
388 values.insert(pos, value);
389 return value;
390}
391
392void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
393 Value val) {
394 Location loc = val.getLoc();
395 Qualifier *question = boolNode->getQuestion();
396 Qualifier *answer = boolNode->getAnswer();
397 Region *region = currentBlock->getParent();
398
399 // Execute the getValue queries first, so that we create success
400 // matcher in the correct (possibly nested) region.
401 SmallVector<Value> args;
402 if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
403 args = {getValueAt(currentBlock, equalToQuestion->getValue())};
404 } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
405 for (Position *position : cstQuestion->getArgs())
406 args.push_back(getValueAt(currentBlock, position));
407 }
408
409 // Generate a new block as success successor and get the failure successor.
410 Block *success = &region->emplaceBlock();
411 Block *failure = failureBlockStack.back();
412
413 // Create the predicate.
414 builder.setInsertionPointToEnd(currentBlock);
415 Predicates::Kind kind = question->getKind();
416 switch (kind) {
418 pdl_interp::IsNotNullOp::create(builder, loc, val, success, failure);
419 break;
421 auto *opNameAnswer = cast<OperationNameAnswer>(answer);
422 pdl_interp::CheckOperationNameOp::create(
423 builder, loc, val, opNameAnswer->getValue().getStringRef(), success,
424 failure);
425 break;
426 }
428 auto *ans = cast<TypeAnswer>(answer);
429 if (isa<pdl::RangeType>(val.getType()))
430 pdl_interp::CheckTypesOp::create(builder, loc, val,
431 llvm::cast<ArrayAttr>(ans->getValue()),
432 success, failure);
433 else
434 pdl_interp::CheckTypeOp::create(builder, loc, val,
435 llvm::cast<TypeAttr>(ans->getValue()),
436 success, failure);
437 break;
438 }
440 auto *ans = cast<AttributeAnswer>(answer);
441 pdl_interp::CheckAttributeOp::create(builder, loc, val, ans->getValue(),
442 success, failure);
443 break;
444 }
447 pdl_interp::CheckOperandCountOp::create(
448 builder, loc, val, cast<UnsignedAnswer>(answer)->getValue(),
449 /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
450 success, failure);
451 break;
454 pdl_interp::CheckResultCountOp::create(
455 builder, loc, val, cast<UnsignedAnswer>(answer)->getValue(),
456 /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
457 success, failure);
458 break;
460 bool trueAnswer = isa<TrueAnswer>(answer);
461 pdl_interp::AreEqualOp::create(builder, loc, val, args.front(),
462 trueAnswer ? success : failure,
463 trueAnswer ? failure : success);
464 break;
465 }
467 auto *cstQuestion = cast<ConstraintQuestion>(question);
468 auto applyConstraintOp = pdl_interp::ApplyConstraintOp::create(
469 builder, loc, cstQuestion->getResultTypes(), cstQuestion->getName(),
470 args, cstQuestion->getIsNegated(), success, failure);
471
472 constraintOpMap.insert({cstQuestion, applyConstraintOp});
473 break;
474 }
475 default:
476 llvm_unreachable("Generating unknown Predicate operation");
477 }
478
479 // Generate the matcher in the current (potentially nested) region.
480 // This might use the results of the current predicate.
481 generateMatcher(*boolNode->getSuccessNode(), *region, success);
482}
483
484template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
485static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
486 llvm::MapVector<Qualifier *, Block *> &dests) {
487 std::vector<ValT> values;
488 std::vector<Block *> blocks;
489 values.reserve(dests.size());
490 blocks.reserve(dests.size());
491 for (const auto &it : dests) {
492 blocks.push_back(it.second);
493 values.push_back(cast<PredT>(it.first)->getValue());
494 }
495 OpT::create(builder, val.getLoc(), val, values, defaultDest, blocks);
496}
497
498void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
499 Value val) {
500 Qualifier *question = switchNode->getQuestion();
501 Region *region = currentBlock->getParent();
502 Block *defaultDest = failureBlockStack.back();
503
504 // If the switch question is not an exact answer, i.e. for the `at_least`
505 // cases, we generate a special block sequence.
506 Predicates::Kind kind = question->getKind();
509 // Order the children such that the cases are in reverse numerical order.
510 SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
511 llvm::seq<unsigned>(0, switchNode->getChildren().size()));
512 llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) {
513 return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() >
514 cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue();
515 });
516
517 // Build the destination for each child using the next highest child as a
518 // a failure destination. This essentially creates the following control
519 // flow:
520 //
521 // if (operand_count < 1)
522 // goto failure
523 // if (child1.match())
524 // ...
525 //
526 // if (operand_count < 2)
527 // goto failure
528 // if (child2.match())
529 // ...
530 //
531 // failure:
532 // ...
533 //
534 failureBlockStack.push_back(defaultDest);
535 Location loc = val.getLoc();
536 for (unsigned idx : sortedChildren) {
537 auto &child = switchNode->getChild(idx);
538 Block *childBlock = generateMatcher(*child.second, *region);
539 Block *predicateBlock = builder.createBlock(childBlock);
540 builder.setInsertionPointToEnd(predicateBlock);
541 unsigned ans = cast<UnsignedAnswer>(child.first)->getValue();
542 switch (kind) {
544 pdl_interp::CheckOperandCountOp::create(builder, loc, val, ans,
545 /*compareAtLeast=*/true,
546 childBlock, defaultDest);
547 break;
549 pdl_interp::CheckResultCountOp::create(builder, loc, val, ans,
550 /*compareAtLeast=*/true,
551 childBlock, defaultDest);
552 break;
553 default:
554 llvm_unreachable("Generating invalid AtLeast operation");
555 }
556 failureBlockStack.back() = predicateBlock;
557 }
558 Block *firstPredicateBlock = failureBlockStack.pop_back_val();
559 currentBlock->getOperations().splice(currentBlock->end(),
560 firstPredicateBlock->getOperations());
561 firstPredicateBlock->erase();
562 return;
563 }
564
565 // Otherwise, generate each of the children and generate an interpreter
566 // switch.
567 llvm::MapVector<Qualifier *, Block *> children;
568 for (auto &it : switchNode->getChildren())
569 children.insert({it.first, generateMatcher(*it.second, *region)});
570 builder.setInsertionPointToEnd(currentBlock);
571
572 switch (question->getKind()) {
574 return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
575 int32_t>(val, defaultDest, builder, children);
577 return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
578 int32_t>(val, defaultDest, builder, children);
580 return createSwitchOp<pdl_interp::SwitchOperationNameOp,
581 OperationNameAnswer>(val, defaultDest, builder,
582 children);
584 if (isa<pdl::RangeType>(val.getType())) {
586 val, defaultDest, builder, children);
587 }
589 val, defaultDest, builder, children);
592 val, defaultDest, builder, children);
593 default:
594 llvm_unreachable("Generating unknown switch predicate.");
595 }
596}
597
598void PatternLowering::generate(SuccessNode *successNode, Block *&currentBlock) {
599 pdl::PatternOp pattern = successNode->getPattern();
600 Value root = successNode->getRoot();
601
602 // Generate a rewriter for the pattern this success node represents, and track
603 // any values used from the match region.
604 SmallVector<Position *, 8> usedMatchValues;
605 SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
606
607 // Process any values used in the rewrite that are defined in the match.
608 std::vector<Value> mappedMatchValues;
609 mappedMatchValues.reserve(usedMatchValues.size());
610 for (Position *position : usedMatchValues)
611 mappedMatchValues.push_back(getValueAt(currentBlock, position));
612
613 // Collect the set of operations generated by the rewriter.
614 SmallVector<StringRef, 4> generatedOps;
615 for (auto op :
616 pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>())
617 generatedOps.push_back(*op.getOpName());
618 ArrayAttr generatedOpsAttr;
619 if (!generatedOps.empty())
620 generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
621
622 // Grab the root kind if present.
623 StringAttr rootKindAttr;
624 if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>())
625 if (std::optional<StringRef> rootKind = rootOp.getOpName())
626 rootKindAttr = builder.getStringAttr(*rootKind);
627
628 builder.setInsertionPointToEnd(currentBlock);
629 auto matchOp = pdl_interp::RecordMatchOp::create(
630 builder, pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
631 rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(),
632 failureBlockStack.back());
633
634 // Set the config of the lowered match to the parent pattern.
635 if (configMap)
636 configMap->try_emplace(matchOp, configMap->lookup(pattern));
637}
638
639SymbolRefAttr PatternLowering::generateRewriter(
640 pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
641 builder.setInsertionPointToEnd(rewriterModule.getBody());
642 // Get the pattern name if available, otherwise use default
643 StringRef rewriterName = "pdl_generated_rewriter";
644 if (auto symName = pattern.getSymName())
645 rewriterName = symName.value();
646 auto rewriterFunc = pdl_interp::FuncOp::create(
647 builder, pattern.getLoc(), rewriterName, builder.getFunctionType({}, {}));
648 rewriterSymbolTable.insert(rewriterFunc);
649
650 // Generate the rewriter function body.
651 builder.setInsertionPointToEnd(&rewriterFunc.front());
652
653 // Map an input operand of the pattern to a generated interpreter value.
654 DenseMap<Value, Value> rewriteValues;
655 auto mapRewriteValue = [&](Value oldValue) {
656 Value &newValue = rewriteValues[oldValue];
657 if (newValue)
658 return newValue;
659
660 // Prefer materializing constants directly when possible.
661 Operation *oldOp = oldValue.getDefiningOp();
662 if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
663 if (Attribute value = attrOp.getValueAttr()) {
664 return newValue = pdl_interp::CreateAttributeOp::create(
665 builder, attrOp.getLoc(), value);
666 }
667 } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
668 if (TypeAttr type = typeOp.getConstantTypeAttr()) {
669 return newValue = pdl_interp::CreateTypeOp::create(
670 builder, typeOp.getLoc(), type);
671 }
672 } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
673 if (ArrayAttr type = typeOp.getConstantTypesAttr()) {
674 return newValue = pdl_interp::CreateTypesOp::create(
675 builder, typeOp.getLoc(), typeOp.getType(), type);
676 }
677 }
678
679 // Otherwise, add this as an input to the rewriter.
680 Position *inputPos = valueToPosition.lookup(oldValue);
681 assert(inputPos && "expected value to be a pattern input");
682 usedMatchValues.push_back(inputPos);
683 return newValue = rewriterFunc.front().addArgument(oldValue.getType(),
684 oldValue.getLoc());
685 };
686
687 // If this is a custom rewriter, simply dispatch to the registered rewrite
688 // method.
689 pdl::RewriteOp rewriter = pattern.getRewriter();
690 if (StringAttr rewriteName = rewriter.getNameAttr()) {
691 SmallVector<Value> args;
692 if (rewriter.getRoot())
693 args.push_back(mapRewriteValue(rewriter.getRoot()));
694 auto mappedArgs =
695 llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue);
696 args.append(mappedArgs.begin(), mappedArgs.end());
697 pdl_interp::ApplyRewriteOp::create(builder, rewriter.getLoc(),
698 /*results=*/TypeRange(), rewriteName,
699 args);
700 } else {
701 // Otherwise this is a dag rewriter defined using PDL operations.
702 for (Operation &rewriteOp : *rewriter.getBody()) {
703 llvm::TypeSwitch<Operation *>(&rewriteOp)
704 .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
705 pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp,
706 pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>([&](auto op) {
707 this->generateRewriter(op, rewriteValues, mapRewriteValue);
708 });
709 }
710 }
711
712 // Update the signature of the rewrite function.
713 rewriterFunc.setType(builder.getFunctionType(
714 llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
715 /*results=*/{}));
716
717 pdl_interp::FinalizeOp::create(builder, rewriter.getLoc());
718 return SymbolRefAttr::get(
719 builder.getContext(),
720 pdl_interp::PDLInterpDialect::getRewriterModuleName(),
721 SymbolRefAttr::get(rewriterFunc));
722}
723
724void PatternLowering::generateRewriter(
725 pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
726 function_ref<Value(Value)> mapRewriteValue) {
727 SmallVector<Value, 2> arguments;
728 for (Value argument : rewriteOp.getArgs())
729 arguments.push_back(mapRewriteValue(argument));
730 auto interpOp = pdl_interp::ApplyRewriteOp::create(
731 builder, rewriteOp.getLoc(), rewriteOp.getResultTypes(),
732 rewriteOp.getNameAttr(), arguments);
733 for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
734 rewriteValues[std::get<0>(it)] = std::get<1>(it);
735}
736
737void PatternLowering::generateRewriter(
738 pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
739 function_ref<Value(Value)> mapRewriteValue) {
740 Value newAttr = pdl_interp::CreateAttributeOp::create(
741 builder, attrOp.getLoc(), attrOp.getValueAttr());
742 rewriteValues[attrOp] = newAttr;
743}
744
745void PatternLowering::generateRewriter(
746 pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
747 function_ref<Value(Value)> mapRewriteValue) {
748 pdl_interp::EraseOp::create(builder, eraseOp.getLoc(),
749 mapRewriteValue(eraseOp.getOpValue()));
750}
751
752void PatternLowering::generateRewriter(
753 pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
754 function_ref<Value(Value)> mapRewriteValue) {
755 SmallVector<Value, 4> operands;
756 for (Value operand : operationOp.getOperandValues())
757 operands.push_back(mapRewriteValue(operand));
758
759 SmallVector<Value, 4> attributes;
760 for (Value attr : operationOp.getAttributeValues())
761 attributes.push_back(mapRewriteValue(attr));
762
763 bool hasInferredResultTypes = false;
764 SmallVector<Value, 2> types;
765 generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types,
766 rewriteValues, hasInferredResultTypes);
767
768 // Create the new operation.
769 Location loc = operationOp.getLoc();
770 Value createdOp = pdl_interp::CreateOperationOp::create(
771 builder, loc, *operationOp.getOpName(), types, hasInferredResultTypes,
772 operands, attributes, operationOp.getAttributeValueNames());
773 rewriteValues[operationOp.getOp()] = createdOp;
774
775 // Generate accesses for any results that have their types constrained.
776 // Handle the case where there is a single range representing all of the
777 // result types.
778 OperandRange resultTys = operationOp.getTypeValues();
779 if (resultTys.size() == 1 && isa<pdl::RangeType>(resultTys[0].getType())) {
780 Value &type = rewriteValues[resultTys[0]];
781 if (!type) {
782 auto results = pdl_interp::GetResultsOp::create(builder, loc, createdOp);
783 type = pdl_interp::GetValueTypeOp::create(builder, loc, results);
784 }
785 return;
786 }
787
788 // Otherwise, populate the individual results.
789 bool seenVariableLength = false;
790 Type valueTy = builder.getType<pdl::ValueType>();
791 Type valueRangeTy = pdl::RangeType::get(valueTy);
792 for (const auto &it : llvm::enumerate(resultTys)) {
793 Value &type = rewriteValues[it.value()];
794 if (type)
795 continue;
796 bool isVariadic = isa<pdl::RangeType>(it.value().getType());
797 seenVariableLength |= isVariadic;
798
799 // After a variable length result has been seen, we need to use result
800 // groups because the exact index of the result is not statically known.
801 Value resultVal;
802 if (seenVariableLength)
803 resultVal = pdl_interp::GetResultsOp::create(
804 builder, loc, isVariadic ? valueRangeTy : valueTy, createdOp,
805 it.index());
806 else
807 resultVal = pdl_interp::GetResultOp::create(builder, loc, valueTy,
808 createdOp, it.index());
809 type = pdl_interp::GetValueTypeOp::create(builder, loc, resultVal);
810 }
811}
812
813void PatternLowering::generateRewriter(
814 pdl::RangeOp rangeOp, DenseMap<Value, Value> &rewriteValues,
815 function_ref<Value(Value)> mapRewriteValue) {
816 SmallVector<Value, 4> replOperands;
817 for (Value operand : rangeOp.getArguments())
818 replOperands.push_back(mapRewriteValue(operand));
819 rewriteValues[rangeOp] = pdl_interp::CreateRangeOp::create(
820 builder, rangeOp.getLoc(), rangeOp.getType(), replOperands);
821}
822
823void PatternLowering::generateRewriter(
824 pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
825 function_ref<Value(Value)> mapRewriteValue) {
826 SmallVector<Value, 4> replOperands;
827
828 // If the replacement was another operation, get its results. `pdl` allows
829 // for using an operation for simplicitly, but the interpreter isn't as
830 // user facing.
831 if (Value replOp = replaceOp.getReplOperation()) {
832 // Don't use replace if we know the replaced operation has no results.
833 auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>();
834 if (!opOp || !opOp.getTypeValues().empty()) {
835 replOperands.push_back(pdl_interp::GetResultsOp::create(
836 builder, replOp.getLoc(), mapRewriteValue(replOp)));
837 }
838 } else {
839 for (Value operand : replaceOp.getReplValues())
840 replOperands.push_back(mapRewriteValue(operand));
841 }
842
843 // If there are no replacement values, just create an erase instead.
844 if (replOperands.empty()) {
845 pdl_interp::EraseOp::create(builder, replaceOp.getLoc(),
846 mapRewriteValue(replaceOp.getOpValue()));
847 return;
848 }
849
850 pdl_interp::ReplaceOp::create(builder, replaceOp.getLoc(),
851 mapRewriteValue(replaceOp.getOpValue()),
852 replOperands);
853}
854
855void PatternLowering::generateRewriter(
856 pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
857 function_ref<Value(Value)> mapRewriteValue) {
858 rewriteValues[resultOp] = pdl_interp::GetResultOp::create(
859 builder, resultOp.getLoc(), builder.getType<pdl::ValueType>(),
860 mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
861}
862
863void PatternLowering::generateRewriter(
864 pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
865 function_ref<Value(Value)> mapRewriteValue) {
866 rewriteValues[resultOp] = pdl_interp::GetResultsOp::create(
867 builder, resultOp.getLoc(), resultOp.getType(),
868 mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
869}
870
871void PatternLowering::generateRewriter(
872 pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
873 function_ref<Value(Value)> mapRewriteValue) {
874 // If the type isn't constant, the users (e.g. OperationOp) will resolve this
875 // type.
876 if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) {
877 rewriteValues[typeOp] =
878 pdl_interp::CreateTypeOp::create(builder, typeOp.getLoc(), typeAttr);
879 }
880}
881
882void PatternLowering::generateRewriter(
883 pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
884 function_ref<Value(Value)> mapRewriteValue) {
885 // If the type isn't constant, the users (e.g. OperationOp) will resolve this
886 // type.
887 if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) {
888 rewriteValues[typeOp] = pdl_interp::CreateTypesOp::create(
889 builder, typeOp.getLoc(), typeOp.getType(), typeAttr);
890 }
891}
892
893void PatternLowering::generateOperationResultTypeRewriter(
894 pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
895 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
896 bool &hasInferredResultTypes) {
897 Block *rewriterBlock = op->getBlock();
898
899 // Try to handle resolution for each of the result types individually. This is
900 // preferred over type inferrence because it will allow for us to use existing
901 // types directly, as opposed to trying to rebuild the type list.
902 OperandRange resultTypeValues = op.getTypeValues();
903 auto tryResolveResultTypes = [&] {
904 types.reserve(resultTypeValues.size());
905 for (const auto &it : llvm::enumerate(resultTypeValues)) {
906 Value resultType = it.value();
907
908 // Check for an already translated value.
909 if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
910 types.push_back(existingRewriteValue);
911 continue;
912 }
913
914 // Check for an input from the matcher.
915 if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
916 types.push_back(mapRewriteValue(resultType));
917 continue;
918 }
919
920 // Otherwise, we couldn't infer the result types. Bail out here to see if
921 // we can infer the types for this operation from another way.
922 types.clear();
923 return failure();
924 }
925 return success();
926 };
927 if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes()))
928 return;
929
930 // Otherwise, check if the operation has type inference support itself.
931 if (op.hasTypeInference()) {
932 hasInferredResultTypes = true;
933 return;
934 }
935
936 // Look for an operation that was replaced by `op`. The result types will be
937 // inferred from the results that were replaced.
938 for (OpOperand &use : op.getOp().getUses()) {
939 // Check that the use corresponds to a ReplaceOp and that it is the
940 // replacement value, not the operation being replaced.
941 pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
942 if (!replOpUser || use.getOperandNumber() == 0)
943 continue;
944 // Make sure the replaced operation was defined before this one. PDL
945 // rewrites only have single block regions, so if the op isn't in the
946 // rewriter block (i.e. the current block of the operation) we already know
947 // it dominates (i.e. it's in the matcher).
948 Value replOpVal = replOpUser.getOpValue();
949 Operation *replacedOp = replOpVal.getDefiningOp();
950 if (replacedOp->getBlock() == rewriterBlock &&
951 !replacedOp->isBeforeInBlock(op))
952 continue;
953
954 Value replacedOpResults = pdl_interp::GetResultsOp::create(
955 builder, replacedOp->getLoc(), mapRewriteValue(replOpVal));
956 types.push_back(pdl_interp::GetValueTypeOp::create(
957 builder, replacedOp->getLoc(), replacedOpResults));
958 return;
959 }
960
961 // If the types could not be inferred from any context and there weren't any
962 // explicit result types, assume the user actually meant for the operation to
963 // have no results.
964 if (resultTypeValues.empty())
965 return;
966
967 // The verifier asserts that the result types of each pdl.getOperation can be
968 // inferred. If we reach here, there is a bug either in the logic above or
969 // in the verifier for pdl.getOperation.
970 op->emitOpError() << "unable to infer result type for operation";
971 llvm_unreachable("unable to infer result type for operation");
972}
973
974//===----------------------------------------------------------------------===//
975// Conversion Pass
976//===----------------------------------------------------------------------===//
977
978namespace {
979struct PDLToPDLInterpPass
980 : public impl::ConvertPDLToPDLInterpPassBase<PDLToPDLInterpPass> {
981 PDLToPDLInterpPass() = default;
982 PDLToPDLInterpPass(const PDLToPDLInterpPass &rhs) = default;
983 PDLToPDLInterpPass(DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
984 : configMap(&configMap) {}
985 void runOnOperation() final;
986
987 /// A map containing the configuration for each pattern.
988 DenseMap<Operation *, PDLPatternConfigSet *> *configMap = nullptr;
989};
990} // namespace
991
992/// Convert the given module containing PDL pattern operations into a PDL
993/// Interpreter operations.
994void PDLToPDLInterpPass::runOnOperation() {
995 ModuleOp module = getOperation();
996
997 // Create the main matcher function This function contains all of the match
998 // related functionality from patterns in the module.
999 OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
1000 auto matcherFunc = pdl_interp::FuncOp::create(
1001 builder, module.getLoc(),
1002 pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
1003 builder.getFunctionType(builder.getType<pdl::OperationType>(),
1004 /*results=*/{}),
1005 /*attrs=*/ArrayRef<NamedAttribute>());
1006
1007 // Create a nested module to hold the functions invoked for rewriting the IR
1008 // after a successful match.
1009 ModuleOp rewriterModule =
1010 ModuleOp::create(builder, module.getLoc(),
1011 pdl_interp::PDLInterpDialect::getRewriterModuleName());
1012
1013 // Generate the code for the patterns within the module.
1014 PatternLowering generator(matcherFunc, rewriterModule, configMap);
1015 generator.lower(module);
1016
1017 // After generation, delete all of the pattern operations.
1018 for (pdl::PatternOp pattern :
1019 llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) {
1020 // Drop the now dead config mappings.
1021 if (configMap)
1022 configMap->erase(pattern);
1023
1024 pattern.erase();
1025 }
1026}
1027
1028std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertPDLToPDLInterpPass(
1030 return std::make_unique<PDLToPDLInterpPass>(configMap);
1031}
return success()
lhs
ArrayAttr()
b getContext())
static const mlir::GenInfo * generator
static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, llvm::MapVector< Qualifier *, Block * > &dests)
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
void erase()
Unlink this Block from its parent region and delete it.
Definition Block.cpp:66
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
iterator end()
Definition Block.h:154
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:76
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
Location getUnknownLoc()
Definition Builders.cpp:25
This class helps build Operations.
Definition Builders.h:207
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition Builders.h:240
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Block & emplaceBlock()
Definition Region.h:46
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Position * getPosition() const
Returns the position on which the question predicate should be checked.
std::unique_ptr< MatcherNode > & getFailureNode()
Returns the node that should be visited if this, or a subsequent node fails.
Qualifier * getQuestion() const
Returns the predicate checked on this node.
static std::unique_ptr< MatcherNode > generateMatcherTree(ModuleOp module, PredicateBuilder &builder, DenseMap< Value, Position * > &valueToPosition)
Given a module containing PDL pattern operations, generate a matcher tree using the patterns within t...
Position * getParent() const
Returns the parent position. The root operation position has no parent.
Definition Predicate.h:153
Predicates::Kind getKind() const
Returns the kind of this position.
Definition Predicate.h:156
Predicates::Kind getKind() const
Returns the kind of this qualifier.
Definition Predicate.h:427
Kind
An enumeration of the kinds of predicates.
Definition Predicate.h:44
@ OperationPos
Positions, ordered by decreasing priority.
Definition Predicate.h:46
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
std::unique_ptr<::mlir::Pass > createConvertPDLToPDLInterpPass()
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
std::unique_ptr< MatcherNode > & getSuccessNode()
Returns the node that should be visited on success.
Qualifier * getAnswer() const
Returns the expected answer of this boolean node.
pdl::PatternOp getPattern() const
Return the high level pattern operation that is matched with this node.
Value getRoot() const
Return the chosen root of the pattern.
std::pair< Qualifier *, std::unique_ptr< MatcherNode > > & getChild(unsigned i)
Returns the child at the given index.