26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/Support/Allocator.h"
29 #include "llvm/Support/ManagedStatic.h"
32 #define GEN_PASS_DEF_REDUCTIONTREE
33 #include "mlir/Reducer/Passes.h.inc"
44 bool eraseOpNotInRange) {
45 std::vector<Operation *> opsNotInRange;
46 std::vector<Operation *> opsInRange;
49 int index = op.index();
50 if (keepIndex < rangeToKeep.size() &&
51 index == rangeToKeep[keepIndex].second)
53 if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
54 opsNotInRange.push_back(&op.value());
56 opsInRange.push_back(&op.value());
72 if (eraseOpNotInRange)
86 template <
typename IteratorType>
89 const Tester &test,
bool eraseOpNotInRange) {
90 std::pair<Tester::Interestingness, size_t> initStatus =
95 return module.emitWarning() <<
"uninterested module will not be reduced";
97 llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
99 std::vector<ReductionNode::Range> ranges{
106 llvm_unreachable(
"unexpected initialization failure");
110 IteratorType iter(root);
112 while (iter != IteratorType::end()) {
122 smallestNode = ¤tNode;
131 trace.push_back(curNode);
132 while (curNode != root) {
134 trace.push_back(curNode);
138 while (!trace.empty()) {
144 llvm::report_fatal_error(
"Reduced module is not interesting");
146 llvm::report_fatal_error(
147 "Reduced module doesn't have consistent size with smallestNode");
151 template <
typename IteratorType>
160 if (failed(findOptimal<IteratorType>(module, region, {}, test,
165 return findOptimal<IteratorType>(module, region,
patterns, test,
175 class ReductionPatternInterfaceCollection
194 class ReductionTreePass :
public impl::ReductionTreeBase<ReductionTreePass> {
196 ReductionTreePass() =
default;
197 ReductionTreePass(
const ReductionTreePass &pass) =
default;
199 LogicalResult initialize(
MLIRContext *context)
override;
202 void runOnOperation()
override;
205 LogicalResult reduceOp(ModuleOp module,
Region ®ion);
212 LogicalResult ReductionTreePass::initialize(
MLIRContext *context) {
214 ReductionPatternInterfaceCollection reducePatternCollection(context);
215 reducePatternCollection.populateReductionPatterns(
patterns);
216 reducerPatterns = std::move(
patterns);
220 void ReductionTreePass::runOnOperation() {
221 Operation *topOperation = getOperation();
224 ModuleOp module = dyn_cast<ModuleOp>(topOperation);
227 <<
"top-level op must be 'builtin.module'";
228 return signalPassFailure();
232 workList.push_back(getOperation());
239 if (failed(reduceOp(module, region)))
240 return signalPassFailure();
245 workList.push_back(&op);
246 }
while (!workList.empty());
249 LogicalResult ReductionTreePass::reduceOp(ModuleOp module,
Region ®ion) {
250 Tester test(testerName, testerArgs);
251 switch (traversalModeId) {
253 return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
254 module, region, reducerPatterns, test);
256 return module.emitError() <<
"unsupported traversal mode detected";
261 return std::make_unique<ReductionTreePass>();
static void applyPatterns(Region ®ion, const FrozenRewritePatternSet &patterns, ArrayRef< ReductionNode::Range > rangeToKeep, bool eraseOpNotInRange)
We implicitly number each operation in the region and if an operation's number falls into rangeToKeep...
static LogicalResult findOptimal(ModuleOp module, Region ®ion, const FrozenRewritePatternSet &patterns, const Tester &test, bool eraseOpNotInRange)
We will apply the reducer patterns to the operations in the ranges specified by ReductionNode.
A collection of dialect interfaces within a context, for a given concrete interface type.
This is used to report the reduction patterns for a Dialect.
virtual void populateReductionPatterns(RewritePatternSet &patterns) const =0
Patterns provided here are intended to transform operations from a complex form to a simpler form,...
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
ReductionTreePass will build a reduction tree during module reduction and the ReductionNode represent...
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion)
Each Reduction Node contains a copy of module for applying rewrite patterns.
size_t getSize() const
Return the size of the module.
ModuleOp getModule() const
If the ReductionNode hasn't been tested the interestingness, it'll be the same module as the one in t...
Region & getRegion() const
Return the region we're reducing.
Tester::Interestingness isInteresting() const
Returns true if the module exhibits the interesting behavior.
ArrayRef< Range > getRanges() const
Return the range set we are using to generate variants.
ReductionNode * getParent() const
ArrayRef< Range > getStartRanges() const
Return the range information that how this node is reduced from the parent node.
void update(std::pair< Tester::Interestingness, size_t > result)
Update the interestingness result from tester.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
This class is used to keep track of the testing environment of the tool.
std::pair< Interestingness, size_t > isInteresting(ModuleOp module) const
Runs the interestingness testing script on a MLIR test case file.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::unique_ptr< Pass > createReductionTreePass()
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
@ ExistingOps
Only pre-existing ops are processed.