MLIR  22.0.0git
Query.cpp
Go to the documentation of this file.
1 //===---- Query.cpp - -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Query/Query.h"
10 #include "QueryParser.h"
12 #include "mlir/IR/IRMapping.h"
13 #include "mlir/IR/Verifier.h"
16 #include "llvm/Support/SourceMgr.h"
17 #include "llvm/Support/raw_ostream.h"
18 
19 namespace mlir::query {
20 
21 QueryRef parse(llvm::StringRef line, const QuerySession &qs) {
22  return QueryParser::parse(line, qs);
23 }
24 
25 std::vector<llvm::LineEditor::Completion>
26 complete(llvm::StringRef line, size_t pos, const QuerySession &qs) {
27  return QueryParser::complete(line, pos, qs);
28 }
29 
30 // TODO: Extract into a helper function that can be reused outside query
31 // context.
32 static Operation *extractFunction(std::vector<Operation *> &ops,
33  MLIRContext *context,
34  llvm::StringRef functionName) {
35  context->loadDialect<func::FuncDialect>();
36  OpBuilder builder(context);
37 
38  // Collect data for function creation
39  std::vector<Operation *> slice;
40  std::vector<Value> values;
41  std::vector<Type> outputTypes;
42 
43  for (auto *op : ops) {
44  // Return op's operands are propagated, but the op itself isn't needed.
45  if (!isa<func::ReturnOp>(op))
46  slice.push_back(op);
47 
48  // All results are returned by the extracted function.
49  llvm::append_range(outputTypes, op->getResults().getTypes());
50 
51  // Track all values that need to be taken as input to function.
52  llvm::append_range(values, op->getOperands());
53  }
54 
55  // Create the function
56  FunctionType funcType =
57  builder.getFunctionType(TypeRange(ValueRange(values)), outputTypes);
58  auto loc = builder.getUnknownLoc();
59  func::FuncOp funcOp = func::FuncOp::create(loc, functionName, funcType);
60 
61  builder.setInsertionPointToEnd(funcOp.addEntryBlock());
62 
63  // Map original values to function arguments
64  IRMapping mapper;
65  for (const auto &arg : llvm::enumerate(values))
66  mapper.map(arg.value(), funcOp.getArgument(arg.index()));
67 
68  // Clone operations and build function body
69  std::vector<Operation *> clonedOps;
70  std::vector<Value> clonedVals;
71  // TODO: Handle extraction of operations with compute payloads defined via
72  // regions.
73  for (Operation *slicedOp : slice) {
74  Operation *clonedOp =
75  clonedOps.emplace_back(builder.clone(*slicedOp, mapper));
76  clonedVals.insert(clonedVals.end(), clonedOp->result_begin(),
77  clonedOp->result_end());
78  }
79  // Add return operation
80  func::ReturnOp::create(builder, loc, clonedVals);
81 
82  // Remove unused function arguments
83  size_t currentIndex = 0;
84  while (currentIndex < funcOp.getNumArguments()) {
85  // Erase if possible.
86  if (funcOp.getArgument(currentIndex).use_empty())
87  if (succeeded(funcOp.eraseArgument(currentIndex)))
88  continue;
89  ++currentIndex;
90  }
91 
92  return funcOp;
93 }
94 
95 Query::~Query() = default;
96 
97 LogicalResult InvalidQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
98  os << errStr << "\n";
99  return mlir::failure();
100 }
101 
102 LogicalResult NoOpQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
103  return mlir::success();
104 }
105 
106 LogicalResult HelpQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
107  os << "Available commands:\n\n"
108  " match MATCHER, m MATCHER "
109  "Match the mlir against the given matcher.\n"
110  " quit "
111  "Terminates the query session.\n\n";
112  return mlir::success();
113 }
114 
115 LogicalResult QuitQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
116  qs.terminate = true;
117  return mlir::success();
118 }
119 
120 LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
121  Operation *rootOp = qs.getRootOp();
122  int matchCount = 0;
123  matcher::MatchFinder finder;
124  auto matches = finder.collectMatches(rootOp, std::move(matcher));
125 
126  // An extract call is recognized by considering if the matcher has a name.
127  // TODO: Consider making the extract more explicit.
128  if (matcher.hasFunctionName()) {
129  auto functionName = matcher.getFunctionName();
130  std::vector<Operation *> flattenedMatches =
131  finder.flattenMatchedOps(matches);
132  Operation *function =
133  extractFunction(flattenedMatches, rootOp->getContext(), functionName);
134  if (failed(verify(function)))
135  return mlir::failure();
136  os << "\n" << *function << "\n\n";
137  function->erase();
138  return mlir::success();
139  }
140 
141  os << "\n";
142  for (auto &results : matches) {
143  os << "Match #" << ++matchCount << ":\n\n";
144  for (Operation *op : results.matchedOps) {
145  if (op == results.rootOp) {
146  finder.printMatch(os, qs, op, "root");
147  } else {
148  finder.printMatch(os, qs, op);
149  }
150  }
151  }
152  os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n");
153  return mlir::success();
154 }
155 
156 } // namespace mlir::query
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void loadDialect()
Load a dialect in the context.
Definition: MLIRContext.h:110
This class helps build Operations.
Definition: Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_iterator result_begin()
Definition: Operation.h:413
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
result_iterator result_end()
Definition: Operation.h:414
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
static QueryRef parse(llvm::StringRef line, const QuerySession &qs)
static std::vector< llvm::LineEditor::Completion > complete(llvm::StringRef line, size_t pos, const QuerySession &qs)
Finds and collects matches from the IR.
Definition: MatchFinder.h:27
void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op) const
Prints the matched operation.
Definition: MatchFinder.cpp:39
std::vector< MatchResult > collectMatches(Operation *root, DynMatcher matcher) const
Traverses the IR and returns a vector of MatchResult for each match of the matcher.
Definition: MatchFinder.cpp:21
std::vector< Operation * > flattenMatchedOps(std::vector< MatchResult > &matches) const
Flattens a vector of MatchResult into a vector of operations.
Definition: MatchFinder.cpp:59
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
static Operation * extractFunction(std::vector< Operation * > &ops, MLIRContext *context, llvm::StringRef functionName)
Definition: Query.cpp:32
llvm::IntrusiveRefCntPtr< Query > QueryRef
Definition: Query.h:36
std::vector< llvm::LineEditor::Completion > complete(llvm::StringRef line, size_t pos, const QuerySession &qs)
Definition: Query.cpp:26
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
llvm::LogicalResult run(llvm::raw_ostream &os, QuerySession &qs) const override
Definition: Query.cpp:106
llvm::LogicalResult run(llvm::raw_ostream &os, QuerySession &qs) const override
Definition: Query.cpp:97
std::string errStr
Definition: Query.h:50
llvm::LogicalResult run(llvm::raw_ostream &os, QuerySession &qs) const override
Definition: Query.cpp:120
const matcher::DynMatcher matcher
Definition: Query.h:97
llvm::LogicalResult run(llvm::raw_ostream &os, QuerySession &qs) const override
Definition: Query.cpp:102
llvm::LogicalResult run(llvm::raw_ostream &os, QuerySession &qs) const override
Definition: Query.cpp:115