MLIR 22.0.0git
Query.cpp
Go to the documentation of this file.
1//===---- Query.cpp - -----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Query/Query.h"
10#include "QueryParser.h"
12#include "mlir/IR/IRMapping.h"
13#include "mlir/IR/Verifier.h"
16#include "llvm/Support/SourceMgr.h"
17#include "llvm/Support/raw_ostream.h"
18
19namespace mlir::query {
20
21QueryRef parse(llvm::StringRef line, const QuerySession &qs) {
22 return QueryParser::parse(line, qs);
23}
24
25std::vector<llvm::LineEditor::Completion>
26complete(llvm::StringRef line, size_t pos, const QuerySession &qs) {
27 return QueryParser::complete(line, pos, qs);
28}
29
30// TODO: Extract into a helper function that can be reused outside query
31// context.
32static Operation *extractFunction(std::vector<Operation *> &ops,
33 MLIRContext *context,
34 llvm::StringRef functionName) {
35 context->loadDialect<func::FuncDialect>();
36 OpBuilder builder(context);
37
38 // Collect data for function creation
39 std::vector<Operation *> slice;
40 std::vector<Value> values;
41 std::vector<Type> outputTypes;
42
43 for (auto *op : ops) {
44 // Return op's operands are propagated, but the op itself isn't needed.
45 if (!isa<func::ReturnOp>(op))
46 slice.push_back(op);
47
48 // All results are returned by the extracted function.
49 llvm::append_range(outputTypes, op->getResults().getTypes());
50
51 // Track all values that need to be taken as input to function.
52 llvm::append_range(values, op->getOperands());
53 }
54
55 // Create the function
56 FunctionType funcType =
57 builder.getFunctionType(TypeRange(ValueRange(values)), outputTypes);
58 auto loc = builder.getUnknownLoc();
59 func::FuncOp funcOp = func::FuncOp::create(loc, functionName, funcType);
60
61 builder.setInsertionPointToEnd(funcOp.addEntryBlock());
62
63 // Map original values to function arguments
64 IRMapping mapper;
65 for (const auto &arg : llvm::enumerate(values))
66 mapper.map(arg.value(), funcOp.getArgument(arg.index()));
67
68 // Clone operations and build function body
69 std::vector<Operation *> clonedOps;
70 std::vector<Value> clonedVals;
71 // TODO: Handle extraction of operations with compute payloads defined via
72 // regions.
73 for (Operation *slicedOp : slice) {
74 Operation *clonedOp =
75 clonedOps.emplace_back(builder.clone(*slicedOp, mapper));
76 clonedVals.insert(clonedVals.end(), clonedOp->result_begin(),
77 clonedOp->result_end());
78 }
79 // Add return operation
80 func::ReturnOp::create(builder, loc, clonedVals);
81
82 // Remove unused function arguments
83 size_t currentIndex = 0;
84 while (currentIndex < funcOp.getNumArguments()) {
85 // Erase if possible.
86 if (funcOp.getArgument(currentIndex).use_empty())
87 if (succeeded(funcOp.eraseArgument(currentIndex)))
88 continue;
89 ++currentIndex;
90 }
91
92 return funcOp;
93}
94
95Query::~Query() = default;
96
97LogicalResult InvalidQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
98 os << errStr << "\n";
99 return mlir::failure();
100}
101
102LogicalResult NoOpQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
103 return mlir::success();
104}
105
106LogicalResult HelpQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
107 os << "Available commands:\n\n"
108 " match MATCHER, m MATCHER "
109 "Match the mlir against the given matcher.\n"
110 " quit "
111 "Terminates the query session.\n\n";
112 return mlir::success();
113}
114
115LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
116 qs.terminate = true;
117 return mlir::success();
118}
119
120LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
121 Operation *rootOp = qs.getRootOp();
122 int matchCount = 0;
124
125 StringRef functionName = matcher.getFunctionName();
126 auto matches = finder.collectMatches(rootOp, std::move(matcher));
127
128 // An extract call is recognized by considering if the matcher has a name.
129 // TODO: Consider making the extract more explicit.
130 if (!functionName.empty()) {
131 std::vector<Operation *> flattenedMatches =
132 finder.flattenMatchedOps(matches);
133 Operation *function =
134 extractFunction(flattenedMatches, rootOp->getContext(), functionName);
135 if (failed(verify(function)))
136 return mlir::failure();
137 os << "\n" << *function << "\n\n";
138 function->erase();
139 return mlir::success();
140 }
141
142 os << "\n";
143 for (auto &results : matches) {
144 os << "Match #" << ++matchCount << ":\n\n";
145 for (Operation *op : results.matchedOps) {
146 if (op == results.rootOp) {
147 finder.printMatch(os, qs, op, "root");
148 } else {
149 finder.printMatch(os, qs, op);
150 }
151 }
152 }
153 os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
154 return mlir::success();
155}
156
157} // namespace mlir::query
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void loadDialect()
Load a dialect in the context.
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
result_iterator result_begin()
Definition Operation.h:413
result_iterator result_end()
Definition Operation.h:414
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
void erase()
Remove this operation from its parent block and delete it.
static QueryRef parse(llvm::StringRef line, const QuerySession &qs)
static std::vector< llvm::LineEditor::Completion > complete(llvm::StringRef line, size_t pos, const QuerySession &qs)
Finds and collects matches from the IR.
Definition MatchFinder.h:27
void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op) const
Prints the matched operation.
std::vector< MatchResult > collectMatches(Operation *root, DynMatcher matcher) const
Traverses the IR and returns a vector of MatchResult for each match of the matcher.
std::vector< Operation * > flattenMatchedOps(std::vector< MatchResult > &matches) const
Flattens a vector of MatchResult into a vector of operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition Query.cpp:21
llvm::IntrusiveRefCntPtr< Query > QueryRef
Definition Query.h:36
static Operation * extractFunction(std::vector< Operation * > &ops, MLIRContext *context, llvm::StringRef functionName)
Definition Query.cpp:32
std::vector< llvm::LineEditor::Completion > complete(llvm::StringRef line, size_t pos, const QuerySession &qs)
Definition Query.cpp:26
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:423
llvm::LogicalResult run(llvm::raw_ostream &os, QuerySession &qs) const override
Definition Query.cpp:106
llvm::LogicalResult run(llvm::raw_ostream &os, QuerySession &qs) const override
Definition Query.cpp:97
std::string errStr
Definition Query.h:50
llvm::LogicalResult run(llvm::raw_ostream &os, QuerySession &qs) const override
Definition Query.cpp:120
const matcher::DynMatcher matcher
Definition Query.h:97
llvm::LogicalResult run(llvm::raw_ostream &os, QuerySession &qs) const override
Definition Query.cpp:102
llvm::LogicalResult run(llvm::raw_ostream &os, QuerySession &qs) const override
Definition Query.cpp:115