16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/ScopedHashTable.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
24 #define GEN_PASS_DEF_CONVERTPDLTOPDLINTERP
25 #include "mlir/Conversion/Passes.h.inc"
38 struct PatternLowering {
40 PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
45 void lower(ModuleOp module);
48 using ValueMap = llvm::ScopedHashTable<Position *, Value>;
49 using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
76 SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
80 void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
83 void generateRewriter(pdl::AttributeOp attrOp,
86 void generateRewriter(pdl::EraseOp eraseOp,
89 void generateRewriter(pdl::OperationOp operationOp,
92 void generateRewriter(pdl::RangeOp rangeOp,
95 void generateRewriter(pdl::ReplaceOp replaceOp,
98 void generateRewriter(pdl::ResultOp resultOp,
101 void generateRewriter(pdl::ResultsOp resultOp,
104 void generateRewriter(pdl::TypeOp typeOp,
107 void generateRewriter(pdl::TypesOp typeOp,
114 void generateOperationResultTypeRewriter(
117 bool &hasInferredResultTypes);
123 pdl_interp::FuncOp matcherFunc;
127 ModuleOp rewriterModule;
154 PatternLowering::PatternLowering(
155 pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
157 : builder(matcherFunc.
getContext()), matcherFunc(matcherFunc),
158 rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule),
159 configMap(configMap) {}
161 void PatternLowering::lower(ModuleOp module) {
166 ValueMapScope topLevelValueScope(values);
170 Block *matcherEntryBlock = &matcherFunc.
front();
171 values.insert(predicateBuilder.getRoot(), matcherEntryBlock->
getArgument(0));
174 std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
175 module, predicateBuilder, valueToPosition);
176 Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody());
177 assert(failureBlockStack.empty() &&
"failed to empty the stack");
182 firstMatcherBlock->
erase();
188 ValueMapScope scope(values);
192 if (isa<ExitNode>(node)) {
193 builder.setInsertionPointToEnd(block);
194 builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
204 std::unique_ptr<MatcherNode> &failureNode = node.
getFailureNode();
207 failureBlock = generateMatcher(*failureNode, region);
208 failureBlockStack.push_back(failureBlock);
210 assert(!failureBlockStack.empty() &&
"expected valid failure block");
211 failureBlock = failureBlockStack.
back();
216 Block *currentBlock = block;
218 Value val = position ? getValueAt(currentBlock, position) :
Value();
222 bool isOperationValue = val && isa<pdl::OperationType>(val.
getType());
223 if (isOperationValue)
229 this->generate(derivedNode, currentBlock, val);
232 generate(successNode, currentBlock);
237 while (failureBlockStack.back() != failureBlock) {
238 failureBlockStack.pop_back();
239 assert(!failureBlockStack.empty() &&
"unable to locate failure block");
244 failureBlockStack.pop_back();
246 if (isOperationValue)
253 if (
Value val = values.lookup(pos))
259 parentVal = getValueAt(currentBlock, parent);
262 Location loc = parentVal ? parentVal.
getLoc() : builder.getUnknownLoc();
263 builder.setInsertionPointToEnd(currentBlock);
267 auto *operationPos = cast<OperationPosition>(pos);
268 if (operationPos->isOperandDefiningOp())
270 value = builder.create<pdl_interp::GetDefiningOpOp>(
271 loc, builder.getType<pdl::OperationType>(), parentVal);
278 auto *usersPos = cast<UsersPosition>(pos);
283 if (isa<pdl::RangeType>(parentVal.
getType()) &&
284 usersPos->useRepresentative())
285 value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0);
290 value = builder.create<pdl_interp::GetUsersOp>(loc, value);
294 assert(!failureBlockStack.empty() &&
"expected valid failure block");
295 auto foreach = builder.create<pdl_interp::ForEachOp>(
296 loc, parentVal, failureBlockStack.back(),
true);
297 value =
foreach.getLoopVariable();
300 Block *continueBlock = builder.createBlock(&
foreach.getRegion());
301 builder.create<pdl_interp::ContinueOp>(loc);
302 failureBlockStack.
push_back(continueBlock);
304 currentBlock = &
foreach.getRegion().
front();
308 auto *operandPos = cast<OperandPosition>(pos);
309 value = builder.create<pdl_interp::GetOperandOp>(
310 loc, builder.getType<pdl::ValueType>(), parentVal,
311 operandPos->getOperandNumber());
315 auto *operandPos = cast<OperandGroupPosition>(pos);
316 Type valueTy = builder.getType<pdl::ValueType>();
317 value = builder.create<pdl_interp::GetOperandsOp>(
319 parentVal, operandPos->getOperandGroupNumber());
323 auto *attrPos = cast<AttributePosition>(pos);
324 value = builder.create<pdl_interp::GetAttributeOp>(
325 loc, builder.getType<pdl::AttributeType>(), parentVal,
326 attrPos->getName().strref());
330 if (isa<pdl::AttributeType>(parentVal.getType()))
331 value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
333 value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
337 auto *resPos = cast<ResultPosition>(pos);
338 value = builder.create<pdl_interp::GetResultOp>(
339 loc, builder.getType<pdl::ValueType>(), parentVal,
340 resPos->getResultNumber());
344 auto *resPos = cast<ResultGroupPosition>(pos);
345 Type valueTy = builder.getType<pdl::ValueType>();
346 value = builder.create<pdl_interp::GetResultsOp>(
348 parentVal, resPos->getResultGroupNumber());
352 auto *attrPos = cast<AttributeLiteralPosition>(pos);
354 builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue());
358 auto *typePos = cast<TypeLiteralPosition>(pos);
359 Attribute rawTypeAttr = typePos->getValue();
360 if (TypeAttr typeAttr = dyn_cast<TypeAttr>(rawTypeAttr))
361 value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr);
363 value = builder.create<pdl_interp::CreateTypesOp>(
364 loc, cast<ArrayAttr>(rawTypeAttr));
368 llvm_unreachable(
"Generating unknown Position getter");
372 values.insert(pos, value);
376 void PatternLowering::generate(
BoolNode *boolNode,
Block *¤tBlock,
386 if (
auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
387 args = {getValueAt(currentBlock, equalToQuestion->getValue())};
388 }
else if (
auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
389 for (
Position *position : cstQuestion->getArgs())
390 args.push_back(getValueAt(currentBlock, position));
399 builder.setInsertionPointToEnd(currentBlock);
403 builder.create<pdl_interp::IsNotNullOp>(loc, val,
success,
failure);
406 auto *opNameAnswer = cast<OperationNameAnswer>(answer);
407 builder.create<pdl_interp::CheckOperationNameOp>(
408 loc, val, opNameAnswer->getValue().getStringRef(),
success,
failure);
412 auto *ans = cast<TypeAnswer>(answer);
413 if (isa<pdl::RangeType>(val.getType()))
414 builder.create<pdl_interp::CheckTypesOp>(
415 loc, val, llvm::cast<ArrayAttr>(ans->getValue()),
success,
failure);
417 builder.create<pdl_interp::CheckTypeOp>(
418 loc, val, llvm::cast<TypeAttr>(ans->getValue()),
success,
failure);
422 auto *ans = cast<AttributeAnswer>(answer);
423 builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
429 builder.create<pdl_interp::CheckOperandCountOp>(
430 loc, val, cast<UnsignedAnswer>(answer)->getValue(),
436 builder.create<pdl_interp::CheckResultCountOp>(
437 loc, val, cast<UnsignedAnswer>(answer)->getValue(),
442 bool trueAnswer = isa<TrueAnswer>(answer);
443 builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(),
449 auto *cstQuestion = cast<ConstraintQuestion>(question);
450 builder.create<pdl_interp::ApplyConstraintOp>(
451 loc, cstQuestion->getName(), args, cstQuestion->getIsNegated(),
success,
456 llvm_unreachable(
"Generating unknown Predicate operation");
460 template <
typename OpT,
typename PredT,
typename ValT =
typename PredT::KeyTy>
462 llvm::MapVector<Qualifier *, Block *> &dests) {
463 std::vector<ValT> values;
464 std::vector<Block *> blocks;
465 values.reserve(dests.size());
466 blocks.reserve(dests.size());
467 for (
const auto &it : dests) {
468 blocks.push_back(it.second);
469 values.push_back(cast<PredT>(it.first)->getValue());
471 builder.
create<OpT>(val.
getLoc(), val, values, defaultDest, blocks);
474 void PatternLowering::generate(
SwitchNode *switchNode,
Block *currentBlock,
478 Block *defaultDest = failureBlockStack.
back();
487 llvm::seq<unsigned>(0, switchNode->
getChildren().size()));
488 llvm::sort(sortedChildren, [&](
unsigned lhs,
unsigned rhs) {
489 return cast<UnsignedAnswer>(switchNode->
getChild(lhs).first)->getValue() >
490 cast<UnsignedAnswer>(switchNode->
getChild(rhs).first)->getValue();
510 failureBlockStack.push_back(defaultDest);
512 for (
unsigned idx : sortedChildren) {
513 auto &child = switchNode->
getChild(idx);
514 Block *childBlock = generateMatcher(*child.second, *region);
515 Block *predicateBlock = builder.createBlock(childBlock);
516 builder.setInsertionPointToEnd(predicateBlock);
517 unsigned ans = cast<UnsignedAnswer>(child.first)->getValue();
520 builder.create<pdl_interp::CheckOperandCountOp>(
521 loc, val, ans,
true, childBlock, defaultDest);
524 builder.create<pdl_interp::CheckResultCountOp>(
525 loc, val, ans,
true, childBlock, defaultDest);
528 llvm_unreachable(
"Generating invalid AtLeast operation");
530 failureBlockStack.back() = predicateBlock;
532 Block *firstPredicateBlock = failureBlockStack.pop_back_val();
535 firstPredicateBlock->
erase();
541 llvm::MapVector<Qualifier *, Block *> children;
543 children.insert({it.first, generateMatcher(*it.second, *region)});
544 builder.setInsertionPointToEnd(currentBlock);
549 int32_t>(val, defaultDest, builder, children);
552 int32_t>(val, defaultDest, builder, children);
558 if (isa<pdl::RangeType>(val.getType())) {
559 return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
560 val, defaultDest, builder, children);
562 return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
563 val, defaultDest, builder, children);
565 return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
566 val, defaultDest, builder, children);
568 llvm_unreachable(
"Generating unknown switch predicate.");
572 void PatternLowering::generate(
SuccessNode *successNode,
Block *¤tBlock) {
573 pdl::PatternOp pattern = successNode->
getPattern();
579 SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
582 std::vector<Value> mappedMatchValues;
583 mappedMatchValues.reserve(usedMatchValues.size());
584 for (
Position *position : usedMatchValues)
585 mappedMatchValues.push_back(getValueAt(currentBlock, position));
590 pattern.getRewriter().getBodyRegion().getOps<pdl::OperationOp>())
591 generatedOps.push_back(*op.getOpName());
592 ArrayAttr generatedOpsAttr;
593 if (!generatedOps.empty())
594 generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
597 StringAttr rootKindAttr;
598 if (pdl::OperationOp rootOp = root.
getDefiningOp<pdl::OperationOp>())
599 if (std::optional<StringRef> rootKind = rootOp.getOpName())
600 rootKindAttr = builder.getStringAttr(*rootKind);
602 builder.setInsertionPointToEnd(currentBlock);
603 auto matchOp = builder.create<pdl_interp::RecordMatchOp>(
604 pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
605 rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(),
606 failureBlockStack.back());
610 configMap->try_emplace(matchOp, configMap->lookup(pattern));
613 SymbolRefAttr PatternLowering::generateRewriter(
615 builder.setInsertionPointToEnd(rewriterModule.getBody());
616 auto rewriterFunc = builder.create<pdl_interp::FuncOp>(
617 pattern.getLoc(),
"pdl_generated_rewriter",
618 builder.getFunctionType(std::nullopt, std::nullopt));
619 rewriterSymbolTable.insert(rewriterFunc);
622 builder.setInsertionPointToEnd(&rewriterFunc.front());
626 auto mapRewriteValue = [&](
Value oldValue) {
627 Value &newValue = rewriteValues[oldValue];
632 Operation *oldOp = oldValue.getDefiningOp();
633 if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
634 if (
Attribute value = attrOp.getValueAttr()) {
635 return newValue = builder.create<pdl_interp::CreateAttributeOp>(
636 attrOp.getLoc(), value);
638 }
else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
639 if (TypeAttr type = typeOp.getConstantTypeAttr()) {
640 return newValue = builder.create<pdl_interp::CreateTypeOp>(
641 typeOp.getLoc(), type);
643 }
else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
644 if (ArrayAttr type = typeOp.getConstantTypesAttr()) {
645 return newValue = builder.create<pdl_interp::CreateTypesOp>(
646 typeOp.getLoc(), typeOp.getType(), type);
651 Position *inputPos = valueToPosition.lookup(oldValue);
652 assert(inputPos &&
"expected value to be a pattern input");
653 usedMatchValues.push_back(inputPos);
654 return newValue = rewriterFunc.front().addArgument(oldValue.getType(),
660 pdl::RewriteOp rewriter = pattern.getRewriter();
661 if (StringAttr rewriteName = rewriter.getNameAttr()) {
663 if (rewriter.getRoot())
664 args.push_back(mapRewriteValue(rewriter.getRoot()));
666 llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue);
667 args.append(mappedArgs.begin(), mappedArgs.end());
668 builder.create<pdl_interp::ApplyRewriteOp>(
669 rewriter.getLoc(),
TypeRange(), rewriteName, args);
672 for (
Operation &rewriteOp : *rewriter.getBody()) {
674 .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
675 pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp,
676 pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>([&](
auto op) {
677 this->generateRewriter(op, rewriteValues, mapRewriteValue);
683 rewriterFunc.setType(builder.getFunctionType(
684 llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
687 builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
689 builder.getContext(),
690 pdl_interp::PDLInterpDialect::getRewriterModuleName(),
694 void PatternLowering::generateRewriter(
698 for (
Value argument : rewriteOp.getArgs())
699 arguments.push_back(mapRewriteValue(argument));
700 auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
701 rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.getNameAttr(),
703 for (
auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
704 rewriteValues[std::get<0>(it)] = std::get<1>(it);
707 void PatternLowering::generateRewriter(
710 Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
711 attrOp.getLoc(), attrOp.getValueAttr());
712 rewriteValues[attrOp] = newAttr;
715 void PatternLowering::generateRewriter(
718 builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
719 mapRewriteValue(eraseOp.getOpValue()));
722 void PatternLowering::generateRewriter(
726 for (
Value operand : operationOp.getOperandValues())
727 operands.push_back(mapRewriteValue(operand));
730 for (
Value attr : operationOp.getAttributeValues())
731 attributes.push_back(mapRewriteValue(attr));
733 bool hasInferredResultTypes =
false;
735 generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types,
736 rewriteValues, hasInferredResultTypes);
739 Location loc = operationOp.getLoc();
740 Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
741 loc, *operationOp.getOpName(), types, hasInferredResultTypes, operands,
742 attributes, operationOp.getAttributeValueNames());
743 rewriteValues[operationOp.getOp()] = createdOp;
749 if (resultTys.size() == 1 && isa<pdl::RangeType>(resultTys[0].getType())) {
750 Value &type = rewriteValues[resultTys[0]];
752 auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
753 type = builder.create<pdl_interp::GetValueTypeOp>(loc, results);
759 bool seenVariableLength =
false;
760 Type valueTy = builder.getType<pdl::ValueType>();
763 Value &type = rewriteValues[it.value()];
766 bool isVariadic = isa<pdl::RangeType>(it.value().getType());
767 seenVariableLength |= isVariadic;
772 if (seenVariableLength)
773 resultVal = builder.create<pdl_interp::GetResultsOp>(
774 loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index());
776 resultVal = builder.create<pdl_interp::GetResultOp>(
777 loc, valueTy, createdOp, it.index());
778 type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal);
782 void PatternLowering::generateRewriter(
786 for (
Value operand : rangeOp.getArguments())
787 replOperands.push_back(mapRewriteValue(operand));
788 rewriteValues[rangeOp] = builder.create<pdl_interp::CreateRangeOp>(
789 rangeOp.getLoc(), rangeOp.getType(), replOperands);
792 void PatternLowering::generateRewriter(
800 if (
Value replOp = replaceOp.getReplOperation()) {
802 auto opOp = replaceOp.getOpValue().getDefiningOp<pdl::OperationOp>();
803 if (!opOp || !opOp.getTypeValues().empty()) {
804 replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
805 replOp.getLoc(), mapRewriteValue(replOp)));
808 for (
Value operand : replaceOp.getReplValues())
809 replOperands.push_back(mapRewriteValue(operand));
813 if (replOperands.empty()) {
814 builder.create<pdl_interp::EraseOp>(
815 replaceOp.getLoc(), mapRewriteValue(replaceOp.getOpValue()));
819 builder.create<pdl_interp::ReplaceOp>(replaceOp.getLoc(),
820 mapRewriteValue(replaceOp.getOpValue()),
824 void PatternLowering::generateRewriter(
827 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
828 resultOp.getLoc(), builder.getType<pdl::ValueType>(),
829 mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
832 void PatternLowering::generateRewriter(
835 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
836 resultOp.getLoc(), resultOp.getType(),
837 mapRewriteValue(resultOp.getParent()), resultOp.getIndex());
840 void PatternLowering::generateRewriter(
845 if (TypeAttr typeAttr = typeOp.getConstantTypeAttr()) {
846 rewriteValues[typeOp] =
847 builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
851 void PatternLowering::generateRewriter(
856 if (ArrayAttr typeAttr = typeOp.getConstantTypesAttr()) {
857 rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
858 typeOp.getLoc(), typeOp.getType(), typeAttr);
862 void PatternLowering::generateOperationResultTypeRewriter(
865 bool &hasInferredResultTypes) {
872 auto tryResolveResultTypes = [&] {
873 types.reserve(resultTypeValues.size());
875 Value resultType = it.value();
878 if (
Value existingRewriteValue = rewriteValues.lookup(resultType)) {
879 types.push_back(existingRewriteValue);
885 types.push_back(mapRewriteValue(resultType));
896 if (!resultTypeValues.empty() &&
succeeded(tryResolveResultTypes()))
900 if (op.hasTypeInference()) {
901 hasInferredResultTypes =
true;
910 pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
911 if (!replOpUser || use.getOperandNumber() == 0)
917 Value replOpVal = replOpUser.getOpValue();
919 if (replacedOp->
getBlock() == rewriterBlock &&
923 Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
924 replacedOp->
getLoc(), mapRewriteValue(replOpVal));
925 types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
926 replacedOp->
getLoc(), replacedOpResults));
933 if (resultTypeValues.empty())
939 op->
emitOpError() <<
"unable to infer result type for operation";
940 llvm_unreachable(
"unable to infer result type for operation");
948 struct PDLToPDLInterpPass
949 :
public impl::ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
950 PDLToPDLInterpPass() =
default;
951 PDLToPDLInterpPass(
const PDLToPDLInterpPass &rhs) =
default;
953 : configMap(&configMap) {}
954 void runOnOperation() final;
963 void PDLToPDLInterpPass::runOnOperation() {
964 ModuleOp module = getOperation();
968 OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
969 auto matcherFunc = builder.
create<pdl_interp::FuncOp>(
970 module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
977 ModuleOp rewriterModule = builder.
create<ModuleOp>(
978 module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
981 PatternLowering
generator(matcherFunc, rewriterModule, configMap);
985 for (pdl::PatternOp pattern :
986 llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) {
989 configMap->erase(pattern);
996 return std::make_unique<PDLToPDLInterpPass>();
1000 return std::make_unique<PDLToPDLInterpPass>(configMap);
static MLIRContext * getContext(OpFoldResult val)
static const mlir::GenInfo * generator
static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder, llvm::MapVector< Qualifier *, Block * > &dests)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
void erase()
Unlink this Block from its parent region and delete it.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
void push_back(Operation *op)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains a set of configurations for a specific pattern.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This class represents the base of a predicate matcher node.
Position * getPosition() const
Returns the position on which the question predicate should be checked.
std::unique_ptr< MatcherNode > & getFailureNode()
Returns the node that should be visited if this, or a subsequent node fails.
Qualifier * getQuestion() const
Returns the predicate checked on this node.
A position describes a value on the input IR on which a predicate may be applied, such as an operatio...
Position * getParent() const
Returns the parent position. The root operation position has no parent.
Predicates::Kind getKind() const
Returns the kind of this position.
This class provides utilities for constructing predicates.
This class provides a storage uniquer that is used to allocate predicate instances.
An ordinal predicate consists of a "Question" and a set of acceptable "Answers" (later converted to o...
Predicates::Kind getKind() const
Returns the kind of this qualifier.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Kind
An enumeration of the kinds of predicates.
@ ResultCountAtLeastQuestion
@ OperationPos
Positions, ordered by decreasing priority.
@ OperandCountAtLeastQuestion
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< OperationPass< ModuleOp > > createPDLToPDLInterpPass()
Creates and returns a pass to convert PDL ops to PDL interpreter ops.
A BoolNode denotes a question with a boolean-like result.
std::unique_ptr< MatcherNode > & getSuccessNode()
Returns the node that should be visited on success.
Qualifier * getAnswer() const
Returns the expected answer of this boolean node.
An Answer representing an OperationName value.
A SuccessNode denotes that a given high level pattern has successfully been matched.
pdl::PatternOp getPattern() const
Return the high level pattern operation that is matched with this node.
Value getRoot() const
Return the chosen root of the pattern.
A SwitchNode denotes a question with multiple potential results.
std::pair< Qualifier *, std::unique_ptr< MatcherNode > > & getChild(unsigned i)
Returns the child at the given index.
ChildMapT & getChildren()