MLIR  19.0.0git
IntRangeOptimizations.cpp
Go to the documentation of this file.
1 //===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
13 
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
23 
24 namespace mlir::arith {
25 #define GEN_PASS_DEF_ARITHINTRANGEOPTS
26 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
27 } // namespace mlir::arith
28 
29 using namespace mlir;
30 using namespace mlir::arith;
31 using namespace mlir::dataflow;
32 
33 static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
34  Value value) {
35  auto *maybeInferredRange =
37  if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
38  return std::nullopt;
39  const ConstantIntRanges &inferredRange =
40  maybeInferredRange->getValue().getValue();
41  return inferredRange.getConstantValue();
42 }
43 
44 /// Patterned after SCCP
46  PatternRewriter &rewriter,
47  Value value) {
48  if (value.use_empty())
49  return failure();
50  std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
51  if (!maybeConstValue.has_value())
52  return failure();
53 
54  Operation *maybeDefiningOp = value.getDefiningOp();
55  Dialect *valueDialect =
56  maybeDefiningOp ? maybeDefiningOp->getDialect()
57  : value.getParentRegion()->getParentOp()->getDialect();
58  Attribute constAttr =
59  rewriter.getIntegerAttr(value.getType(), *maybeConstValue);
60  Operation *constOp = valueDialect->materializeConstant(
61  rewriter, constAttr, value.getType(), value.getLoc());
62  // Fall back to arith.constant if the dialect materializer doesn't know what
63  // to do with an integer constant.
64  if (!constOp)
65  constOp = rewriter.getContext()
66  ->getLoadedDialect<ArithDialect>()
67  ->materializeConstant(rewriter, constAttr, value.getType(),
68  value.getLoc());
69  if (!constOp)
70  return failure();
71 
72  rewriter.replaceAllUsesWith(value, constOp->getResult(0));
73  return success();
74 }
75 
76 namespace {
77 class DataFlowListener : public RewriterBase::Listener {
78 public:
79  DataFlowListener(DataFlowSolver &s) : s(s) {}
80 
81 protected:
82  void notifyOperationErased(Operation *op) override {
83  s.eraseState(op);
84  for (Value res : op->getResults())
85  s.eraseState(res);
86  }
87 
88  DataFlowSolver &s;
89 };
90 
91 /// Rewrite any results of `op` that were inferred to be constant integers to
92 /// and replace their uses with that constant. Return success() if all results
93 /// where thus replaced and the operation is erased. Also replace any block
94 /// arguments with their constant values.
95 struct MaterializeKnownConstantValues : public RewritePattern {
96  MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
97  : RewritePattern(Pattern::MatchAnyOpTypeTag(), /*benefit=*/1, context),
98  solver(s) {}
99 
100  LogicalResult match(Operation *op) const override {
101  if (matchPattern(op, m_Constant()))
102  return failure();
103 
104  auto needsReplacing = [&](Value v) {
105  return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
106  };
107  bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
108  if (op->getNumRegions() == 0)
109  return success(hasConstantResults);
110  bool hasConstantRegionArgs = false;
111  for (Region &region : op->getRegions()) {
112  for (Block &block : region.getBlocks()) {
113  hasConstantRegionArgs |=
114  llvm::any_of(block.getArguments(), needsReplacing);
115  }
116  }
117  return success(hasConstantResults || hasConstantRegionArgs);
118  }
119 
120  void rewrite(Operation *op, PatternRewriter &rewriter) const override {
121  bool replacedAll = (op->getNumResults() != 0);
122  for (Value v : op->getResults())
123  replacedAll &=
124  (succeeded(maybeReplaceWithConstant(solver, rewriter, v)) ||
125  v.use_empty());
126  if (replacedAll && isOpTriviallyDead(op)) {
127  rewriter.eraseOp(op);
128  return;
129  }
130 
131  PatternRewriter::InsertionGuard guard(rewriter);
132  for (Region &region : op->getRegions()) {
133  for (Block &block : region.getBlocks()) {
134  rewriter.setInsertionPointToStart(&block);
135  for (BlockArgument &arg : block.getArguments()) {
136  (void)maybeReplaceWithConstant(solver, rewriter, arg);
137  }
138  }
139  }
140  }
141 
142 private:
143  DataFlowSolver &solver;
144 };
145 
146 template <typename RemOp>
147 struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
148  DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s)
149  : OpRewritePattern<RemOp>(context), solver(s) {}
150 
151  LogicalResult matchAndRewrite(RemOp op,
152  PatternRewriter &rewriter) const override {
153  Value lhs = op.getOperand(0);
154  Value rhs = op.getOperand(1);
155  auto maybeModulus = getConstantIntValue(rhs);
156  if (!maybeModulus.has_value())
157  return failure();
158  int64_t modulus = *maybeModulus;
159  if (modulus <= 0)
160  return failure();
161  auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
162  if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
163  return failure();
164  const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
165  const APInt &min = isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
166  const APInt &max = isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
167  // The minima and maxima here are given as closed ranges, we must be
168  // strictly less than the modulus.
169  if (min.isNegative() || min.uge(modulus))
170  return failure();
171  if (max.isNegative() || max.uge(modulus))
172  return failure();
173  if (!min.ule(max))
174  return failure();
175 
176  // With all those conditions out of the way, we know thas this invocation of
177  // a remainder is a noop because the input is strictly within the range
178  // [0, modulus), so get rid of it.
179  rewriter.replaceOp(op, ValueRange{lhs});
180  return success();
181  }
182 
183 private:
184  DataFlowSolver &solver;
185 };
186 
187 struct IntRangeOptimizationsPass
188  : public arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
189 
190  void runOnOperation() override {
191  Operation *op = getOperation();
192  MLIRContext *ctx = op->getContext();
193  DataFlowSolver solver;
194  solver.load<DeadCodeAnalysis>();
195  solver.load<IntegerRangeAnalysis>();
196  if (failed(solver.initializeAndRun(op)))
197  return signalPassFailure();
198 
199  DataFlowListener listener(solver);
200 
201  RewritePatternSet patterns(ctx);
202  populateIntRangeOptimizationsPatterns(patterns, solver);
203 
204  GreedyRewriteConfig config;
205  config.listener = &listener;
206 
207  if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
208  signalPassFailure();
209  }
210 };
211 } // namespace
212 
214  RewritePatternSet &patterns, DataFlowSolver &solver) {
215  patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
216  DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
217 }
218 
220  return std::make_unique<IntRangeOptimizationsPass>();
221 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static std::optional< APInt > getMaybeConstantValue(DataFlowSolver &solver, Value value)
static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, PatternRewriter &rewriter, Value value)
Patterned after SCCP.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:31
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
MLIRContext * getContext() const
Definition: Builders.h:55
A set of arbitrary-precision integers representing bounds on a given integer value.
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
std::optional< APInt > getConstantValue() const
If either the signed or unsigned interpretations of the range indicate that the value it bounds is a ...
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
The general data-flow analysis solver.
void eraseState(PointT point)
Erase any analysis state associated with the given program point.
const StateT * lookupState(PointT point) const
Lookup an analysis state for the given program point.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition: Dialect.h:86
This class allows control over how the GreedyPatternRewriteDriver works.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
RewritePattern is the common base class for all DAG to DAG replacements.
Definition: PatternMatch.h:246
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
Dead code analysis analyzes control-flow, as understood by RegionBranchOpInterface and BranchOpInterf...
Integer range analysis determines the integer value range of SSA values using operations that define ...
This lattice element represents the integer value range of an SSA value.
std::unique_ptr< Pass > createIntRangeOptimizationsPass()
Create a pass which do optimizations based on integer range analysis.
void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns, DataFlowSolver &solver)
Add patterns for int range based optimizations.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358