18#include "llvm/Support/CommandLine.h"
23#define GEN_PASS_DEF_AFFINELOOPUNROLL
24#include "mlir/Dialect/Affine/Transforms/Passes.h.inc"
28#define DEBUG_TYPE "affine-loop-unroll"
42struct LoopUnroll :
public affine::impl::AffineLoopUnrollBase<LoopUnroll> {
45 const std::function<
unsigned(AffineForOp)> getUnrollFactor;
47 LoopUnroll() : getUnrollFactor(
nullptr) {}
48 LoopUnroll(
const LoopUnroll &other) =
default;
50 std::optional<unsigned> unrollFactor = std::nullopt,
51 bool unrollUpToFactor =
false,
52 const std::function<
unsigned(AffineForOp)> &getUnrollFactor =
nullptr)
53 : getUnrollFactor(getUnrollFactor) {
55 this->unrollFactor = *unrollFactor;
56 this->unrollUpToFactor = unrollUpToFactor;
59 void runOnOperation()
override;
62 LogicalResult runOnAffineForOp(AffineForOp forOp);
69 ->walk([&](AffineForOp nestedForOp) {
78 f.walk([&](AffineForOp forOp) {
80 loops.push_back(forOp);
84void LoopUnroll::runOnOperation() {
85 if (!(unrollFactor.getValue() > 0 || unrollFactor.getValue() == -1)) {
87 "Invalid option: 'unroll-factor' should be greater than 0 or "
89 return signalPassFailure();
91 FunctionOpInterface func = getOperation();
92 if (func.isExternal())
95 if (unrollFactor.getValue() == -1 && unrollFullThreshold.hasValue()) {
97 SmallVector<AffineForOp, 4> loops;
102 getOperation().walk([&](AffineForOp forOp) {
103 std::optional<APInt> tripCount = forOp.getStaticTripCount();
104 if (tripCount && tripCount->getZExtValue() <= unrollFullThreshold)
105 loops.push_back(forOp);
107 for (
auto forOp : loops)
113 SmallVector<AffineForOp, 4> loops;
114 for (
unsigned i = 0; i < numRepetitions || getUnrollFactor; i++) {
119 bool unrolled =
false;
120 for (
auto forOp : loops)
121 unrolled |= succeeded(runOnAffineForOp(forOp));
130LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) {
134 nullptr, cleanUpUnroll);
136 if (unrollFactor.getValue() == -1)
139 if (unrollUpToFactor)
145std::unique_ptr<InterfacePass<FunctionOpInterface>>
147 int unrollFactor,
bool unrollUpToFactor,
148 const std::function<
unsigned(AffineForOp)> &getUnrollFactor) {
149 return std::make_unique<LoopUnroll>(
150 unrollFactor == -1 ? std::nullopt : std::optional<unsigned>(unrollFactor),
151 unrollUpToFactor, getUnrollFactor);
static void gatherInnermostLoops(FunctionOpInterface f, SmallVectorImpl< AffineForOp > &loops)
Gathers loops that have no affine.for's nested within.
static bool isInnermostAffineForOp(AffineForOp op)
Returns true if no other affine.for ops are nested within op.
static WalkResult interrupt()
LogicalResult loopUnrollFull(AffineForOp forOp)
Unrolls this for operation completely if the trip count is known to be constant.
LogicalResult loopUnrollByFactor(AffineForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr, bool cleanUpUnroll=false)
Unrolls this for operation by the specified unroll factor.
LogicalResult loopUnrollUpToFactor(AffineForOp forOp, uint64_t unrollFactor)
Unrolls this loop by the specified unroll factor or its trip count, whichever is lower.
std::unique_ptr< InterfacePass< FunctionOpInterface > > createLoopUnrollPass(int unrollFactor=-1, bool unrollUpToFactor=false, const std::function< unsigned(AffineForOp)> &getUnrollFactor=nullptr)
Creates a loop unrolling pass with the provided parameters.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.