15#include "llvm/ADT/MapVector.h"
16#include "llvm/ADT/ScopedHashTable.h"
17#include "llvm/ADT/Sequence.h"
18#include "llvm/ADT/SmallVector.h"
19#include "llvm/ADT/TypeSwitch.h"
22#define GEN_PASS_DEF_CONVERTPDLTOPDLINTERPPASS
23#include "mlir/Conversion/Passes.h.inc"
36struct PatternLowering {
38 PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
43 void lower(ModuleOp module);
46 using ValueMap = llvm::ScopedHashTable<Position *, Value>;
47 using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
51 Block *generateMatcher(MatcherNode &node, Region ®ion,
52 Block *block =
nullptr);
57 Value getValueAt(
Block *¤tBlock, Position *pos);
61 void generate(BoolNode *boolNode,
Block *¤tBlock, Value val);
66 void generate(SwitchNode *switchNode,
Block *currentBlock, Value val);
71 void generate(SuccessNode *successNode,
Block *¤tBlock);
75 SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
76 SmallVectorImpl<Position *> &usedMatchValues);
79 void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
80 DenseMap<Value, Value> &rewriteValues,
82 void generateRewriter(pdl::AttributeOp attrOp,
83 DenseMap<Value, Value> &rewriteValues,
85 void generateRewriter(pdl::EraseOp eraseOp,
86 DenseMap<Value, Value> &rewriteValues,
88 void generateRewriter(pdl::OperationOp operationOp,
89 DenseMap<Value, Value> &rewriteValues,
91 void generateRewriter(pdl::RangeOp rangeOp,
92 DenseMap<Value, Value> &rewriteValues,
94 void generateRewriter(pdl::ReplaceOp replaceOp,
95 DenseMap<Value, Value> &rewriteValues,
97 void generateRewriter(pdl::ResultOp resultOp,
98 DenseMap<Value, Value> &rewriteValues,
100 void generateRewriter(pdl::ResultsOp resultOp,
101 DenseMap<Value, Value> &rewriteValues,
103 void generateRewriter(pdl::TypeOp typeOp,
104 DenseMap<Value, Value> &rewriteValues,
106 void generateRewriter(pdl::TypesOp typeOp,
107 DenseMap<Value, Value> &rewriteValues,
113 void generateOperationResultTypeRewriter(
114 pdl::OperationOp op,
function_ref<Value(Value)> mapRewriteValue,
115 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
116 bool &hasInferredResultTypes);
122 pdl_interp::FuncOp matcherFunc;
126 ModuleOp rewriterModule;
129 SymbolTable rewriterSymbolTable;
137 SmallVector<Block *, 8> failureBlockStack;
157PatternLowering::PatternLowering(
158 pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
160 : builder(matcherFunc.
getContext()), matcherFunc(matcherFunc),
161 rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule),
162 configMap(configMap) {}
164void PatternLowering::lower(ModuleOp module) {
165 PredicateUniquer predicateUniquer;
166 PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
169 ValueMapScope topLevelValueScope(values);
173 Block *matcherEntryBlock = &matcherFunc.front();
174 values.insert(predicateBuilder.getRoot(), matcherEntryBlock->
getArgument(0));
178 module, predicateBuilder, valueToPosition);
179 Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody());
180 assert(failureBlockStack.empty() &&
"failed to empty the stack");
185 firstMatcherBlock->
erase();
188Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion,
193 ValueMapScope scope(values);
197 if (isa<ExitNode>(node)) {
199 pdl_interp::FinalizeOp::create(builder, matcherFunc.getLoc());
209 std::unique_ptr<MatcherNode> &failureNode = node.
getFailureNode();
212 failureBlock = generateMatcher(*failureNode, region);
213 failureBlockStack.push_back(failureBlock);
215 assert(!failureBlockStack.empty() &&
"expected valid failure block");
216 failureBlock = failureBlockStack.back();
221 Block *currentBlock = block;
223 Value val = position ? getValueAt(currentBlock, position) : Value();
227 bool isOperationValue = val && isa<pdl::OperationType>(val.
getType());
228 if (isOperationValue)
233 .Case<BoolNode, SwitchNode>([&](
auto *derivedNode) {
234 this->generate(derivedNode, currentBlock, val);
236 .Case([&](SuccessNode *successNode) {
237 generate(successNode, currentBlock);
242 while (failureBlockStack.back() != failureBlock) {
243 failureBlockStack.pop_back();
244 assert(!failureBlockStack.empty() &&
"unable to locate failure block");
249 failureBlockStack.pop_back();
251 if (isOperationValue)
257Value PatternLowering::getValueAt(
Block *¤tBlock, Position *pos) {
258 if (Value val = values.lookup(pos))
264 parentVal = getValueAt(currentBlock, parent);
272 auto *operationPos = cast<OperationPosition>(pos);
273 if (operationPos->isOperandDefiningOp())
275 value = pdl_interp::GetDefiningOpOp::create(
276 builder, loc, builder.
getType<pdl::OperationType>(), parentVal);
283 auto *usersPos = cast<UsersPosition>(pos);
288 if (isa<pdl::RangeType>(parentVal.
getType()) &&
289 usersPos->useRepresentative())
290 value = pdl_interp::ExtractOp::create(builder, loc, parentVal, 0);
295 value = pdl_interp::GetUsersOp::create(builder, loc, value);
299 assert(!failureBlockStack.empty() &&
"expected valid failure block");
300 auto foreach = pdl_interp::ForEachOp::create(
301 builder, loc, parentVal, failureBlockStack.back(),
true);
302 value =
foreach.getLoopVariable();
306 pdl_interp::ContinueOp::create(builder, loc);
307 failureBlockStack.push_back(continueBlock);
309 currentBlock = &
foreach.getRegion().
front();
313 auto *operandPos = cast<OperandPosition>(pos);
314 value = pdl_interp::GetOperandOp::create(
315 builder, loc, builder.
getType<pdl::ValueType>(), parentVal,
316 operandPos->getOperandNumber());
320 auto *operandPos = cast<OperandGroupPosition>(pos);
321 Type valueTy = builder.
getType<pdl::ValueType>();
322 value = pdl_interp::GetOperandsOp::create(
324 operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
325 parentVal, operandPos->getOperandGroupNumber());
329 auto *attrPos = cast<AttributePosition>(pos);
330 value = pdl_interp::GetAttributeOp::create(
331 builder, loc, builder.
getType<pdl::AttributeType>(), parentVal,
332 attrPos->getName().strref());
336 if (isa<pdl::AttributeType>(parentVal.
getType()))
337 value = pdl_interp::GetAttributeTypeOp::create(builder, loc, parentVal);
339 value = pdl_interp::GetValueTypeOp::create(builder, loc, parentVal);
343 auto *resPos = cast<ResultPosition>(pos);
344 value = pdl_interp::GetResultOp::create(
345 builder, loc, builder.
getType<pdl::ValueType>(), parentVal,
346 resPos->getResultNumber());
350 auto *resPos = cast<ResultGroupPosition>(pos);
351 Type valueTy = builder.
getType<pdl::ValueType>();
352 value = pdl_interp::GetResultsOp::create(
354 resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
355 parentVal, resPos->getResultGroupNumber());
359 auto *attrPos = cast<AttributeLiteralPosition>(pos);
360 value = pdl_interp::CreateAttributeOp::create(builder, loc,
361 attrPos->getValue());
365 auto *typePos = cast<TypeLiteralPosition>(pos);
366 Attribute rawTypeAttr = typePos->getValue();
367 if (TypeAttr typeAttr = dyn_cast<TypeAttr>(rawTypeAttr))
368 value = pdl_interp::CreateTypeOp::create(builder, loc, typeAttr);
370 value = pdl_interp::CreateTypesOp::create(builder, loc,
371 cast<ArrayAttr>(rawTypeAttr));
377 auto *constrResPos = cast<ConstraintPosition>(pos);
378 auto i = constraintOpMap.find(constrResPos->getQuestion());
379 assert(i != constraintOpMap.end());
380 value = i->second->getResult(constrResPos->getIndex());
384 llvm_unreachable(
"Generating unknown Position getter");
388 values.insert(pos, value);
392void PatternLowering::generate(BoolNode *boolNode,
Block *¤tBlock,
394 Location loc = val.
getLoc();
396 Qualifier *answer = boolNode->
getAnswer();
397 Region *region = currentBlock->
getParent();
401 SmallVector<Value> args;
402 if (
auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
403 args = {getValueAt(currentBlock, equalToQuestion->getValue())};
404 }
else if (
auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
405 for (Position *position : cstQuestion->getArgs())
406 args.push_back(getValueAt(currentBlock, position));
411 Block *failure = failureBlockStack.back();
418 pdl_interp::IsNotNullOp::create(builder, loc, val,
success, failure);
421 auto *opNameAnswer = cast<OperationNameAnswer>(answer);
422 pdl_interp::CheckOperationNameOp::create(
423 builder, loc, val, opNameAnswer->getValue().getStringRef(),
success,
428 auto *ans = cast<TypeAnswer>(answer);
429 if (isa<pdl::RangeType>(val.
getType()))
430 pdl_interp::CheckTypesOp::create(builder, loc, val,
431 llvm::cast<ArrayAttr>(ans->getValue()),
434 pdl_interp::CheckTypeOp::create(builder, loc, val,
435 llvm::cast<TypeAttr>(ans->getValue()),
440 auto *ans = cast<AttributeAnswer>(answer);
441 pdl_interp::CheckAttributeOp::create(builder, loc, val, ans->getValue(),
447 pdl_interp::CheckOperandCountOp::create(
448 builder, loc, val, cast<UnsignedAnswer>(answer)->getValue(),
454 pdl_interp::CheckResultCountOp::create(
455 builder, loc, val, cast<UnsignedAnswer>(answer)->getValue(),
460 bool trueAnswer = isa<TrueAnswer>(answer);
461 pdl_interp::AreEqualOp::create(builder, loc, val, args.front(),
462 trueAnswer ?
success : failure,
463 trueAnswer ? failure :
success);
467 auto *cstQuestion = cast<ConstraintQuestion>(question);
468 auto applyConstraintOp = pdl_interp::ApplyConstraintOp::create(
469 builder, loc, cstQuestion->getResultTypes(), cstQuestion->getName(),
470 args, cstQuestion->getIsNegated(),
success, failure);
472 constraintOpMap.insert({cstQuestion, applyConstraintOp});
476 llvm_unreachable(
"Generating unknown Predicate operation");
484template <
typename OpT,
typename PredT,
typename ValT =
typename PredT::KeyTy>
486 llvm::MapVector<Qualifier *, Block *> &dests) {
487 std::vector<ValT> values;
488 std::vector<Block *> blocks;
489 values.reserve(dests.size());
490 blocks.reserve(dests.size());
491 for (
const auto &it : dests) {
492 blocks.push_back(it.second);
493 values.push_back(cast<PredT>(it.first)->getValue());
495 OpT::create(builder, val.
getLoc(), val, values, defaultDest, blocks);
498void PatternLowering::generate(SwitchNode *switchNode,
Block *currentBlock,
501 Region *region = currentBlock->
getParent();
502 Block *defaultDest = failureBlockStack.back();
510 SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
511 llvm::seq<unsigned>(0, switchNode->
getChildren().size()));
512 llvm::sort(sortedChildren, [&](
unsigned lhs,
unsigned rhs) {
513 return cast<UnsignedAnswer>(switchNode->
getChild(
lhs).first)->getValue() >
514 cast<UnsignedAnswer>(switchNode->
getChild(
rhs).first)->getValue();
534 failureBlockStack.push_back(defaultDest);
535 Location loc = val.
getLoc();
536 for (
unsigned idx : sortedChildren) {
537 auto &child = switchNode->
getChild(idx);
538 Block *childBlock = generateMatcher(*child.second, *region);
541 unsigned ans = cast<UnsignedAnswer>(child.first)->getValue();
544 pdl_interp::CheckOperandCountOp::create(builder, loc, val, ans,
546 childBlock, defaultDest);
549 pdl_interp::CheckResultCountOp::create(builder, loc, val, ans,
551 childBlock, defaultDest);
554 llvm_unreachable(
"Generating invalid AtLeast operation");
556 failureBlockStack.back() = predicateBlock;
558 Block *firstPredicateBlock = failureBlockStack.pop_back_val();
561 firstPredicateBlock->
erase();
567 llvm::MapVector<Qualifier *, Block *> children;
569 children.insert({it.first, generateMatcher(*it.second, *region)});
575 int32_t>(val, defaultDest, builder, children);
578 int32_t>(val, defaultDest, builder, children);
584 if (isa<pdl::RangeType>(val.getType())) {
586 val, defaultDest, builder, children);
589 val, defaultDest, builder, children);
592 val, defaultDest, builder, children);
594 llvm_unreachable(
"Generating unknown switch predicate.");
598void PatternLowering::generate(SuccessNode *successNode,
Block *¤tBlock) {
599 pdl::PatternOp pattern = successNode->
getPattern();
600 Value root = successNode->
getRoot();
604 SmallVector<Position *, 8> usedMatchValues;
605 SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
608 std::vector<Value> mappedMatchValues;
609 mappedMatchValues.reserve(usedMatchValues.size());
610 for (Position *position : usedMatchValues)
611 mappedMatchValues.push_back(getValueAt(currentBlock, position));
614 SmallVector<StringRef, 4> generatedOps;
616 pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>())
617 generatedOps.push_back(*op.getOpName());
619 if (!generatedOps.empty())
620 generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
623 StringAttr rootKindAttr;
624 if (pdl::OperationOp rootOp = root.
getDefiningOp<pdl::OperationOp>())
625 if (std::optional<StringRef> rootKind = rootOp.getOpName())
626 rootKindAttr = builder.getStringAttr(*rootKind);
628 builder.setInsertionPointToEnd(currentBlock);
629 auto matchOp = pdl_interp::RecordMatchOp::create(
630 builder, pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
631 rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(),
632 failureBlockStack.back());
636 configMap->try_emplace(matchOp, configMap->lookup(pattern));
639SymbolRefAttr PatternLowering::generateRewriter(
640 pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
641 builder.setInsertionPointToEnd(rewriterModule.getBody());
643 StringRef rewriterName =
"pdl_generated_rewriter";
644 if (
auto symName = pattern.getSymName())
645 rewriterName = symName.value();
646 auto rewriterFunc = pdl_interp::FuncOp::create(
647 builder, pattern.getLoc(), rewriterName, builder.getFunctionType({}, {}));
648 rewriterSymbolTable.
insert(rewriterFunc);
651 builder.setInsertionPointToEnd(&rewriterFunc.front());
654 DenseMap<Value, Value> rewriteValues;
655 auto mapRewriteValue = [&](Value oldValue) {
656 Value &newValue = rewriteValues[oldValue];
662 if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
663 if (Attribute value = attrOp.getValueAttr()) {
664 return newValue = pdl_interp::CreateAttributeOp::create(
665 builder, attrOp.getLoc(), value);
667 }
else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
668 if (TypeAttr type = typeOp.getConstantTypeAttr()) {
669 return newValue = pdl_interp::CreateTypeOp::create(
670 builder, typeOp.getLoc(), type);
672 }
else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
673 if (
ArrayAttr type = typeOp.getConstantTypesAttr()) {
674 return newValue = pdl_interp::CreateTypesOp::create(
675 builder, typeOp.getLoc(), typeOp.getType(), type);
680 Position *inputPos = valueToPosition.lookup(oldValue);
681 assert(inputPos &&
"expected value to be a pattern input");
682 usedMatchValues.push_back(inputPos);
683 return newValue = rewriterFunc.front().addArgument(oldValue.getType(),
689 pdl::RewriteOp rewriter = pattern.getRewriter();
690 if (StringAttr rewriteName = rewriter.getNameAttr()) {
691 SmallVector<Value> args;
692 if (rewriter.getRoot())
693 args.push_back(mapRewriteValue(rewriter.getRoot()));
695 llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue);
696 args.append(mappedArgs.begin(), mappedArgs.end());
697 pdl_interp::ApplyRewriteOp::create(builder, rewriter.getLoc(),
702 for (Operation &rewriteOp : *rewriter.getBody()) {
703 llvm::TypeSwitch<Operation *>(&rewriteOp)
704 .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
705 pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp,
706 pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>([&](
auto op) {
707 this->generateRewriter(op, rewriteValues, mapRewriteValue);
713 rewriterFunc.setType(builder.getFunctionType(
714 llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
717 pdl_interp::FinalizeOp::create(builder, rewriter.getLoc());
718 return SymbolRefAttr::get(
719 builder.getContext(),
720 pdl_interp::PDLInterpDialect::getRewriterModuleName(),
721 SymbolRefAttr::get(rewriterFunc));
724void PatternLowering::generateRewriter(
725 pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
727 SmallVector<Value, 2> arguments;
728 for (Value argument : rewriteOp.getArgs())
729 arguments.push_back(mapRewriteValue(argument));
730 auto interpOp = pdl_interp::ApplyRewriteOp::create(
731 builder, rewriteOp.getLoc(), rewriteOp.getResultTypes(),
732 rewriteOp.getNameAttr(), arguments);
733 for (
auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
734 rewriteValues[std::get<0>(it)] = std::get<1>(it);
737void PatternLowering::generateRewriter(
738 pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
740 Value newAttr = pdl_interp::CreateAttributeOp::create(
741 builder, attrOp.getLoc(), attrOp.getValueAttr());
742 rewriteValues[attrOp] = newAttr;
745void PatternLowering::generateRewriter(
746 pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
748 pdl_interp::EraseOp::create(builder, eraseOp.getLoc(),
749 mapRewriteValue(eraseOp.getOpValue()));
752void PatternLowering::generateRewriter(
753 pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
755 SmallVector<Value, 4> operands;
756 for (Value operand : operationOp.getOperandValues())
757 operands.push_back(mapRewriteValue(operand));
759 SmallVector<Value, 4> attributes;
760 for (Value attr : operationOp.getAttributeValues())
761 attributes.push_back(mapRewriteValue(attr));
763 bool hasInferredResultTypes =
false;
764 SmallVector<Value, 2> types;
765 generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types,
766 rewriteValues, hasInferredResultTypes);
769 Location loc = operationOp.getLoc();
770 Value createdOp = pdl_interp::CreateOperationOp::create(
771 builder, loc, *operationOp.getOpName(), types, hasInferredResultTypes,
772 operands, attributes, operationOp.getAttributeValueNames());
773 rewriteValues[operationOp.getOp()] = createdOp;
778 OperandRange resultTys = operationOp.getTypeValues();
779 if (resultTys.size() == 1 && isa<pdl::RangeType>(resultTys[0].getType())) {
780 Value &type = rewriteValues[resultTys[0]];
782 auto results = pdl_interp::GetResultsOp::create(builder, loc, createdOp);
783 type = pdl_interp::GetValueTypeOp::create(builder, loc, results);
789 bool seenVariableLength =
false;
790 Type valueTy = builder.
getType<pdl::ValueType>();
791 Type valueRangeTy = pdl::RangeType::get(valueTy);
792 for (
const auto &it : llvm::enumerate(resultTys)) {
793 Value &type = rewriteValues[it.value()];
796 bool isVariadic = isa<pdl::RangeType>(it.value().getType());
797 seenVariableLength |= isVariadic;
802 if (seenVariableLength)
803 resultVal = pdl_interp::GetResultsOp::create(
804 builder, loc, isVariadic ? valueRangeTy : valueTy, createdOp,
807 resultVal = pdl_interp::GetResultOp::create(builder, loc, valueTy,
808 createdOp, it.index());
809 type = pdl_interp::GetValueTypeOp::create(builder, loc, resultVal);
813void PatternLowering::generateRewriter(
814 pdl::RangeOp rangeOp, DenseMap<Value, Value> &rewriteValues,
816 SmallVector<Value, 4> replOperands;
817 for (Value operand : rangeOp.getArguments())
818 replOperands.push_back(mapRewriteValue(operand));
819 rewriteValues[rangeOp] = pdl_interp::CreateRangeOp::create(
820 builder, rangeOp.getLoc(), rangeOp.getType(), replOperands);
823void PatternLowering::generateRewriter(
824 pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
826 SmallVector<Value, 4> replOperands;
831 if (Value replOp = replaceOp.getReplOperation()) {
833 auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>();
834 if (!opOp || !opOp.getTypeValues().empty()) {
835 replOperands.push_back(pdl_interp::GetResultsOp::create(
836 builder, replOp.getLoc(), mapRewriteValue(replOp)));
839 for (Value operand : replaceOp.getReplValues())
840 replOperands.push_back(mapRewriteValue(operand));
844 if (replOperands.empty()) {
845 pdl_interp::EraseOp::create(builder, replaceOp.getLoc(),
846 mapRewriteValue(replaceOp.getOpValue()));
850 pdl_interp::ReplaceOp::create(builder, replaceOp.getLoc(),
851 mapRewriteValue(replaceOp.getOpValue()),
855void PatternLowering::generateRewriter(
856 pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
858 rewriteValues[resultOp] = pdl_interp::GetResultOp::create(
859 builder, resultOp.getLoc(), builder.getType<pdl::ValueType>(),
860 mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
863void PatternLowering::generateRewriter(
864 pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
866 rewriteValues[resultOp] = pdl_interp::GetResultsOp::create(
867 builder, resultOp.getLoc(), resultOp.getType(),
868 mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
871void PatternLowering::generateRewriter(
872 pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
876 if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) {
877 rewriteValues[typeOp] =
878 pdl_interp::CreateTypeOp::create(builder, typeOp.getLoc(), typeAttr);
882void PatternLowering::generateRewriter(
883 pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
887 if (
ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) {
888 rewriteValues[typeOp] = pdl_interp::CreateTypesOp::create(
889 builder, typeOp.getLoc(), typeOp.getType(), typeAttr);
893void PatternLowering::generateOperationResultTypeRewriter(
894 pdl::OperationOp op,
function_ref<Value(Value)> mapRewriteValue,
895 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
896 bool &hasInferredResultTypes) {
897 Block *rewriterBlock = op->getBlock();
902 OperandRange resultTypeValues = op.getTypeValues();
903 auto tryResolveResultTypes = [&] {
904 types.reserve(resultTypeValues.size());
905 for (
const auto &it : llvm::enumerate(resultTypeValues)) {
906 Value resultType = it.value();
909 if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
910 types.push_back(existingRewriteValue);
916 types.push_back(mapRewriteValue(resultType));
927 if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes()))
931 if (op.hasTypeInference()) {
932 hasInferredResultTypes =
true;
938 for (OpOperand &use : op.getOp().getUses()) {
941 pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
942 if (!replOpUser || use.getOperandNumber() == 0)
948 Value replOpVal = replOpUser.getOpValue();
950 if (replacedOp->
getBlock() == rewriterBlock &&
954 Value replacedOpResults = pdl_interp::GetResultsOp::create(
955 builder, replacedOp->
getLoc(), mapRewriteValue(replOpVal));
956 types.push_back(pdl_interp::GetValueTypeOp::create(
957 builder, replacedOp->
getLoc(), replacedOpResults));
964 if (resultTypeValues.empty())
970 op->emitOpError() <<
"unable to infer result type for operation";
971 llvm_unreachable(
"unable to infer result type for operation");
979struct PDLToPDLInterpPass
980 :
public impl::ConvertPDLToPDLInterpPassBase<PDLToPDLInterpPass> {
981 PDLToPDLInterpPass() =
default;
982 PDLToPDLInterpPass(
const PDLToPDLInterpPass &
rhs) =
default;
984 : configMap(&configMap) {}
985 void runOnOperation() final;
988 DenseMap<Operation *, PDLPatternConfigSet *> *configMap =
nullptr;
994void PDLToPDLInterpPass::runOnOperation() {
995 ModuleOp module = getOperation();
1000 auto matcherFunc = pdl_interp::FuncOp::create(
1001 builder, module.getLoc(),
1002 pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
1005 ArrayRef<NamedAttribute>());
1009 ModuleOp rewriterModule =
1010 ModuleOp::create(builder, module.getLoc(),
1011 pdl_interp::PDLInterpDialect::getRewriterModuleName());
1014 PatternLowering
generator(matcherFunc, rewriterModule, configMap);
1018 for (pdl::PatternOp pattern :
1019 llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) {
1022 configMap->erase(pattern);
1030 return std::make_unique<PDLToPDLInterpPass>(configMap);
static const mlir::GenInfo * generator
static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, llvm::MapVector< Qualifier *, Block * > &dests)
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
void erase()
Unlink this Block from its parent region and delete it.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
This class helps build Operations.
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
Location getLoc()
The source location the operation was defined or derived from.
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Position * getPosition() const
Returns the position on which the question predicate should be checked.
std::unique_ptr< MatcherNode > & getFailureNode()
Returns the node that should be visited if this, or a subsequent node fails.
Qualifier * getQuestion() const
Returns the predicate checked on this node.
static std::unique_ptr< MatcherNode > generateMatcherTree(ModuleOp module, PredicateBuilder &builder, DenseMap< Value, Position * > &valueToPosition)
Given a module containing PDL pattern operations, generate a matcher tree using the patterns within t...
Position * getParent() const
Returns the parent position. The root operation position has no parent.
Predicates::Kind getKind() const
Returns the kind of this position.
Predicates::Kind getKind() const
Returns the kind of this qualifier.
Kind
An enumeration of the kinds of predicates.
@ ResultCountAtLeastQuestion
@ OperationPos
Positions, ordered by decreasing priority.
@ OperandCountAtLeastQuestion
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
llvm::TypeSwitch< T, ResultT > TypeSwitch
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
std::unique_ptr<::mlir::Pass > createConvertPDLToPDLInterpPass()
llvm::function_ref< Fn > function_ref
std::unique_ptr< MatcherNode > & getSuccessNode()
Returns the node that should be visited on success.
Qualifier * getAnswer() const
Returns the expected answer of this boolean node.
pdl::PatternOp getPattern() const
Return the high level pattern operation that is matched with this node.
Value getRoot() const
Return the chosen root of the pattern.
std::pair< Qualifier *, std::unique_ptr< MatcherNode > > & getChild(unsigned i)
Returns the child at the given index.
ChildMapT & getChildren()