MLIR 22.0.0git
PDLToPDLInterp.cpp
Go to the documentation of this file.
1//===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
11#include "PredicateTree.h"
14#include "mlir/Pass/Pass.h"
15#include "llvm/ADT/MapVector.h"
16#include "llvm/ADT/ScopedHashTable.h"
17#include "llvm/ADT/Sequence.h"
18#include "llvm/ADT/SmallVector.h"
19#include "llvm/ADT/TypeSwitch.h"
20
21namespace mlir {
22#define GEN_PASS_DEF_CONVERTPDLTOPDLINTERPPASS
23#include "mlir/Conversion/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27using namespace mlir::pdl_to_pdl_interp;
28
29//===----------------------------------------------------------------------===//
30// PatternLowering
31//===----------------------------------------------------------------------===//
32
33namespace {
34/// This class generators operations within the PDL Interpreter dialect from a
35/// given module containing PDL pattern operations.
36struct PatternLowering {
37public:
38 PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
40
41 /// Generate code for matching and rewriting based on the pattern operations
42 /// within the module.
43 void lower(ModuleOp module);
44
45private:
46 using ValueMap = llvm::ScopedHashTable<Position *, Value>;
47 using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
48
49 /// Generate interpreter operations for the tree rooted at the given matcher
50 /// node, in the specified region.
51 Block *generateMatcher(MatcherNode &node, Region &region,
52 Block *block = nullptr);
53
54 /// Get or create an access to the provided positional value in the current
55 /// block. This operation may mutate the provided block pointer if nested
56 /// regions (i.e., pdl_interp.iterate) are required.
57 Value getValueAt(Block *&currentBlock, Position *pos);
58
59 /// Create the interpreter predicate operations. This operation may mutate the
60 /// provided current block pointer if nested regions (iterates) are required.
61 void generate(BoolNode *boolNode, Block *&currentBlock, Value val);
62
63 /// Create the interpreter switch / predicate operations, with several case
64 /// destinations. This operation never mutates the provided current block
65 /// pointer, because the switch operation does not need Values beyond `val`.
66 void generate(SwitchNode *switchNode, Block *currentBlock, Value val);
67
68 /// Create the interpreter operations to record a successful pattern match
69 /// using the contained root operation. This operation may mutate the current
70 /// block pointer if nested regions (i.e., pdl_interp.iterate) are required.
71 void generate(SuccessNode *successNode, Block *&currentBlock);
72
73 /// Generate a rewriter function for the given pattern operation, and returns
74 /// a reference to that function.
75 SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
76 SmallVectorImpl<Position *> &usedMatchValues);
77
78 /// Generate the rewriter code for the given operation.
79 void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
80 DenseMap<Value, Value> &rewriteValues,
81 function_ref<Value(Value)> mapRewriteValue);
82 void generateRewriter(pdl::AttributeOp attrOp,
83 DenseMap<Value, Value> &rewriteValues,
84 function_ref<Value(Value)> mapRewriteValue);
85 void generateRewriter(pdl::EraseOp eraseOp,
86 DenseMap<Value, Value> &rewriteValues,
87 function_ref<Value(Value)> mapRewriteValue);
88 void generateRewriter(pdl::OperationOp operationOp,
89 DenseMap<Value, Value> &rewriteValues,
90 function_ref<Value(Value)> mapRewriteValue);
91 void generateRewriter(pdl::RangeOp rangeOp,
92 DenseMap<Value, Value> &rewriteValues,
93 function_ref<Value(Value)> mapRewriteValue);
94 void generateRewriter(pdl::ReplaceOp replaceOp,
95 DenseMap<Value, Value> &rewriteValues,
96 function_ref<Value(Value)> mapRewriteValue);
97 void generateRewriter(pdl::ResultOp resultOp,
98 DenseMap<Value, Value> &rewriteValues,
99 function_ref<Value(Value)> mapRewriteValue);
100 void generateRewriter(pdl::ResultsOp resultOp,
101 DenseMap<Value, Value> &rewriteValues,
102 function_ref<Value(Value)> mapRewriteValue);
103 void generateRewriter(pdl::TypeOp typeOp,
104 DenseMap<Value, Value> &rewriteValues,
105 function_ref<Value(Value)> mapRewriteValue);
106 void generateRewriter(pdl::TypesOp typeOp,
107 DenseMap<Value, Value> &rewriteValues,
108 function_ref<Value(Value)> mapRewriteValue);
109
110 /// Generate the values used for resolving the result types of an operation
111 /// created within a dag rewriter region. If the result types of the operation
112 /// should be inferred, `hasInferredResultTypes` is set to true.
113 void generateOperationResultTypeRewriter(
114 pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
115 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
116 bool &hasInferredResultTypes);
117
118 /// A builder to use when generating interpreter operations.
119 OpBuilder builder;
120
121 /// The matcher function used for all match related logic within PDL patterns.
122 pdl_interp::FuncOp matcherFunc;
123
124 /// The rewriter module containing the all rewrite related logic within PDL
125 /// patterns.
126 ModuleOp rewriterModule;
127
128 /// The symbol table of the rewriter module used for insertion.
129 SymbolTable rewriterSymbolTable;
130
131 /// A scoped map connecting a position with the corresponding interpreter
132 /// value.
133 ValueMap values;
134
135 /// A stack of blocks used as the failure destination for matcher nodes that
136 /// don't have an explicit failure path.
137 SmallVector<Block *, 8> failureBlockStack;
138
139 /// A mapping between values defined in a pattern match, and the corresponding
140 /// positional value.
141 DenseMap<Value, Position *> valueToPosition;
142
143 /// The set of operation values whose location will be used for newly
144 /// generated operations.
145 SetVector<Value> locOps;
146
147 /// A mapping between pattern operations and the corresponding configuration
148 /// set.
150
151 /// A mapping from a constraint question to the ApplyConstraintOp
152 /// that implements it.
154};
155} // namespace
156
157PatternLowering::PatternLowering(
158 pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
160 : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
161 rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule),
162 configMap(configMap) {}
163
164void PatternLowering::lower(ModuleOp module) {
165 PredicateUniquer predicateUniquer;
166 PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
167
168 // Define top-level scope for the arguments to the matcher function.
169 ValueMapScope topLevelValueScope(values);
170
171 // Insert the root operation, i.e. argument to the matcher, at the root
172 // position.
173 Block *matcherEntryBlock = &matcherFunc.front();
174 values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
175
176 // Generate a root matcher node from the provided PDL module.
177 std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
178 module, predicateBuilder, valueToPosition);
179 Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody());
180 assert(failureBlockStack.empty() && "failed to empty the stack");
181
182 // After generation, merged the first matched block into the entry.
183 matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
184 firstMatcherBlock->getOperations());
185 firstMatcherBlock->erase();
186}
187
188Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region,
189 Block *block) {
190 // Push a new scope for the values used by this matcher.
191 if (!block)
192 block = &region.emplaceBlock();
193 ValueMapScope scope(values);
194
195 // If this is the return node, simply insert the corresponding interpreter
196 // finalize.
197 if (isa<ExitNode>(node)) {
198 builder.setInsertionPointToEnd(block);
199 pdl_interp::FinalizeOp::create(builder, matcherFunc.getLoc());
200 return block;
201 }
202
203 // Get the next block in the match sequence.
204 // This is intentionally executed first, before we get the value for the
205 // position associated with the node, so that we preserve an "there exist"
206 // semantics: if getting a value requires an upward traversal (going from a
207 // value to its consumers), we want to perform the check on all the consumers
208 // before we pass control to the failure node.
209 std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
210 Block *failureBlock;
211 if (failureNode) {
212 failureBlock = generateMatcher(*failureNode, region);
213 failureBlockStack.push_back(failureBlock);
214 } else {
215 assert(!failureBlockStack.empty() && "expected valid failure block");
216 failureBlock = failureBlockStack.back();
217 }
218
219 // If this node contains a position, get the corresponding value for this
220 // block.
221 Block *currentBlock = block;
222 Position *position = node.getPosition();
223 Value val = position ? getValueAt(currentBlock, position) : Value();
224
225 // If this value corresponds to an operation, record that we are going to use
226 // its location as part of a fused location.
227 bool isOperationValue = val && isa<pdl::OperationType>(val.getType());
228 if (isOperationValue)
229 locOps.insert(val);
230
231 // Dispatch to the correct method based on derived node type.
233 .Case<BoolNode, SwitchNode>([&](auto *derivedNode) {
234 this->generate(derivedNode, currentBlock, val);
235 })
236 .Case([&](SuccessNode *successNode) {
237 generate(successNode, currentBlock);
238 });
239
240 // Pop all the failure blocks that were inserted due to nesting of
241 // pdl_interp.iterate.
242 while (failureBlockStack.back() != failureBlock) {
243 failureBlockStack.pop_back();
244 assert(!failureBlockStack.empty() && "unable to locate failure block");
245 }
246
247 // Pop the new failure block.
248 if (failureNode)
249 failureBlockStack.pop_back();
250
251 if (isOperationValue)
252 locOps.remove(val);
253
254 return block;
255}
256
257Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
258 if (Value val = values.lookup(pos))
259 return val;
260
261 // Get the value for the parent position.
262 Value parentVal;
263 if (Position *parent = pos->getParent())
264 parentVal = getValueAt(currentBlock, parent);
265
266 // TODO: Use a location from the position.
267 Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc();
268 builder.setInsertionPointToEnd(currentBlock);
269 Value value;
270 switch (pos->getKind()) {
272 auto *operationPos = cast<OperationPosition>(pos);
273 if (operationPos->isOperandDefiningOp())
274 // Standard (downward) traversal which directly follows the defining op.
275 value = pdl_interp::GetDefiningOpOp::create(
276 builder, loc, builder.getType<pdl::OperationType>(), parentVal);
277 else
278 // A passthrough operation position.
279 value = parentVal;
280 break;
281 }
283 auto *usersPos = cast<UsersPosition>(pos);
284
285 // The first operation retrieves the representative value of a range.
286 // This applies only when the parent is a range of values and we were
287 // requested to use a representative value (e.g., upward traversal).
288 if (isa<pdl::RangeType>(parentVal.getType()) &&
289 usersPos->useRepresentative())
290 value = pdl_interp::ExtractOp::create(builder, loc, parentVal, 0);
291 else
292 value = parentVal;
293
294 // The second operation retrieves the users.
295 value = pdl_interp::GetUsersOp::create(builder, loc, value);
296 break;
297 }
299 assert(!failureBlockStack.empty() && "expected valid failure block");
300 auto foreach = pdl_interp::ForEachOp::create(
301 builder, loc, parentVal, failureBlockStack.back(), /*initLoop=*/true);
302 value = foreach.getLoopVariable();
303
304 // Create the continuation block.
305 Block *continueBlock = builder.createBlock(&foreach.getRegion());
306 pdl_interp::ContinueOp::create(builder, loc);
307 failureBlockStack.push_back(continueBlock);
308
309 currentBlock = &foreach.getRegion().front();
310 break;
311 }
313 auto *operandPos = cast<OperandPosition>(pos);
314 value = pdl_interp::GetOperandOp::create(
315 builder, loc, builder.getType<pdl::ValueType>(), parentVal,
316 operandPos->getOperandNumber());
317 break;
318 }
320 auto *operandPos = cast<OperandGroupPosition>(pos);
321 Type valueTy = builder.getType<pdl::ValueType>();
322 value = pdl_interp::GetOperandsOp::create(
323 builder, loc,
324 operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
325 parentVal, operandPos->getOperandGroupNumber());
326 break;
327 }
329 auto *attrPos = cast<AttributePosition>(pos);
330 value = pdl_interp::GetAttributeOp::create(
331 builder, loc, builder.getType<pdl::AttributeType>(), parentVal,
332 attrPos->getName().strref());
333 break;
334 }
335 case Predicates::TypePos: {
336 if (isa<pdl::AttributeType>(parentVal.getType()))
337 value = pdl_interp::GetAttributeTypeOp::create(builder, loc, parentVal);
338 else
339 value = pdl_interp::GetValueTypeOp::create(builder, loc, parentVal);
340 break;
341 }
343 auto *resPos = cast<ResultPosition>(pos);
344 value = pdl_interp::GetResultOp::create(
345 builder, loc, builder.getType<pdl::ValueType>(), parentVal,
346 resPos->getResultNumber());
347 break;
348 }
350 auto *resPos = cast<ResultGroupPosition>(pos);
351 Type valueTy = builder.getType<pdl::ValueType>();
352 value = pdl_interp::GetResultsOp::create(
353 builder, loc,
354 resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
355 parentVal, resPos->getResultGroupNumber());
356 break;
357 }
359 auto *attrPos = cast<AttributeLiteralPosition>(pos);
360 value = pdl_interp::CreateAttributeOp::create(builder, loc,
361 attrPos->getValue());
362 break;
363 }
365 auto *typePos = cast<TypeLiteralPosition>(pos);
366 Attribute rawTypeAttr = typePos->getValue();
367 if (TypeAttr typeAttr = dyn_cast<TypeAttr>(rawTypeAttr))
368 value = pdl_interp::CreateTypeOp::create(builder, loc, typeAttr);
369 else
370 value = pdl_interp::CreateTypesOp::create(builder, loc,
371 cast<ArrayAttr>(rawTypeAttr));
372 break;
373 }
375 // Due to the order of traversal, the ApplyConstraintOp has already been
376 // created and we can find it in constraintOpMap.
377 auto *constrResPos = cast<ConstraintPosition>(pos);
378 auto i = constraintOpMap.find(constrResPos->getQuestion());
379 assert(i != constraintOpMap.end());
380 value = i->second->getResult(constrResPos->getIndex());
381 break;
382 }
383 default:
384 llvm_unreachable("Generating unknown Position getter");
385 break;
386 }
387
388 values.insert(pos, value);
389 return value;
390}
391
392void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
393 Value val) {
394 Location loc = val.getLoc();
395 Qualifier *question = boolNode->getQuestion();
396 Qualifier *answer = boolNode->getAnswer();
397 Region *region = currentBlock->getParent();
398
399 // Execute the getValue queries first, so that we create success
400 // matcher in the correct (possibly nested) region.
401 SmallVector<Value> args;
402 if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
403 args = {getValueAt(currentBlock, equalToQuestion->getValue())};
404 } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
405 for (Position *position : cstQuestion->getArgs())
406 args.push_back(getValueAt(currentBlock, position));
407 }
408
409 // Generate a new block as success successor and get the failure successor.
410 Block *success = &region->emplaceBlock();
411 Block *failure = failureBlockStack.back();
412
413 // Create the predicate.
414 builder.setInsertionPointToEnd(currentBlock);
415 Predicates::Kind kind = question->getKind();
416 switch (kind) {
418 pdl_interp::IsNotNullOp::create(builder, loc, val, success, failure);
419 break;
421 auto *opNameAnswer = cast<OperationNameAnswer>(answer);
422 pdl_interp::CheckOperationNameOp::create(
423 builder, loc, val, opNameAnswer->getValue().getStringRef(), success,
424 failure);
425 break;
426 }
428 auto *ans = cast<TypeAnswer>(answer);
429 if (isa<pdl::RangeType>(val.getType()))
430 pdl_interp::CheckTypesOp::create(builder, loc, val,
431 llvm::cast<ArrayAttr>(ans->getValue()),
432 success, failure);
433 else
434 pdl_interp::CheckTypeOp::create(builder, loc, val,
435 llvm::cast<TypeAttr>(ans->getValue()),
436 success, failure);
437 break;
438 }
440 auto *ans = cast<AttributeAnswer>(answer);
441 pdl_interp::CheckAttributeOp::create(builder, loc, val, ans->getValue(),
442 success, failure);
443 break;
444 }
447 pdl_interp::CheckOperandCountOp::create(
448 builder, loc, val, cast<UnsignedAnswer>(answer)->getValue(),
449 /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
450 success, failure);
451 break;
454 pdl_interp::CheckResultCountOp::create(
455 builder, loc, val, cast<UnsignedAnswer>(answer)->getValue(),
456 /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
457 success, failure);
458 break;
460 bool trueAnswer = isa<TrueAnswer>(answer);
461 pdl_interp::AreEqualOp::create(builder, loc, val, args.front(),
462 trueAnswer ? success : failure,
463 trueAnswer ? failure : success);
464 break;
465 }
467 auto *cstQuestion = cast<ConstraintQuestion>(question);
468 auto applyConstraintOp = pdl_interp::ApplyConstraintOp::create(
469 builder, loc, cstQuestion->getResultTypes(), cstQuestion->getName(),
470 args, cstQuestion->getIsNegated(), success, failure);
471
472 constraintOpMap.insert({cstQuestion, applyConstraintOp});
473 break;
474 }
475 default:
476 llvm_unreachable("Generating unknown Predicate operation");
477 }
478
479 // Generate the matcher in the current (potentially nested) region.
480 // This might use the results of the current predicate.
481 generateMatcher(*boolNode->getSuccessNode(), *region, success);
482}
483
484template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
485static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
486 llvm::MapVector<Qualifier *, Block *> &dests) {
487 std::vector<ValT> values;
488 std::vector<Block *> blocks;
489 values.reserve(dests.size());
490 blocks.reserve(dests.size());
491 for (const auto &it : dests) {
492 blocks.push_back(it.second);
493 values.push_back(cast<PredT>(it.first)->getValue());
494 }
495 OpT::create(builder, val.getLoc(), val, values, defaultDest, blocks);
496}
497
498void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
499 Value val) {
500 Qualifier *question = switchNode->getQuestion();
501 Region *region = currentBlock->getParent();
502 Block *defaultDest = failureBlockStack.back();
503
504 // If the switch question is not an exact answer, i.e. for the `at_least`
505 // cases, we generate a special block sequence.
506 Predicates::Kind kind = question->getKind();
509 // Order the children such that the cases are in reverse numerical order.
510 SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
511 llvm::seq<unsigned>(0, switchNode->getChildren().size()));
512 llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) {
513 return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() >
514 cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue();
515 });
516
517 // Build the destination for each child using the next highest child as a
518 // a failure destination. This essentially creates the following control
519 // flow:
520 //
521 // if (operand_count < 1)
522 // goto failure
523 // if (child1.match())
524 // ...
525 //
526 // if (operand_count < 2)
527 // goto failure
528 // if (child2.match())
529 // ...
530 //
531 // failure:
532 // ...
533 //
534 failureBlockStack.push_back(defaultDest);
535 Location loc = val.getLoc();
536 for (unsigned idx : sortedChildren) {
537 auto &child = switchNode->getChild(idx);
538 Block *childBlock = generateMatcher(*child.second, *region);
539 Block *predicateBlock = builder.createBlock(childBlock);
540 builder.setInsertionPointToEnd(predicateBlock);
541 unsigned ans = cast<UnsignedAnswer>(child.first)->getValue();
542 switch (kind) {
544 pdl_interp::CheckOperandCountOp::create(builder, loc, val, ans,
545 /*compareAtLeast=*/true,
546 childBlock, defaultDest);
547 break;
549 pdl_interp::CheckResultCountOp::create(builder, loc, val, ans,
550 /*compareAtLeast=*/true,
551 childBlock, defaultDest);
552 break;
553 default:
554 llvm_unreachable("Generating invalid AtLeast operation");
555 }
556 failureBlockStack.back() = predicateBlock;
557 }
558 Block *firstPredicateBlock = failureBlockStack.pop_back_val();
559 currentBlock->getOperations().splice(currentBlock->end(),
560 firstPredicateBlock->getOperations());
561 firstPredicateBlock->erase();
562 return;
563 }
564
565 // Otherwise, generate each of the children and generate an interpreter
566 // switch.
567 llvm::MapVector<Qualifier *, Block *> children;
568 for (auto &it : switchNode->getChildren())
569 children.insert({it.first, generateMatcher(*it.second, *region)});
570 builder.setInsertionPointToEnd(currentBlock);
571
572 switch (question->getKind()) {
574 return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
575 int32_t>(val, defaultDest, builder, children);
577 return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
578 int32_t>(val, defaultDest, builder, children);
580 return createSwitchOp<pdl_interp::SwitchOperationNameOp,
581 OperationNameAnswer>(val, defaultDest, builder,
582 children);
584 if (isa<pdl::RangeType>(val.getType())) {
586 val, defaultDest, builder, children);
587 }
589 val, defaultDest, builder, children);
592 val, defaultDest, builder, children);
593 default:
594 llvm_unreachable("Generating unknown switch predicate.");
595 }
596}
597
598void PatternLowering::generate(SuccessNode *successNode, Block *&currentBlock) {
599 pdl::PatternOp pattern = successNode->getPattern();
600 Value root = successNode->getRoot();
601
602 // Generate a rewriter for the pattern this success node represents, and track
603 // any values used from the match region.
604 SmallVector<Position *, 8> usedMatchValues;
605 SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
606
607 // Process any values used in the rewrite that are defined in the match.
608 std::vector<Value> mappedMatchValues;
609 mappedMatchValues.reserve(usedMatchValues.size());
610 for (Position *position : usedMatchValues)
611 mappedMatchValues.push_back(getValueAt(currentBlock, position));
612
613 // Collect the set of operations generated by the rewriter.
614 SmallVector<StringRef, 4> generatedOps;
615 for (auto op :
616 pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>())
617 generatedOps.push_back(*op.getOpName());
618 ArrayAttr generatedOpsAttr;
619 if (!generatedOps.empty())
620 generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
621
622 // Grab the root kind if present.
623 StringAttr rootKindAttr;
624 if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>())
625 if (std::optional<StringRef> rootKind = rootOp.getOpName())
626 rootKindAttr = builder.getStringAttr(*rootKind);
627
628 builder.setInsertionPointToEnd(currentBlock);
629 auto matchOp = pdl_interp::RecordMatchOp::create(
630 builder, pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
631 rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(),
632 failureBlockStack.back());
633
634 // Set the config of the lowered match to the parent pattern.
635 if (configMap)
636 configMap->try_emplace(matchOp, configMap->lookup(pattern));
637}
638
639SymbolRefAttr PatternLowering::generateRewriter(
640 pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
641 builder.setInsertionPointToEnd(rewriterModule.getBody());
642 auto rewriterFunc = pdl_interp::FuncOp::create(
643 builder, pattern.getLoc(), "pdl_generated_rewriter",
644 builder.getFunctionType({}, {}));
645 rewriterSymbolTable.insert(rewriterFunc);
646
647 // Generate the rewriter function body.
648 builder.setInsertionPointToEnd(&rewriterFunc.front());
649
650 // Map an input operand of the pattern to a generated interpreter value.
651 DenseMap<Value, Value> rewriteValues;
652 auto mapRewriteValue = [&](Value oldValue) {
653 Value &newValue = rewriteValues[oldValue];
654 if (newValue)
655 return newValue;
656
657 // Prefer materializing constants directly when possible.
658 Operation *oldOp = oldValue.getDefiningOp();
659 if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
660 if (Attribute value = attrOp.getValueAttr()) {
661 return newValue = pdl_interp::CreateAttributeOp::create(
662 builder, attrOp.getLoc(), value);
663 }
664 } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
665 if (TypeAttr type = typeOp.getConstantTypeAttr()) {
666 return newValue = pdl_interp::CreateTypeOp::create(
667 builder, typeOp.getLoc(), type);
668 }
669 } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
670 if (ArrayAttr type = typeOp.getConstantTypesAttr()) {
671 return newValue = pdl_interp::CreateTypesOp::create(
672 builder, typeOp.getLoc(), typeOp.getType(), type);
673 }
674 }
675
676 // Otherwise, add this as an input to the rewriter.
677 Position *inputPos = valueToPosition.lookup(oldValue);
678 assert(inputPos && "expected value to be a pattern input");
679 usedMatchValues.push_back(inputPos);
680 return newValue = rewriterFunc.front().addArgument(oldValue.getType(),
681 oldValue.getLoc());
682 };
683
684 // If this is a custom rewriter, simply dispatch to the registered rewrite
685 // method.
686 pdl::RewriteOp rewriter = pattern.getRewriter();
687 if (StringAttr rewriteName = rewriter.getNameAttr()) {
688 SmallVector<Value> args;
689 if (rewriter.getRoot())
690 args.push_back(mapRewriteValue(rewriter.getRoot()));
691 auto mappedArgs =
692 llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue);
693 args.append(mappedArgs.begin(), mappedArgs.end());
694 pdl_interp::ApplyRewriteOp::create(builder, rewriter.getLoc(),
695 /*results=*/TypeRange(), rewriteName,
696 args);
697 } else {
698 // Otherwise this is a dag rewriter defined using PDL operations.
699 for (Operation &rewriteOp : *rewriter.getBody()) {
700 llvm::TypeSwitch<Operation *>(&rewriteOp)
701 .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
702 pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp,
703 pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>([&](auto op) {
704 this->generateRewriter(op, rewriteValues, mapRewriteValue);
705 });
706 }
707 }
708
709 // Update the signature of the rewrite function.
710 rewriterFunc.setType(builder.getFunctionType(
711 llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
712 /*results=*/{}));
713
714 pdl_interp::FinalizeOp::create(builder, rewriter.getLoc());
715 return SymbolRefAttr::get(
716 builder.getContext(),
717 pdl_interp::PDLInterpDialect::getRewriterModuleName(),
718 SymbolRefAttr::get(rewriterFunc));
719}
720
721void PatternLowering::generateRewriter(
722 pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
723 function_ref<Value(Value)> mapRewriteValue) {
724 SmallVector<Value, 2> arguments;
725 for (Value argument : rewriteOp.getArgs())
726 arguments.push_back(mapRewriteValue(argument));
727 auto interpOp = pdl_interp::ApplyRewriteOp::create(
728 builder, rewriteOp.getLoc(), rewriteOp.getResultTypes(),
729 rewriteOp.getNameAttr(), arguments);
730 for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
731 rewriteValues[std::get<0>(it)] = std::get<1>(it);
732}
733
734void PatternLowering::generateRewriter(
735 pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
736 function_ref<Value(Value)> mapRewriteValue) {
737 Value newAttr = pdl_interp::CreateAttributeOp::create(
738 builder, attrOp.getLoc(), attrOp.getValueAttr());
739 rewriteValues[attrOp] = newAttr;
740}
741
742void PatternLowering::generateRewriter(
743 pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
744 function_ref<Value(Value)> mapRewriteValue) {
745 pdl_interp::EraseOp::create(builder, eraseOp.getLoc(),
746 mapRewriteValue(eraseOp.getOpValue()));
747}
748
749void PatternLowering::generateRewriter(
750 pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
751 function_ref<Value(Value)> mapRewriteValue) {
752 SmallVector<Value, 4> operands;
753 for (Value operand : operationOp.getOperandValues())
754 operands.push_back(mapRewriteValue(operand));
755
756 SmallVector<Value, 4> attributes;
757 for (Value attr : operationOp.getAttributeValues())
758 attributes.push_back(mapRewriteValue(attr));
759
760 bool hasInferredResultTypes = false;
761 SmallVector<Value, 2> types;
762 generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types,
763 rewriteValues, hasInferredResultTypes);
764
765 // Create the new operation.
766 Location loc = operationOp.getLoc();
767 Value createdOp = pdl_interp::CreateOperationOp::create(
768 builder, loc, *operationOp.getOpName(), types, hasInferredResultTypes,
769 operands, attributes, operationOp.getAttributeValueNames());
770 rewriteValues[operationOp.getOp()] = createdOp;
771
772 // Generate accesses for any results that have their types constrained.
773 // Handle the case where there is a single range representing all of the
774 // result types.
775 OperandRange resultTys = operationOp.getTypeValues();
776 if (resultTys.size() == 1 && isa<pdl::RangeType>(resultTys[0].getType())) {
777 Value &type = rewriteValues[resultTys[0]];
778 if (!type) {
779 auto results = pdl_interp::GetResultsOp::create(builder, loc, createdOp);
780 type = pdl_interp::GetValueTypeOp::create(builder, loc, results);
781 }
782 return;
783 }
784
785 // Otherwise, populate the individual results.
786 bool seenVariableLength = false;
787 Type valueTy = builder.getType<pdl::ValueType>();
788 Type valueRangeTy = pdl::RangeType::get(valueTy);
789 for (const auto &it : llvm::enumerate(resultTys)) {
790 Value &type = rewriteValues[it.value()];
791 if (type)
792 continue;
793 bool isVariadic = isa<pdl::RangeType>(it.value().getType());
794 seenVariableLength |= isVariadic;
795
796 // After a variable length result has been seen, we need to use result
797 // groups because the exact index of the result is not statically known.
798 Value resultVal;
799 if (seenVariableLength)
800 resultVal = pdl_interp::GetResultsOp::create(
801 builder, loc, isVariadic ? valueRangeTy : valueTy, createdOp,
802 it.index());
803 else
804 resultVal = pdl_interp::GetResultOp::create(builder, loc, valueTy,
805 createdOp, it.index());
806 type = pdl_interp::GetValueTypeOp::create(builder, loc, resultVal);
807 }
808}
809
810void PatternLowering::generateRewriter(
811 pdl::RangeOp rangeOp, DenseMap<Value, Value> &rewriteValues,
812 function_ref<Value(Value)> mapRewriteValue) {
813 SmallVector<Value, 4> replOperands;
814 for (Value operand : rangeOp.getArguments())
815 replOperands.push_back(mapRewriteValue(operand));
816 rewriteValues[rangeOp] = pdl_interp::CreateRangeOp::create(
817 builder, rangeOp.getLoc(), rangeOp.getType(), replOperands);
818}
819
820void PatternLowering::generateRewriter(
821 pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
822 function_ref<Value(Value)> mapRewriteValue) {
823 SmallVector<Value, 4> replOperands;
824
825 // If the replacement was another operation, get its results. `pdl` allows
826 // for using an operation for simplicitly, but the interpreter isn't as
827 // user facing.
828 if (Value replOp = replaceOp.getReplOperation()) {
829 // Don't use replace if we know the replaced operation has no results.
830 auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>();
831 if (!opOp || !opOp.getTypeValues().empty()) {
832 replOperands.push_back(pdl_interp::GetResultsOp::create(
833 builder, replOp.getLoc(), mapRewriteValue(replOp)));
834 }
835 } else {
836 for (Value operand : replaceOp.getReplValues())
837 replOperands.push_back(mapRewriteValue(operand));
838 }
839
840 // If there are no replacement values, just create an erase instead.
841 if (replOperands.empty()) {
842 pdl_interp::EraseOp::create(builder, replaceOp.getLoc(),
843 mapRewriteValue(replaceOp.getOpValue()));
844 return;
845 }
846
847 pdl_interp::ReplaceOp::create(builder, replaceOp.getLoc(),
848 mapRewriteValue(replaceOp.getOpValue()),
849 replOperands);
850}
851
852void PatternLowering::generateRewriter(
853 pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
854 function_ref<Value(Value)> mapRewriteValue) {
855 rewriteValues[resultOp] = pdl_interp::GetResultOp::create(
856 builder, resultOp.getLoc(), builder.getType<pdl::ValueType>(),
857 mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
858}
859
860void PatternLowering::generateRewriter(
861 pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
862 function_ref<Value(Value)> mapRewriteValue) {
863 rewriteValues[resultOp] = pdl_interp::GetResultsOp::create(
864 builder, resultOp.getLoc(), resultOp.getType(),
865 mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
866}
867
868void PatternLowering::generateRewriter(
869 pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
870 function_ref<Value(Value)> mapRewriteValue) {
871 // If the type isn't constant, the users (e.g. OperationOp) will resolve this
872 // type.
873 if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) {
874 rewriteValues[typeOp] =
875 pdl_interp::CreateTypeOp::create(builder, typeOp.getLoc(), typeAttr);
876 }
877}
878
879void PatternLowering::generateRewriter(
880 pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
881 function_ref<Value(Value)> mapRewriteValue) {
882 // If the type isn't constant, the users (e.g. OperationOp) will resolve this
883 // type.
884 if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) {
885 rewriteValues[typeOp] = pdl_interp::CreateTypesOp::create(
886 builder, typeOp.getLoc(), typeOp.getType(), typeAttr);
887 }
888}
889
890void PatternLowering::generateOperationResultTypeRewriter(
891 pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
892 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
893 bool &hasInferredResultTypes) {
894 Block *rewriterBlock = op->getBlock();
895
896 // Try to handle resolution for each of the result types individually. This is
897 // preferred over type inferrence because it will allow for us to use existing
898 // types directly, as opposed to trying to rebuild the type list.
899 OperandRange resultTypeValues = op.getTypeValues();
900 auto tryResolveResultTypes = [&] {
901 types.reserve(resultTypeValues.size());
902 for (const auto &it : llvm::enumerate(resultTypeValues)) {
903 Value resultType = it.value();
904
905 // Check for an already translated value.
906 if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
907 types.push_back(existingRewriteValue);
908 continue;
909 }
910
911 // Check for an input from the matcher.
912 if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
913 types.push_back(mapRewriteValue(resultType));
914 continue;
915 }
916
917 // Otherwise, we couldn't infer the result types. Bail out here to see if
918 // we can infer the types for this operation from another way.
919 types.clear();
920 return failure();
921 }
922 return success();
923 };
924 if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes()))
925 return;
926
927 // Otherwise, check if the operation has type inference support itself.
928 if (op.hasTypeInference()) {
929 hasInferredResultTypes = true;
930 return;
931 }
932
933 // Look for an operation that was replaced by `op`. The result types will be
934 // inferred from the results that were replaced.
935 for (OpOperand &use : op.getOp().getUses()) {
936 // Check that the use corresponds to a ReplaceOp and that it is the
937 // replacement value, not the operation being replaced.
938 pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
939 if (!replOpUser || use.getOperandNumber() == 0)
940 continue;
941 // Make sure the replaced operation was defined before this one. PDL
942 // rewrites only have single block regions, so if the op isn't in the
943 // rewriter block (i.e. the current block of the operation) we already know
944 // it dominates (i.e. it's in the matcher).
945 Value replOpVal = replOpUser.getOpValue();
946 Operation *replacedOp = replOpVal.getDefiningOp();
947 if (replacedOp->getBlock() == rewriterBlock &&
948 !replacedOp->isBeforeInBlock(op))
949 continue;
950
951 Value replacedOpResults = pdl_interp::GetResultsOp::create(
952 builder, replacedOp->getLoc(), mapRewriteValue(replOpVal));
953 types.push_back(pdl_interp::GetValueTypeOp::create(
954 builder, replacedOp->getLoc(), replacedOpResults));
955 return;
956 }
957
958 // If the types could not be inferred from any context and there weren't any
959 // explicit result types, assume the user actually meant for the operation to
960 // have no results.
961 if (resultTypeValues.empty())
962 return;
963
964 // The verifier asserts that the result types of each pdl.getOperation can be
965 // inferred. If we reach here, there is a bug either in the logic above or
966 // in the verifier for pdl.getOperation.
967 op->emitOpError() << "unable to infer result type for operation";
968 llvm_unreachable("unable to infer result type for operation");
969}
970
971//===----------------------------------------------------------------------===//
972// Conversion Pass
973//===----------------------------------------------------------------------===//
974
975namespace {
976struct PDLToPDLInterpPass
977 : public impl::ConvertPDLToPDLInterpPassBase<PDLToPDLInterpPass> {
978 PDLToPDLInterpPass() = default;
979 PDLToPDLInterpPass(const PDLToPDLInterpPass &rhs) = default;
980 PDLToPDLInterpPass(DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
981 : configMap(&configMap) {}
982 void runOnOperation() final;
983
984 /// A map containing the configuration for each pattern.
985 DenseMap<Operation *, PDLPatternConfigSet *> *configMap = nullptr;
986};
987} // namespace
988
989/// Convert the given module containing PDL pattern operations into a PDL
990/// Interpreter operations.
991void PDLToPDLInterpPass::runOnOperation() {
992 ModuleOp module = getOperation();
993
994 // Create the main matcher function This function contains all of the match
995 // related functionality from patterns in the module.
996 OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
997 auto matcherFunc = pdl_interp::FuncOp::create(
998 builder, module.getLoc(),
999 pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
1000 builder.getFunctionType(builder.getType<pdl::OperationType>(),
1001 /*results=*/{}),
1002 /*attrs=*/ArrayRef<NamedAttribute>());
1003
1004 // Create a nested module to hold the functions invoked for rewriting the IR
1005 // after a successful match.
1006 ModuleOp rewriterModule =
1007 ModuleOp::create(builder, module.getLoc(),
1008 pdl_interp::PDLInterpDialect::getRewriterModuleName());
1009
1010 // Generate the code for the patterns within the module.
1011 PatternLowering generator(matcherFunc, rewriterModule, configMap);
1012 generator.lower(module);
1013
1014 // After generation, delete all of the pattern operations.
1015 for (pdl::PatternOp pattern :
1016 llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) {
1017 // Drop the now dead config mappings.
1018 if (configMap)
1019 configMap->erase(pattern);
1020
1021 pattern.erase();
1022 }
1023}
1024
1025std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertPDLToPDLInterpPass(
1027 return std::make_unique<PDLToPDLInterpPass>(configMap);
1028}
return success()
lhs
ArrayAttr()
b getContext())
static const mlir::GenInfo * generator
static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, llvm::MapVector< Qualifier *, Block * > &dests)
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
void erase()
Unlink this Block from its parent region and delete it.
Definition Block.cpp:66
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition Block.cpp:27
OpListType & getOperations()
Definition Block.h:137
Operation & front()
Definition Block.h:153
iterator end()
Definition Block.h:144
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:76
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
Location getUnknownLoc()
Definition Builders.cpp:25
This class helps build Operations.
Definition Builders.h:207
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition Builders.h:240
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Block & emplaceBlock()
Definition Region.h:46
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Position * getPosition() const
Returns the position on which the question predicate should be checked.
std::unique_ptr< MatcherNode > & getFailureNode()
Returns the node that should be visited if this, or a subsequent node fails.
Qualifier * getQuestion() const
Returns the predicate checked on this node.
static std::unique_ptr< MatcherNode > generateMatcherTree(ModuleOp module, PredicateBuilder &builder, DenseMap< Value, Position * > &valueToPosition)
Given a module containing PDL pattern operations, generate a matcher tree using the patterns within t...
Position * getParent() const
Returns the parent position. The root operation position has no parent.
Definition Predicate.h:153
Predicates::Kind getKind() const
Returns the kind of this position.
Definition Predicate.h:156
Predicates::Kind getKind() const
Returns the kind of this qualifier.
Definition Predicate.h:427
Kind
An enumeration of the kinds of predicates.
Definition Predicate.h:44
@ OperationPos
Positions, ordered by decreasing priority.
Definition Predicate.h:46
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
std::unique_ptr<::mlir::Pass > createConvertPDLToPDLInterpPass()
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
std::unique_ptr< MatcherNode > & getSuccessNode()
Returns the node that should be visited on success.
Qualifier * getAnswer() const
Returns the expected answer of this boolean node.
pdl::PatternOp getPattern() const
Return the high level pattern operation that is matched with this node.
Value getRoot() const
Return the chosen root of the pattern.
std::pair< Qualifier *, std::unique_ptr< MatcherNode > > & getChild(unsigned i)
Returns the child at the given index.