25#include "llvm/ADT/ArrayRef.h"
26#include "llvm/Support/Allocator.h"
29#define GEN_PASS_DEF_REDUCTIONTREEPASS
30#include "mlir/Reducer/Passes.h.inc"
41 bool eraseOpNotInRange) {
42 std::vector<Operation *> opsNotInRange;
43 std::vector<Operation *> opsInRange;
45 for (
const auto &op : enumerate(region.
getOps())) {
46 int index = op.index();
47 if (keepIndex < rangeToKeep.size() &&
48 index == rangeToKeep[keepIndex].second)
50 if (keepIndex == rangeToKeep.size() ||
index < rangeToKeep[keepIndex].first)
51 opsNotInRange.push_back(&op.value());
53 opsInRange.push_back(&op.value());
69 if (eraseOpNotInRange)
83template <
typename IteratorType>
86 const Tester &test,
bool eraseOpNotInRange) {
87 std::pair<Tester::Interestingness, size_t> initStatus =
92 return module.emitWarning() << "uninterested module will not be reduced";
94 llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
96 std::vector<ReductionNode::Range> ranges{
103 llvm_unreachable(
"unexpected initialization failure");
107 IteratorType iter(root);
109 while (iter != IteratorType::end()) {
119 smallestNode = ¤tNode;
128 trace.push_back(curNode);
129 while (curNode != root) {
131 trace.push_back(curNode);
135 while (!trace.empty()) {
141 llvm::report_fatal_error(
"Reduced module is not interesting");
143 llvm::report_fatal_error(
144 "Reduced module doesn't have consistent size with smallestNode");
152 std::pair<Tester::Interestingness, size_t> initStatus =
158 return module.emitWarning() << "uninterested module will not be reduced";
159 llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
163 std::vector<ReductionNode::Range> ranges{{0, 0}};
173 llvm_unreachable(
"unexpected initialization failure");
187template <
typename IteratorType>
216class ReductionPatternInterfaceCollection
222 void populateReductionPatterns(RewritePatternSet &pattern,
223 Tester &tester)
const {
224 for (
const DialectReductionPatternInterface &interface : *
this) {
225 interface.populateReductionPatterns(pattern);
226 interface.populateReductionPatternsWithTester(pattern, tester);
238class ReductionTreePass
239 :
public impl::ReductionTreePassBase<ReductionTreePass> {
243 LogicalResult
initialize(MLIRContext *context)
override;
246 void runOnOperation()
override;
249 LogicalResult reduceOp(ModuleOp module, Region ®ion);
252 FrozenRewritePatternSet reducerPatterns;
257LogicalResult ReductionTreePass::initialize(
MLIRContext *context) {
261 RewritePatternSet patterns(context);
263 ReductionPatternInterfaceCollection reducePatternCollection(context);
264 reducePatternCollection.populateReductionPatterns(patterns, tester);
266 reducerPatterns = std::move(patterns);
270void ReductionTreePass::runOnOperation() {
271 Operation *topOperation = getOperation();
274 ModuleOp module = dyn_cast<ModuleOp>(topOperation);
277 <<
"top-level op must be 'builtin.module'";
278 return signalPassFailure();
281 SmallVector<Operation *, 8> workList;
282 workList.push_back(getOperation());
285 Operation *op = workList.pop_back_val();
289 if (
failed(reduceOp(module, region)))
290 return signalPassFailure();
293 for (Operation &op : region.getOps())
295 workList.push_back(&op);
296 }
while (!workList.empty());
299LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) {
300 switch (traversalModeId) {
301 case TraversalMode::SinglePath:
303 module, region, reducerPatterns, tester);
305 return module.emitError() << "unsupported traversal mode detected";
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static void applyPatterns(Region ®ion, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static LogicalResult eraseAllOpsInRegion(ModuleOp module, Region ®ion, const Tester &test)
This function attempts to erase all operations within the region currently being processed.
static LogicalResult findOptimal(ModuleOp module, Region ®ion, const FrozenRewritePatternSet &patterns, const Tester &test, bool eraseOpNotInRange)
We will apply the reducer patterns to the operations in the ranges specified by ReductionNode.
A collection of dialect interfaces within a context, for a given concrete interface type.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
ReductionTreePass will build a reduction tree during module reduction and the ReductionNode represent...
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion)
Each Reduction Node contains a copy of module for applying rewrite patterns.
ArrayRef< Range > getRanges() const
Return the range set we are using to generate variants.
size_t getSize() const
Return the size of the module.
ModuleOp getModule() const
If the ReductionNode hasn't been tested the interestingness, it'll be the same module as the one in t...
Region & getRegion() const
Return the region we're reducing.
Tester::Interestingness isInteresting() const
Returns true if the module exhibits the interesting behavior.
ReductionNode * getParent() const
ArrayRef< Range > getStartRanges() const
Return the range information that how this node is reduced from the parent node.
void update(std::pair< Tester::Interestingness, size_t > result)
Update the interestingness result from tester.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
iterator_range< OpIterator > getOps()
This class is used to keep track of the testing environment of the tool.
void setTestScriptArgs(ArrayRef< std::string > args)
std::pair< Interestingness, size_t > isInteresting(ModuleOp module) const
Runs the interestingness testing script on a MLIR test case file.
void setTestScript(StringRef script)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
@ ExistingOps
Only pre-existing ops are processed.