MLIR 23.0.0git
IntegerDivisibilityAnalysis.cpp
Go to the documentation of this file.
1//===- IntegerDivisibilityAnalysis.cpp - Integer divisibility ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the dataflow analysis class for integer divisibility
10// inference. Operations participate in the analysis by implementing
11// `InferIntDivisibilityOpInterface`.
12//
13//===----------------------------------------------------------------------===//
14
16
19#include "llvm/Support/Debug.h"
20
21#define DEBUG_TYPE "int-divisibility-analysis"
22
23using llvm::dbgs;
24
25namespace mlir::dataflow {
26
32
36 auto inferrable = dyn_cast<InferIntDivisibilityOpInterface>(op);
37 if (!inferrable) {
38 setAllToEntryStates(results);
39 return success();
40 }
41
42 LLVM_DEBUG(dbgs() << "Inferring divisibility for " << *op << "\n");
43 auto argDivs = llvm::map_to_vector(
44 operands, [](const IntegerDivisibilityLattice *lattice) {
45 return lattice->getValue();
46 });
47 auto joinCallback = [&](Value v, const IntegerDivisibility &newDiv) {
48 auto result = dyn_cast<OpResult>(v);
49 if (!result) {
50 return;
51 }
52 assert(llvm::is_contained(op->getResults(), result));
53
54 LLVM_DEBUG(dbgs() << "Inferred divisibility " << newDiv << "\n");
55 IntegerDivisibilityLattice *lattice = results[result.getResultNumber()];
56 IntegerDivisibility oldDiv = lattice->getValue();
57
58 ChangeResult changed = lattice->join(newDiv);
59
60 // Catch loop results with loop-variant divisibility and conservatively
61 // set them to divisibility 1 (no information) so we don't ratchet
62 // indefinitely (the dataflow analysis in MLIR doesn't attempt to work
63 // out trip counts and often can't).
64 bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
65 return op->hasTrait<OpTrait::IsTerminator>();
66 });
67 if (isYieldedResult && !oldDiv.isUninitialized() &&
68 !(lattice->getValue() == oldDiv)) {
69 LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
70 changed |= lattice->join(IntegerDivisibility::getMinDivisibility());
71 }
72 propagateIfChanged(lattice, changed);
73 };
74
75 inferrable.inferResultDivisibility(argDivs, joinCallback);
76 return success();
77}
78
80 Operation *op, const RegionSuccessor &successor, ValueRange successorInputs,
82 // Get the constant divisibility, or query the lattice for Values.
83 auto getDivFromOfr = [&](std::optional<OpFoldResult> ofr, Block *block,
84 bool isUnsigned) -> uint64_t {
85 if (ofr.has_value()) {
86 if (auto constBound = getConstantIntValue(*ofr)) {
87 return constBound.value();
88 }
89 auto value = cast<Value>(ofr.value());
90 const IntegerDivisibilityLattice *lattice =
92 if (lattice != nullptr && !lattice->getValue().isUninitialized()) {
93 return isUnsigned ? lattice->getValue().getValue().udiv()
94 : lattice->getValue().getValue().sdiv();
95 }
96 }
97 return isUnsigned
100 };
101
102 // Infer bounds for loop arguments that have static bounds
103 if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
104 std::optional<SmallVector<Value>> ivs = loop.getLoopInductionVars();
105 std::optional<SmallVector<OpFoldResult>> lbs = loop.getLoopLowerBounds();
106 std::optional<SmallVector<OpFoldResult>> steps = loop.getLoopSteps();
107 if (!ivs || !lbs || !steps) {
109 op, successor, successorInputs, argLattices);
110 }
111 for (auto [iv, lb, step] : llvm::zip_equal(*ivs, *lbs, *steps)) {
113 Block *block = iv.getParentBlock();
114 uint64_t stepUDiv = getDivFromOfr(step, block, /*unsigned=*/true);
115 uint64_t stepSDiv = getDivFromOfr(step, block, /*unsigned=*/false);
116 uint64_t lbUDiv = getDivFromOfr(lb, block, /*unsigned=*/true);
117 uint64_t lbSDiv = getDivFromOfr(lb, block, /*unsigned=*/false);
118 ConstantIntDivisibility lbDiv(lbUDiv, lbSDiv);
119 ConstantIntDivisibility stepDiv(stepUDiv, stepSDiv);
120
121 // Loop induction variables are computed as `lb + i * step`. The
122 // divisibility for `i * step` is just the divisibility of `step`, so
123 // the total divisibility is obtained by unioning the step divisibility
124 // with the lower bound divisibility, which takes the GCD of the two.
125 ConstantIntDivisibility ivDiv = stepDiv.getUnion(lbDiv);
126 propagateIfChanged(ivEntry, ivEntry->join(ivDiv));
127 }
128 return;
129 }
130
132 op, successor, successorInputs, argLattices);
133}
134
135} // namespace mlir::dataflow
return success()
Block represents an ordered list of Operations.
Definition Block.h:33
Statically known divisibility information for an integer SSA value.
ConstantIntDivisibility getUnion(const ConstantIntDivisibility &other) const
ProgramPoint * getProgramPointBefore(Operation *op)
Get a uniqued program point instance.
void propagateIfChanged(AnalysisState *state, ChangeResult changed)
Propagate an update to a state if it changed.
This lattice value represents the integer divisibility of an SSA value.
const ConstantIntDivisibility & getValue() const
static IntegerDivisibility getMinDivisibility()
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
result_range getResults()
Definition Operation.h:440
This class represents a successor of a region.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
user_range getUsers() const
Definition Value.h:218
void visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor, ValueRange successorInputs, ArrayRef< IntegerDivisibilityLattice * > argLattices) override
Visit block arguments or operation results of an operation with region control-flow for which values ...
void setToEntryState(IntegerDivisibilityLattice *lattice) override
At an entry point, set the lattice to the most pessimistic state, indicating that no further reasonin...
LogicalResult visitOperation(Operation *op, ArrayRef< const IntegerDivisibilityLattice * > operands, ArrayRef< IntegerDivisibilityLattice * > results) override
Visit an operation, invoking the transfer function.
This lattice element represents the integer divisibility of an SSA value.
ValueT & getValue()
Return the value held by this lattice.
ChangeResult join(const AbstractSparseLattice &rhs) override
Join the information contained in the 'rhs' lattice into this lattice.
IntegerDivisibilityLattice * getLatticeElement(Value value) override
void setAllToEntryStates(ArrayRef< IntegerDivisibilityLattice * > lattices)
virtual void visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor, ValueRange nonSuccessorInputs, ArrayRef< StateT * > nonSuccessorInputLattices)
Given an operation with possible region control-flow, the lattices of the operands,...
const IntegerDivisibilityLattice * getLatticeElementFor(ProgramPoint *point, Value value)
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
ChangeResult
A result type used to indicate if a change happened.