15#include "llvm/ADT/MapVector.h"
16#include "llvm/ADT/ScopedHashTable.h"
17#include "llvm/ADT/Sequence.h"
18#include "llvm/ADT/SmallVector.h"
19#include "llvm/ADT/TypeSwitch.h"
22#define GEN_PASS_DEF_CONVERTPDLTOPDLINTERPPASS
23#include "mlir/Conversion/Passes.h.inc"
36struct PatternLowering {
38 PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
43 void lower(ModuleOp module);
46 using ValueMap = llvm::ScopedHashTable<Position *, Value>;
47 using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
51 Block *generateMatcher(MatcherNode &node, Region ®ion,
52 Block *block =
nullptr);
57 Value getValueAt(
Block *¤tBlock, Position *pos);
61 void generate(BoolNode *boolNode,
Block *¤tBlock, Value val);
66 void generate(SwitchNode *switchNode,
Block *currentBlock, Value val);
71 void generate(SuccessNode *successNode,
Block *¤tBlock);
75 SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
76 SmallVectorImpl<Position *> &usedMatchValues);
79 void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
80 DenseMap<Value, Value> &rewriteValues,
82 void generateRewriter(pdl::AttributeOp attrOp,
83 DenseMap<Value, Value> &rewriteValues,
85 void generateRewriter(pdl::EraseOp eraseOp,
86 DenseMap<Value, Value> &rewriteValues,
88 void generateRewriter(pdl::OperationOp operationOp,
89 DenseMap<Value, Value> &rewriteValues,
91 void generateRewriter(pdl::RangeOp rangeOp,
92 DenseMap<Value, Value> &rewriteValues,
94 void generateRewriter(pdl::ReplaceOp replaceOp,
95 DenseMap<Value, Value> &rewriteValues,
97 void generateRewriter(pdl::ResultOp resultOp,
98 DenseMap<Value, Value> &rewriteValues,
100 void generateRewriter(pdl::ResultsOp resultOp,
101 DenseMap<Value, Value> &rewriteValues,
103 void generateRewriter(pdl::TypeOp typeOp,
104 DenseMap<Value, Value> &rewriteValues,
106 void generateRewriter(pdl::TypesOp typeOp,
107 DenseMap<Value, Value> &rewriteValues,
113 void generateOperationResultTypeRewriter(
114 pdl::OperationOp op,
function_ref<Value(Value)> mapRewriteValue,
115 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
116 bool &hasInferredResultTypes);
122 pdl_interp::FuncOp matcherFunc;
126 ModuleOp rewriterModule;
129 SymbolTable rewriterSymbolTable;
137 SmallVector<Block *, 8> failureBlockStack;
157PatternLowering::PatternLowering(
158 pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
160 : builder(matcherFunc.
getContext()), matcherFunc(matcherFunc),
161 rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule),
162 configMap(configMap) {}
164void PatternLowering::lower(ModuleOp module) {
165 PredicateUniquer predicateUniquer;
166 PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
169 ValueMapScope topLevelValueScope(values);
173 Block *matcherEntryBlock = &matcherFunc.front();
174 values.insert(predicateBuilder.getRoot(), matcherEntryBlock->
getArgument(0));
178 module, predicateBuilder, valueToPosition);
179 Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody());
180 assert(failureBlockStack.empty() &&
"failed to empty the stack");
185 firstMatcherBlock->
erase();
188Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion,
193 ValueMapScope scope(values);
197 if (isa<ExitNode>(node)) {
199 pdl_interp::FinalizeOp::create(builder, matcherFunc.getLoc());
209 std::unique_ptr<MatcherNode> &failureNode = node.
getFailureNode();
212 failureBlock = generateMatcher(*failureNode, region);
213 failureBlockStack.push_back(failureBlock);
215 assert(!failureBlockStack.empty() &&
"expected valid failure block");
216 failureBlock = failureBlockStack.back();
221 Block *currentBlock = block;
223 Value val = position ? getValueAt(currentBlock, position) : Value();
227 bool isOperationValue = val && isa<pdl::OperationType>(val.
getType());
228 if (isOperationValue)
233 .Case<BoolNode, SwitchNode>([&](
auto *derivedNode) {
234 this->generate(derivedNode, currentBlock, val);
236 .Case([&](SuccessNode *successNode) {
237 generate(successNode, currentBlock);
242 while (failureBlockStack.back() != failureBlock) {
243 failureBlockStack.pop_back();
244 assert(!failureBlockStack.empty() &&
"unable to locate failure block");
249 failureBlockStack.pop_back();
251 if (isOperationValue)
257Value PatternLowering::getValueAt(
Block *¤tBlock, Position *pos) {
258 if (Value val = values.lookup(pos))
264 parentVal = getValueAt(currentBlock, parent);
272 auto *operationPos = cast<OperationPosition>(pos);
273 if (operationPos->isOperandDefiningOp())
275 value = pdl_interp::GetDefiningOpOp::create(
276 builder, loc, builder.
getType<pdl::OperationType>(), parentVal);
283 auto *usersPos = cast<UsersPosition>(pos);
288 if (isa<pdl::RangeType>(parentVal.
getType()) &&
289 usersPos->useRepresentative())
290 value = pdl_interp::ExtractOp::create(builder, loc, parentVal, 0);
295 value = pdl_interp::GetUsersOp::create(builder, loc, value);
299 assert(!failureBlockStack.empty() &&
"expected valid failure block");
300 auto foreach = pdl_interp::ForEachOp::create(
301 builder, loc, parentVal, failureBlockStack.back(),
true);
302 value =
foreach.getLoopVariable();
306 pdl_interp::ContinueOp::create(builder, loc);
307 failureBlockStack.push_back(continueBlock);
309 currentBlock = &
foreach.getRegion().
front();
313 auto *operandPos = cast<OperandPosition>(pos);
314 value = pdl_interp::GetOperandOp::create(
315 builder, loc, builder.
getType<pdl::ValueType>(), parentVal,
316 operandPos->getOperandNumber());
320 auto *operandPos = cast<OperandGroupPosition>(pos);
321 Type valueTy = builder.
getType<pdl::ValueType>();
322 value = pdl_interp::GetOperandsOp::create(
324 operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
325 parentVal, operandPos->getOperandGroupNumber());
329 auto *attrPos = cast<AttributePosition>(pos);
330 value = pdl_interp::GetAttributeOp::create(
331 builder, loc, builder.
getType<pdl::AttributeType>(), parentVal,
332 attrPos->getName().strref());
336 if (isa<pdl::AttributeType>(parentVal.
getType()))
337 value = pdl_interp::GetAttributeTypeOp::create(builder, loc, parentVal);
339 value = pdl_interp::GetValueTypeOp::create(builder, loc, parentVal);
343 auto *resPos = cast<ResultPosition>(pos);
344 value = pdl_interp::GetResultOp::create(
345 builder, loc, builder.
getType<pdl::ValueType>(), parentVal,
346 resPos->getResultNumber());
350 auto *resPos = cast<ResultGroupPosition>(pos);
351 Type valueTy = builder.
getType<pdl::ValueType>();
352 value = pdl_interp::GetResultsOp::create(
354 resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
355 parentVal, resPos->getResultGroupNumber());
359 auto *attrPos = cast<AttributeLiteralPosition>(pos);
360 value = pdl_interp::CreateAttributeOp::create(builder, loc,
361 attrPos->getValue());
365 auto *typePos = cast<TypeLiteralPosition>(pos);
366 Attribute rawTypeAttr = typePos->getValue();
367 if (TypeAttr typeAttr = dyn_cast<TypeAttr>(rawTypeAttr))
368 value = pdl_interp::CreateTypeOp::create(builder, loc, typeAttr);
370 value = pdl_interp::CreateTypesOp::create(builder, loc,
371 cast<ArrayAttr>(rawTypeAttr));
377 auto *constrResPos = cast<ConstraintPosition>(pos);
378 auto i = constraintOpMap.find(constrResPos->getQuestion());
379 assert(i != constraintOpMap.end());
380 value = i->second->getResult(constrResPos->getIndex());
384 llvm_unreachable(
"Generating unknown Position getter");
388 values.insert(pos, value);
392void PatternLowering::generate(BoolNode *boolNode,
Block *¤tBlock,
394 Location loc = val.
getLoc();
396 Qualifier *answer = boolNode->
getAnswer();
397 Region *region = currentBlock->
getParent();
401 SmallVector<Value> args;
402 if (
auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
403 args = {getValueAt(currentBlock, equalToQuestion->getValue())};
404 }
else if (
auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
405 for (Position *position : cstQuestion->getArgs())
406 args.push_back(getValueAt(currentBlock, position));
411 Block *failure = failureBlockStack.back();
418 pdl_interp::IsNotNullOp::create(builder, loc, val,
success, failure);
421 auto *opNameAnswer = cast<OperationNameAnswer>(answer);
422 pdl_interp::CheckOperationNameOp::create(
423 builder, loc, val, opNameAnswer->getValue().getStringRef(),
success,
428 auto *ans = cast<TypeAnswer>(answer);
429 if (isa<pdl::RangeType>(val.
getType()))
430 pdl_interp::CheckTypesOp::create(builder, loc, val,
431 llvm::cast<ArrayAttr>(ans->getValue()),
434 pdl_interp::CheckTypeOp::create(builder, loc, val,
435 llvm::cast<TypeAttr>(ans->getValue()),
440 auto *ans = cast<AttributeAnswer>(answer);
441 pdl_interp::CheckAttributeOp::create(builder, loc, val, ans->getValue(),
447 pdl_interp::CheckOperandCountOp::create(
448 builder, loc, val, cast<UnsignedAnswer>(answer)->getValue(),
454 pdl_interp::CheckResultCountOp::create(
455 builder, loc, val, cast<UnsignedAnswer>(answer)->getValue(),
460 bool trueAnswer = isa<TrueAnswer>(answer);
461 pdl_interp::AreEqualOp::create(builder, loc, val, args.front(),
462 trueAnswer ?
success : failure,
463 trueAnswer ? failure :
success);
467 auto *cstQuestion = cast<ConstraintQuestion>(question);
468 auto applyConstraintOp = pdl_interp::ApplyConstraintOp::create(
469 builder, loc, cstQuestion->getResultTypes(), cstQuestion->getName(),
470 args, cstQuestion->getIsNegated(),
success, failure);
472 constraintOpMap.insert({cstQuestion, applyConstraintOp});
476 llvm_unreachable(
"Generating unknown Predicate operation");
484template <
typename OpT,
typename PredT,
typename ValT =
typename PredT::KeyTy>
486 llvm::MapVector<Qualifier *, Block *> &dests) {
487 std::vector<ValT> values;
488 std::vector<Block *> blocks;
489 values.reserve(dests.size());
490 blocks.reserve(dests.size());
491 for (
const auto &it : dests) {
492 blocks.push_back(it.second);
493 values.push_back(cast<PredT>(it.first)->getValue());
495 OpT::create(builder, val.
getLoc(), val, values, defaultDest, blocks);
498void PatternLowering::generate(SwitchNode *switchNode,
Block *currentBlock,
501 Region *region = currentBlock->
getParent();
502 Block *defaultDest = failureBlockStack.back();
510 SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
511 llvm::seq<unsigned>(0, switchNode->
getChildren().size()));
512 llvm::sort(sortedChildren, [&](
unsigned lhs,
unsigned rhs) {
513 return cast<UnsignedAnswer>(switchNode->
getChild(
lhs).first)->getValue() >
514 cast<UnsignedAnswer>(switchNode->
getChild(
rhs).first)->getValue();
534 failureBlockStack.push_back(defaultDest);
535 Location loc = val.
getLoc();
536 for (
unsigned idx : sortedChildren) {
537 auto &child = switchNode->
getChild(idx);
538 Block *childBlock = generateMatcher(*child.second, *region);
541 unsigned ans = cast<UnsignedAnswer>(child.first)->getValue();
544 pdl_interp::CheckOperandCountOp::create(builder, loc, val, ans,
546 childBlock, defaultDest);
549 pdl_interp::CheckResultCountOp::create(builder, loc, val, ans,
551 childBlock, defaultDest);
554 llvm_unreachable(
"Generating invalid AtLeast operation");
556 failureBlockStack.back() = predicateBlock;
558 Block *firstPredicateBlock = failureBlockStack.pop_back_val();
561 firstPredicateBlock->
erase();
567 llvm::MapVector<Qualifier *, Block *> children;
569 children.insert({it.first, generateMatcher(*it.second, *region)});
575 int32_t>(val, defaultDest, builder, children);
578 int32_t>(val, defaultDest, builder, children);
584 if (isa<pdl::RangeType>(val.getType())) {
586 val, defaultDest, builder, children);
589 val, defaultDest, builder, children);
592 val, defaultDest, builder, children);
594 llvm_unreachable(
"Generating unknown switch predicate.");
598void PatternLowering::generate(SuccessNode *successNode,
Block *¤tBlock) {
599 pdl::PatternOp pattern = successNode->
getPattern();
600 Value root = successNode->
getRoot();
604 SmallVector<Position *, 8> usedMatchValues;
605 SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
608 std::vector<Value> mappedMatchValues;
609 mappedMatchValues.reserve(usedMatchValues.size());
610 for (Position *position : usedMatchValues)
611 mappedMatchValues.push_back(getValueAt(currentBlock, position));
614 SmallVector<StringRef, 4> generatedOps;
616 pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>())
617 generatedOps.push_back(*op.getOpName());
619 if (!generatedOps.empty())
620 generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
623 StringAttr rootKindAttr;
624 if (pdl::OperationOp rootOp = root.
getDefiningOp<pdl::OperationOp>())
625 if (std::optional<StringRef> rootKind = rootOp.getOpName())
626 rootKindAttr = builder.getStringAttr(*rootKind);
628 builder.setInsertionPointToEnd(currentBlock);
629 auto matchOp = pdl_interp::RecordMatchOp::create(
630 builder, pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
631 rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(),
632 failureBlockStack.back());
636 configMap->try_emplace(matchOp, configMap->lookup(pattern));
639SymbolRefAttr PatternLowering::generateRewriter(
640 pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
641 builder.setInsertionPointToEnd(rewriterModule.getBody());
642 auto rewriterFunc = pdl_interp::FuncOp::create(
643 builder, pattern.getLoc(),
"pdl_generated_rewriter",
644 builder.getFunctionType({}, {}));
645 rewriterSymbolTable.
insert(rewriterFunc);
648 builder.setInsertionPointToEnd(&rewriterFunc.front());
651 DenseMap<Value, Value> rewriteValues;
652 auto mapRewriteValue = [&](Value oldValue) {
653 Value &newValue = rewriteValues[oldValue];
659 if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
660 if (Attribute value = attrOp.getValueAttr()) {
661 return newValue = pdl_interp::CreateAttributeOp::create(
662 builder, attrOp.getLoc(), value);
664 }
else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
665 if (TypeAttr type = typeOp.getConstantTypeAttr()) {
666 return newValue = pdl_interp::CreateTypeOp::create(
667 builder, typeOp.getLoc(), type);
669 }
else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
670 if (
ArrayAttr type = typeOp.getConstantTypesAttr()) {
671 return newValue = pdl_interp::CreateTypesOp::create(
672 builder, typeOp.getLoc(), typeOp.getType(), type);
677 Position *inputPos = valueToPosition.lookup(oldValue);
678 assert(inputPos &&
"expected value to be a pattern input");
679 usedMatchValues.push_back(inputPos);
680 return newValue = rewriterFunc.front().addArgument(oldValue.getType(),
686 pdl::RewriteOp rewriter = pattern.getRewriter();
687 if (StringAttr rewriteName = rewriter.getNameAttr()) {
688 SmallVector<Value> args;
689 if (rewriter.getRoot())
690 args.push_back(mapRewriteValue(rewriter.getRoot()));
692 llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue);
693 args.append(mappedArgs.begin(), mappedArgs.end());
694 pdl_interp::ApplyRewriteOp::create(builder, rewriter.getLoc(),
699 for (Operation &rewriteOp : *rewriter.getBody()) {
700 llvm::TypeSwitch<Operation *>(&rewriteOp)
701 .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
702 pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp,
703 pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>([&](
auto op) {
704 this->generateRewriter(op, rewriteValues, mapRewriteValue);
710 rewriterFunc.setType(builder.getFunctionType(
711 llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
714 pdl_interp::FinalizeOp::create(builder, rewriter.getLoc());
715 return SymbolRefAttr::get(
716 builder.getContext(),
717 pdl_interp::PDLInterpDialect::getRewriterModuleName(),
718 SymbolRefAttr::get(rewriterFunc));
721void PatternLowering::generateRewriter(
722 pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
724 SmallVector<Value, 2> arguments;
725 for (Value argument : rewriteOp.getArgs())
726 arguments.push_back(mapRewriteValue(argument));
727 auto interpOp = pdl_interp::ApplyRewriteOp::create(
728 builder, rewriteOp.getLoc(), rewriteOp.getResultTypes(),
729 rewriteOp.getNameAttr(), arguments);
730 for (
auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
731 rewriteValues[std::get<0>(it)] = std::get<1>(it);
734void PatternLowering::generateRewriter(
735 pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
737 Value newAttr = pdl_interp::CreateAttributeOp::create(
738 builder, attrOp.getLoc(), attrOp.getValueAttr());
739 rewriteValues[attrOp] = newAttr;
742void PatternLowering::generateRewriter(
743 pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
745 pdl_interp::EraseOp::create(builder, eraseOp.getLoc(),
746 mapRewriteValue(eraseOp.getOpValue()));
749void PatternLowering::generateRewriter(
750 pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
752 SmallVector<Value, 4> operands;
753 for (Value operand : operationOp.getOperandValues())
754 operands.push_back(mapRewriteValue(operand));
756 SmallVector<Value, 4> attributes;
757 for (Value attr : operationOp.getAttributeValues())
758 attributes.push_back(mapRewriteValue(attr));
760 bool hasInferredResultTypes =
false;
761 SmallVector<Value, 2> types;
762 generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types,
763 rewriteValues, hasInferredResultTypes);
766 Location loc = operationOp.getLoc();
767 Value createdOp = pdl_interp::CreateOperationOp::create(
768 builder, loc, *operationOp.getOpName(), types, hasInferredResultTypes,
769 operands, attributes, operationOp.getAttributeValueNames());
770 rewriteValues[operationOp.getOp()] = createdOp;
775 OperandRange resultTys = operationOp.getTypeValues();
776 if (resultTys.size() == 1 && isa<pdl::RangeType>(resultTys[0].getType())) {
777 Value &type = rewriteValues[resultTys[0]];
779 auto results = pdl_interp::GetResultsOp::create(builder, loc, createdOp);
780 type = pdl_interp::GetValueTypeOp::create(builder, loc, results);
786 bool seenVariableLength =
false;
787 Type valueTy = builder.
getType<pdl::ValueType>();
788 Type valueRangeTy = pdl::RangeType::get(valueTy);
789 for (
const auto &it : llvm::enumerate(resultTys)) {
790 Value &type = rewriteValues[it.value()];
793 bool isVariadic = isa<pdl::RangeType>(it.value().getType());
794 seenVariableLength |= isVariadic;
799 if (seenVariableLength)
800 resultVal = pdl_interp::GetResultsOp::create(
801 builder, loc, isVariadic ? valueRangeTy : valueTy, createdOp,
804 resultVal = pdl_interp::GetResultOp::create(builder, loc, valueTy,
805 createdOp, it.index());
806 type = pdl_interp::GetValueTypeOp::create(builder, loc, resultVal);
810void PatternLowering::generateRewriter(
811 pdl::RangeOp rangeOp, DenseMap<Value, Value> &rewriteValues,
813 SmallVector<Value, 4> replOperands;
814 for (Value operand : rangeOp.getArguments())
815 replOperands.push_back(mapRewriteValue(operand));
816 rewriteValues[rangeOp] = pdl_interp::CreateRangeOp::create(
817 builder, rangeOp.getLoc(), rangeOp.getType(), replOperands);
820void PatternLowering::generateRewriter(
821 pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
823 SmallVector<Value, 4> replOperands;
828 if (Value replOp = replaceOp.getReplOperation()) {
830 auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>();
831 if (!opOp || !opOp.getTypeValues().empty()) {
832 replOperands.push_back(pdl_interp::GetResultsOp::create(
833 builder, replOp.getLoc(), mapRewriteValue(replOp)));
836 for (Value operand : replaceOp.getReplValues())
837 replOperands.push_back(mapRewriteValue(operand));
841 if (replOperands.empty()) {
842 pdl_interp::EraseOp::create(builder, replaceOp.getLoc(),
843 mapRewriteValue(replaceOp.getOpValue()));
847 pdl_interp::ReplaceOp::create(builder, replaceOp.getLoc(),
848 mapRewriteValue(replaceOp.getOpValue()),
852void PatternLowering::generateRewriter(
853 pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
855 rewriteValues[resultOp] = pdl_interp::GetResultOp::create(
856 builder, resultOp.getLoc(), builder.getType<pdl::ValueType>(),
857 mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
860void PatternLowering::generateRewriter(
861 pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
863 rewriteValues[resultOp] = pdl_interp::GetResultsOp::create(
864 builder, resultOp.getLoc(), resultOp.getType(),
865 mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
868void PatternLowering::generateRewriter(
869 pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
873 if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) {
874 rewriteValues[typeOp] =
875 pdl_interp::CreateTypeOp::create(builder, typeOp.getLoc(), typeAttr);
879void PatternLowering::generateRewriter(
880 pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
884 if (
ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) {
885 rewriteValues[typeOp] = pdl_interp::CreateTypesOp::create(
886 builder, typeOp.getLoc(), typeOp.getType(), typeAttr);
890void PatternLowering::generateOperationResultTypeRewriter(
891 pdl::OperationOp op,
function_ref<Value(Value)> mapRewriteValue,
892 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
893 bool &hasInferredResultTypes) {
894 Block *rewriterBlock = op->getBlock();
899 OperandRange resultTypeValues = op.getTypeValues();
900 auto tryResolveResultTypes = [&] {
901 types.reserve(resultTypeValues.size());
902 for (
const auto &it : llvm::enumerate(resultTypeValues)) {
903 Value resultType = it.value();
906 if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
907 types.push_back(existingRewriteValue);
913 types.push_back(mapRewriteValue(resultType));
924 if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes()))
928 if (op.hasTypeInference()) {
929 hasInferredResultTypes =
true;
935 for (OpOperand &use : op.getOp().getUses()) {
938 pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
939 if (!replOpUser || use.getOperandNumber() == 0)
945 Value replOpVal = replOpUser.getOpValue();
947 if (replacedOp->
getBlock() == rewriterBlock &&
951 Value replacedOpResults = pdl_interp::GetResultsOp::create(
952 builder, replacedOp->
getLoc(), mapRewriteValue(replOpVal));
953 types.push_back(pdl_interp::GetValueTypeOp::create(
954 builder, replacedOp->
getLoc(), replacedOpResults));
961 if (resultTypeValues.empty())
967 op->emitOpError() <<
"unable to infer result type for operation";
968 llvm_unreachable(
"unable to infer result type for operation");
976struct PDLToPDLInterpPass
977 :
public impl::ConvertPDLToPDLInterpPassBase<PDLToPDLInterpPass> {
978 PDLToPDLInterpPass() =
default;
979 PDLToPDLInterpPass(
const PDLToPDLInterpPass &
rhs) =
default;
981 : configMap(&configMap) {}
982 void runOnOperation() final;
985 DenseMap<Operation *, PDLPatternConfigSet *> *configMap =
nullptr;
991void PDLToPDLInterpPass::runOnOperation() {
992 ModuleOp module = getOperation();
997 auto matcherFunc = pdl_interp::FuncOp::create(
998 builder, module.getLoc(),
999 pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
1002 ArrayRef<NamedAttribute>());
1006 ModuleOp rewriterModule =
1007 ModuleOp::create(builder, module.getLoc(),
1008 pdl_interp::PDLInterpDialect::getRewriterModuleName());
1011 PatternLowering
generator(matcherFunc, rewriterModule, configMap);
1015 for (pdl::PatternOp pattern :
1016 llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) {
1019 configMap->erase(pattern);
1027 return std::make_unique<PDLToPDLInterpPass>(configMap);
static const mlir::GenInfo * generator
static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, llvm::MapVector< Qualifier *, Block * > &dests)
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
void erase()
Unlink this Block from its parent region and delete it.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
This class helps build Operations.
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
Location getLoc()
The source location the operation was defined or derived from.
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Position * getPosition() const
Returns the position on which the question predicate should be checked.
std::unique_ptr< MatcherNode > & getFailureNode()
Returns the node that should be visited if this, or a subsequent node fails.
Qualifier * getQuestion() const
Returns the predicate checked on this node.
static std::unique_ptr< MatcherNode > generateMatcherTree(ModuleOp module, PredicateBuilder &builder, DenseMap< Value, Position * > &valueToPosition)
Given a module containing PDL pattern operations, generate a matcher tree using the patterns within t...
Position * getParent() const
Returns the parent position. The root operation position has no parent.
Predicates::Kind getKind() const
Returns the kind of this position.
Predicates::Kind getKind() const
Returns the kind of this qualifier.
Kind
An enumeration of the kinds of predicates.
@ ResultCountAtLeastQuestion
@ OperationPos
Positions, ordered by decreasing priority.
@ OperandCountAtLeastQuestion
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
llvm::TypeSwitch< T, ResultT > TypeSwitch
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
std::unique_ptr<::mlir::Pass > createConvertPDLToPDLInterpPass()
llvm::function_ref< Fn > function_ref
std::unique_ptr< MatcherNode > & getSuccessNode()
Returns the node that should be visited on success.
Qualifier * getAnswer() const
Returns the expected answer of this boolean node.
pdl::PatternOp getPattern() const
Return the high level pattern operation that is matched with this node.
Value getRoot() const
Return the chosen root of the pattern.
std::pair< Qualifier *, std::unique_ptr< MatcherNode > > & getChild(unsigned i)
Returns the child at the given index.
ChildMapT & getChildren()