25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/Support/Allocator.h"
29 #define GEN_PASS_DEF_REDUCTIONTREEPASS
30 #include "mlir/Reducer/Passes.h.inc"
41 bool eraseOpNotInRange) {
42 std::vector<Operation *> opsNotInRange;
43 std::vector<Operation *> opsInRange;
46 int index = op.index();
47 if (keepIndex < rangeToKeep.size() &&
48 index == rangeToKeep[keepIndex].second)
50 if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
51 opsNotInRange.push_back(&op.value());
53 opsInRange.push_back(&op.value());
69 if (eraseOpNotInRange)
83 template <
typename IteratorType>
86 const Tester &test,
bool eraseOpNotInRange) {
87 std::pair<Tester::Interestingness, size_t> initStatus =
92 return module.emitWarning() <<
"uninterested module will not be reduced";
94 llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
96 std::vector<ReductionNode::Range> ranges{
103 llvm_unreachable(
"unexpected initialization failure");
107 IteratorType iter(root);
109 while (iter != IteratorType::end()) {
119 smallestNode = ¤tNode;
128 trace.push_back(curNode);
129 while (curNode != root) {
131 trace.push_back(curNode);
135 while (!trace.empty()) {
141 llvm::report_fatal_error(
"Reduced module is not interesting");
143 llvm::report_fatal_error(
144 "Reduced module doesn't have consistent size with smallestNode");
148 template <
typename IteratorType>
157 if (
failed(findOptimal<IteratorType>(module, region, {}, test,
162 return findOptimal<IteratorType>(module, region,
patterns, test,
172 class ReductionPatternInterfaceCollection
191 class ReductionTreePass
192 :
public impl::ReductionTreePassBase<ReductionTreePass> {
196 LogicalResult initialize(
MLIRContext *context)
override;
199 void runOnOperation()
override;
202 LogicalResult reduceOp(ModuleOp module,
Region ®ion);
209 LogicalResult ReductionTreePass::initialize(
MLIRContext *context) {
211 ReductionPatternInterfaceCollection reducePatternCollection(context);
212 reducePatternCollection.populateReductionPatterns(
patterns);
213 reducerPatterns = std::move(
patterns);
217 void ReductionTreePass::runOnOperation() {
218 Operation *topOperation = getOperation();
221 ModuleOp module = dyn_cast<ModuleOp>(topOperation);
224 <<
"top-level op must be 'builtin.module'";
225 return signalPassFailure();
229 workList.push_back(getOperation());
236 if (
failed(reduceOp(module, region)))
237 return signalPassFailure();
242 workList.push_back(&op);
243 }
while (!workList.empty());
246 LogicalResult ReductionTreePass::reduceOp(ModuleOp module,
Region ®ion) {
247 Tester test(testerName, testerArgs);
248 switch (traversalModeId) {
250 return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
251 module, region, reducerPatterns, test);
253 return module.emitError() <<
"unsupported traversal mode detected";
static void applyPatterns(Region ®ion, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static LogicalResult findOptimal(ModuleOp module, Region ®ion, const FrozenRewritePatternSet &patterns, const Tester &test, bool eraseOpNotInRange)
We will apply the reducer patterns to the operations in the ranges specified by ReductionNode.
A collection of dialect interfaces within a context, for a given concrete interface type.
This is used to report the reduction patterns for a Dialect.
virtual void populateReductionPatterns(RewritePatternSet &patterns) const =0
Patterns provided here are intended to transform operations from a complex form to a simpler form,...
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteConfig & setStrictness(GreedyRewriteStrictness mode)
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
ReductionTreePass will build a reduction tree during module reduction and the ReductionNode represent...
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion)
Each Reduction Node contains a copy of module for applying rewrite patterns.
size_t getSize() const
Return the size of the module.
ModuleOp getModule() const
If the ReductionNode hasn't been tested the interestingness, it'll be the same module as the one in t...
Region & getRegion() const
Return the region we're reducing.
Tester::Interestingness isInteresting() const
Returns true if the module exhibits the interesting behavior.
ArrayRef< Range > getRanges() const
Return the range set we are using to generate variants.
ReductionNode * getParent() const
ArrayRef< Range > getStartRanges() const
Return the range information that how this node is reduced from the parent node.
void update(std::pair< Tester::Interestingness, size_t > result)
Update the interestingness result from tester.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
This class is used to keep track of the testing environment of the tool.
std::pair< Interestingness, size_t > isInteresting(ModuleOp module) const
Runs the interestingness testing script on a MLIR test case file.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
@ ExistingOps
Only pre-existing ops are processed.