MLIR  19.0.0git
IntRangeOptimizations.cpp
Go to the documentation of this file.
1 //===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
12 
17 
18 namespace mlir::arith {
19 #define GEN_PASS_DEF_ARITHINTRANGEOPTS
20 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
21 } // namespace mlir::arith
22 
23 using namespace mlir;
24 using namespace mlir::arith;
25 using namespace mlir::dataflow;
26 
27 /// Returns true if 2 integer ranges have intersection.
28 static bool intersects(const ConstantIntRanges &lhs,
29  const ConstantIntRanges &rhs) {
30  return !((lhs.smax().slt(rhs.smin()) || lhs.smin().sgt(rhs.smax())) &&
31  (lhs.umax().ult(rhs.umin()) || lhs.umin().ugt(rhs.umax())));
32 }
33 
35  if (!intersects(lhs, rhs))
36  return false;
37 
38  return failure();
39 }
40 
42  if (!intersects(lhs, rhs))
43  return true;
44 
45  return failure();
46 }
47 
49  if (lhs.smax().slt(rhs.smin()))
50  return true;
51 
52  if (lhs.smin().sge(rhs.smax()))
53  return false;
54 
55  return failure();
56 }
57 
59  if (lhs.smax().sle(rhs.smin()))
60  return true;
61 
62  if (lhs.smin().sgt(rhs.smax()))
63  return false;
64 
65  return failure();
66 }
67 
69  return handleSlt(std::move(rhs), std::move(lhs));
70 }
71 
73  return handleSle(std::move(rhs), std::move(lhs));
74 }
75 
77  if (lhs.umax().ult(rhs.umin()))
78  return true;
79 
80  if (lhs.umin().uge(rhs.umax()))
81  return false;
82 
83  return failure();
84 }
85 
87  if (lhs.umax().ule(rhs.umin()))
88  return true;
89 
90  if (lhs.umin().ugt(rhs.umax()))
91  return false;
92 
93  return failure();
94 }
95 
97  return handleUlt(std::move(rhs), std::move(lhs));
98 }
99 
101  return handleUle(std::move(rhs), std::move(lhs));
102 }
103 
104 namespace {
105 struct ConvertCmpOp : public OpRewritePattern<arith::CmpIOp> {
106 
107  ConvertCmpOp(MLIRContext *context, DataFlowSolver &s)
108  : OpRewritePattern<arith::CmpIOp>(context), solver(s) {}
109 
110  LogicalResult matchAndRewrite(arith::CmpIOp op,
111  PatternRewriter &rewriter) const override {
112  auto *lhsResult =
113  solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getLhs());
114  if (!lhsResult || lhsResult->getValue().isUninitialized())
115  return failure();
116 
117  auto *rhsResult =
118  solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getRhs());
119  if (!rhsResult || rhsResult->getValue().isUninitialized())
120  return failure();
121 
122  using HandlerFunc =
124  std::array<HandlerFunc, arith::getMaxEnumValForCmpIPredicate() + 1>
125  handlers{};
126  using Pred = arith::CmpIPredicate;
127  handlers[static_cast<size_t>(Pred::eq)] = &handleEq;
128  handlers[static_cast<size_t>(Pred::ne)] = &handleNe;
129  handlers[static_cast<size_t>(Pred::slt)] = &handleSlt;
130  handlers[static_cast<size_t>(Pred::sle)] = &handleSle;
131  handlers[static_cast<size_t>(Pred::sgt)] = &handleSgt;
132  handlers[static_cast<size_t>(Pred::sge)] = &handleSge;
133  handlers[static_cast<size_t>(Pred::ult)] = &handleUlt;
134  handlers[static_cast<size_t>(Pred::ule)] = &handleUle;
135  handlers[static_cast<size_t>(Pred::ugt)] = &handleUgt;
136  handlers[static_cast<size_t>(Pred::uge)] = &handleUge;
137 
138  HandlerFunc handler = handlers[static_cast<size_t>(op.getPredicate())];
139  if (!handler)
140  return failure();
141 
142  ConstantIntRanges lhsValue = lhsResult->getValue().getValue();
143  ConstantIntRanges rhsValue = rhsResult->getValue().getValue();
144  FailureOr<bool> result = handler(lhsValue, rhsValue);
145 
146  if (failed(result))
147  return failure();
148 
149  rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(
150  op, static_cast<int64_t>(*result), /*width*/ 1);
151  return success();
152  }
153 
154 private:
155  DataFlowSolver &solver;
156 };
157 
158 struct IntRangeOptimizationsPass
159  : public arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
160 
161  void runOnOperation() override {
162  Operation *op = getOperation();
163  MLIRContext *ctx = op->getContext();
164  DataFlowSolver solver;
165  solver.load<DeadCodeAnalysis>();
166  solver.load<IntegerRangeAnalysis>();
167  if (failed(solver.initializeAndRun(op)))
168  return signalPassFailure();
169 
170  RewritePatternSet patterns(ctx);
171  populateIntRangeOptimizationsPatterns(patterns, solver);
172 
173  if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
174  signalPassFailure();
175  }
176 };
177 } // namespace
178 
180  RewritePatternSet &patterns, DataFlowSolver &solver) {
181  patterns.add<ConvertCmpOp>(patterns.getContext(), solver);
182 }
183 
185  return std::make_unique<IntRangeOptimizationsPass>();
186 }
static FailureOr< bool > handleUlt(ConstantIntRanges lhs, ConstantIntRanges rhs)
static FailureOr< bool > handleSlt(ConstantIntRanges lhs, ConstantIntRanges rhs)
static FailureOr< bool > handleUgt(ConstantIntRanges lhs, ConstantIntRanges rhs)
static bool intersects(const ConstantIntRanges &lhs, const ConstantIntRanges &rhs)
Returns true if 2 integer ranges have intersection.
static FailureOr< bool > handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs)
static FailureOr< bool > handleSge(ConstantIntRanges lhs, ConstantIntRanges rhs)
static FailureOr< bool > handleUle(ConstantIntRanges lhs, ConstantIntRanges rhs)
static FailureOr< bool > handleNe(ConstantIntRanges lhs, ConstantIntRanges rhs)
static FailureOr< bool > handleSgt(ConstantIntRanges lhs, ConstantIntRanges rhs)
static FailureOr< bool > handleEq(ConstantIntRanges lhs, ConstantIntRanges rhs)
static FailureOr< bool > handleSle(ConstantIntRanges lhs, ConstantIntRanges rhs)
A set of arbitrary-precision integers representing bounds on a given integer value.
const APInt & smax() const
The maximum value of an integer when it is interpreted as signed.
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
const APInt & umax() const
The maximum value of an integer when it is interpreted as unsigned.
const APInt & umin() const
The minimum value of an integer when it is interpreted as unsigned.
The general data-flow analysis solver.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Dead code analysis analyzes control-flow, as understood by RegionBranchOpInterface and BranchOpInterf...
Integer range analysis determines the integer value range of SSA values using operations that define ...
This lattice element represents the integer value range of an SSA value.
std::unique_ptr< Pass > createIntRangeOptimizationsPass()
Create a pass which do optimizations based on integer range analysis.
void populateIntRangeOptimizationsPatterns(RewritePatternSet &patterns, DataFlowSolver &solver)
Add patterns for int range based optimizations.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358