25#include "llvm/ADT/ArrayRef.h"
26#include "llvm/Support/Allocator.h"
29#define GEN_PASS_DEF_REDUCTIONTREEPASS
30#include "mlir/Reducer/Passes.h.inc"
41 bool eraseOpNotInRange) {
42 std::vector<Operation *> opsNotInRange;
44 for (
const auto &op : enumerate(region.
getOps())) {
45 int index = op.index();
46 if (keepIndex < rangeToKeep.size() &&
47 index == rangeToKeep[keepIndex].second)
49 if (keepIndex == rangeToKeep.size() ||
index < rangeToKeep[keepIndex].first)
50 opsNotInRange.push_back(&op.value());
57 if (!eraseOpNotInRange)
67 if (eraseOpNotInRange)
81template <
typename IteratorType>
84 const Tester &test,
bool eraseOpNotInRange) {
85 std::pair<Tester::Interestingness, size_t> initStatus =
90 return module.emitError() << "uninterested module will not be reduced";
92 llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
94 std::vector<ReductionNode::Range> ranges{
101 llvm_unreachable(
"unexpected initialization failure");
105 IteratorType iter(root);
107 while (iter != IteratorType::end()) {
117 smallestNode = ¤tNode;
126 trace.push_back(curNode);
127 while (curNode != root) {
129 trace.push_back(curNode);
133 while (!trace.empty()) {
139 llvm::report_fatal_error(
"Reduced module is not interesting");
141 llvm::report_fatal_error(
142 "Reduced module doesn't have consistent size with smallestNode");
150 std::pair<Tester::Interestingness, size_t> initStatus =
156 return module.emitError() << "uninterested module will not be reduced";
157 llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
161 std::vector<ReductionNode::Range> ranges{{0, 0}};
171 llvm_unreachable(
"unexpected initialization failure");
185template <
typename IteratorType>
214class ReductionPatternInterfaceCollection
222 for (
const DialectReductionPatternInterface &interface : *
this) {
223 interface.populateReductionPatterns(pattern);
224 interface.populateReductionPatternsWithTester(pattern, tester);
236class ReductionTreePass
244 void runOnOperation()
override;
247 LogicalResult reduceOp(ModuleOp module,
Region ®ion);
255LogicalResult ReductionTreePass::initialize(
MLIRContext *context) {
259 RewritePatternSet patterns(context);
261 ReductionPatternInterfaceCollection reducePatternCollection(context);
262 reducePatternCollection.populateReductionPatterns(patterns, tester);
264 reducerPatterns = std::move(patterns);
268void ReductionTreePass::runOnOperation() {
269 Operation *topOperation = getOperation();
272 ModuleOp module = dyn_cast<ModuleOp>(topOperation);
275 <<
"top-level op must be 'builtin.module'";
276 return signalPassFailure();
279 SmallVector<Operation *, 8> workList;
280 workList.push_back(getOperation());
283 Operation *op = workList.pop_back_val();
287 if (
failed(reduceOp(module, region)))
288 return signalPassFailure();
291 for (Operation &op : region.
getOps())
293 workList.push_back(&op);
294 }
while (!workList.empty());
297LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) {
298 switch (traversalModeId) {
299 case TraversalMode::SinglePath:
301 module, region, reducerPatterns, tester);
303 return module.emitError() << "unsupported traversal mode detected";
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static void applyPatterns(Region ®ion, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static LogicalResult eraseAllOpsInRegion(ModuleOp module, Region ®ion, const Tester &test)
This function attempts to erase all operations within the region currently being processed.
static LogicalResult findOptimal(ModuleOp module, Region ®ion, const FrozenRewritePatternSet &patterns, const Tester &test, bool eraseOpNotInRange)
We will apply the reducer patterns to the operations in the ranges specified by ReductionNode.
A collection of dialect interfaces within a context, for a given concrete interface type.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
ReductionTreePass will build a reduction tree during module reduction and the ReductionNode represent...
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion)
Each Reduction Node contains a copy of module for applying rewrite patterns.
ArrayRef< Range > getRanges() const
Return the range set we are using to generate variants.
size_t getSize() const
Return the size of the module.
ModuleOp getModule() const
If the ReductionNode hasn't been tested the interestingness, it'll be the same module as the one in t...
Region & getRegion() const
Return the region we're reducing.
Tester::Interestingness isInteresting() const
Returns true if the module exhibits the interesting behavior.
ReductionNode * getParent() const
ArrayRef< Range > getStartRanges() const
Return the range information that how this node is reduced from the parent node.
void update(std::pair< Tester::Interestingness, size_t > result)
Update the interestingness result from tester.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
iterator_range< OpIterator > getOps()
This class is used to keep track of the testing environment of the tool.
void setTestScriptArgs(ArrayRef< std::string > args)
std::pair< Interestingness, size_t > isInteresting(Operation *topOp) const
Runs the interestingness testing script on a MLIR test case file.
void setTestScript(StringRef script)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
@ ExistingOps
Only pre-existing ops are processed.