MLIR  20.0.0git
UnsignedWhenEquivalent.cpp
Go to the documentation of this file.
1 //===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with
2 // unsigned
3 // ones when all their arguments and results are statically non-negative --===//
4 //
5 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6 // See https://llvm.org/LICENSE.txt for license information.
7 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //
9 //===----------------------------------------------------------------------===//
10 
12 
16 #include "mlir/IR/PatternMatch.h"
18 
19 namespace mlir {
20 namespace arith {
21 #define GEN_PASS_DEF_ARITHUNSIGNEDWHENEQUIVALENT
22 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
23 } // namespace arith
24 } // namespace mlir
25 
26 using namespace mlir;
27 using namespace mlir::arith;
28 using namespace mlir::dataflow;
29 
30 /// Succeeds when a value is statically non-negative in that it has a lower
31 /// bound on its value (if it is treated as signed) and that bound is
32 /// non-negative.
33 // TODO: IntegerRangeAnalysis internally assumes index is 64bit and this pattern
34 // relies on this. These transformations may not be valid for 32bit index,
35 // need more investigation.
36 static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) {
37  auto *result = solver.lookupState<IntegerValueRangeLattice>(v);
38  if (!result || result->getValue().isUninitialized())
39  return failure();
40  const ConstantIntRanges &range = result->getValue().getValue();
41  return success(range.smin().isNonNegative());
42 }
43 
44 /// Succeeds if an op can be converted to its unsigned equivalent without
45 /// changing its semantics. This is the case when none of its openands or
46 /// results can be below 0 when analyzed from a signed perspective.
47 static LogicalResult staticallyNonNegative(DataFlowSolver &solver,
48  Operation *op) {
49  auto nonNegativePred = [&solver](Value v) -> bool {
50  return succeeded(staticallyNonNegative(solver, v));
51  };
52  return success(llvm::all_of(op->getOperands(), nonNegativePred) &&
53  llvm::all_of(op->getResults(), nonNegativePred));
54 }
55 
56 /// Succeeds when the comparison predicate is a signed operation and all the
57 /// operands are non-negative, indicating that the cmpi operation `op` can have
58 /// its predicate changed to an unsigned equivalent.
59 static LogicalResult isCmpIConvertable(DataFlowSolver &solver, CmpIOp op) {
60  CmpIPredicate pred = op.getPredicate();
61  switch (pred) {
62  case CmpIPredicate::sle:
63  case CmpIPredicate::slt:
64  case CmpIPredicate::sge:
65  case CmpIPredicate::sgt:
66  return success(llvm::all_of(op.getOperands(), [&solver](Value v) -> bool {
67  return succeeded(staticallyNonNegative(solver, v));
68  }));
69  default:
70  return failure();
71  }
72 }
73 
74 /// Return the unsigned equivalent of a signed comparison predicate,
75 /// or the predicate itself if there is none.
76 static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
77  switch (pred) {
78  case CmpIPredicate::sle:
79  return CmpIPredicate::ule;
80  case CmpIPredicate::slt:
81  return CmpIPredicate::ult;
82  case CmpIPredicate::sge:
83  return CmpIPredicate::uge;
84  case CmpIPredicate::sgt:
85  return CmpIPredicate::ugt;
86  default:
87  return pred;
88  }
89 }
90 
91 namespace {
92 class DataFlowListener : public RewriterBase::Listener {
93 public:
94  DataFlowListener(DataFlowSolver &s) : s(s) {}
95 
96 protected:
97  void notifyOperationErased(Operation *op) override {
98  s.eraseState(s.getProgramPointAfter(op));
99  for (Value res : op->getResults())
100  s.eraseState(res);
101  }
102 
103  DataFlowSolver &s;
104 };
105 
106 template <typename Signed, typename Unsigned>
107 struct ConvertOpToUnsigned final : OpRewritePattern<Signed> {
108  ConvertOpToUnsigned(MLIRContext *context, DataFlowSolver &s)
109  : OpRewritePattern<Signed>(context), solver(s) {}
110 
111  LogicalResult matchAndRewrite(Signed op, PatternRewriter &rw) const override {
112  if (failed(
113  staticallyNonNegative(this->solver, static_cast<Operation *>(op))))
114  return failure();
115 
116  rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(), op->getOperands(),
117  op->getAttrs());
118  return success();
119  }
120 
121 private:
122  DataFlowSolver &solver;
123 };
124 
125 struct ConvertCmpIToUnsigned final : OpRewritePattern<CmpIOp> {
126  ConvertCmpIToUnsigned(MLIRContext *context, DataFlowSolver &s)
127  : OpRewritePattern<CmpIOp>(context), solver(s) {}
128 
129  LogicalResult matchAndRewrite(CmpIOp op, PatternRewriter &rw) const override {
130  if (failed(isCmpIConvertable(this->solver, op)))
131  return failure();
132 
133  rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()),
134  op.getLhs(), op.getRhs());
135  return success();
136  }
137 
138 private:
139  DataFlowSolver &solver;
140 };
141 
142 struct ArithUnsignedWhenEquivalentPass
143  : public arith::impl::ArithUnsignedWhenEquivalentBase<
144  ArithUnsignedWhenEquivalentPass> {
145 
146  void runOnOperation() override {
147  Operation *op = getOperation();
148  MLIRContext *ctx = op->getContext();
149  DataFlowSolver solver;
150  solver.load<DeadCodeAnalysis>();
151  solver.load<IntegerRangeAnalysis>();
152  if (failed(solver.initializeAndRun(op)))
153  return signalPassFailure();
154 
155  DataFlowListener listener(solver);
156 
157  RewritePatternSet patterns(ctx);
158  populateUnsignedWhenEquivalentPatterns(patterns, solver);
159 
160  walkAndApplyPatterns(op, std::move(patterns), &listener);
161  }
162 };
163 } // end anonymous namespace
164 
166  RewritePatternSet &patterns, DataFlowSolver &solver) {
167  patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
168  ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
169  ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
170  ConvertOpToUnsigned<RemSIOp, RemUIOp>,
171  ConvertOpToUnsigned<MinSIOp, MinUIOp>,
172  ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
173  ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
174  patterns.getContext(), solver);
175 }
176 
178  return std::make_unique<ArithUnsignedWhenEquivalentPass>();
179 }
static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v)
Succeeds when a value is statically non-negative in that it has a lower bound on its value (if it is ...
static CmpIPredicate toUnsignedPred(CmpIPredicate pred)
Return the unsigned equivalent of a signed comparison predicate, or the predicate itself if there is ...
static LogicalResult isCmpIConvertable(DataFlowSolver &solver, CmpIOp op)
Succeeds when the comparison predicate is a signed operation and all the operands are non-negative,...
A set of arbitrary-precision integers representing bounds on a given integer value.
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
The general data-flow analysis solver.
void eraseState(AnchorT anchor)
Erase any analysis state associated with the given lattice anchor.
const StateT * lookupState(AnchorT anchor) const
Lookup an analysis state for the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Dead code analysis analyzes control-flow, as understood by RegionBranchOpInterface and BranchOpInterf...
Integer range analysis determines the integer value range of SSA values using operations that define ...
This lattice element represents the integer value range of an SSA value.
void populateUnsignedWhenEquivalentPatterns(RewritePatternSet &patterns, DataFlowSolver &solver)
Replace signed ops with unsigned ones where they are proven equivalent.
std::unique_ptr< Pass > createArithUnsignedWhenEquivalentPass()
Create a pass to replace signed ops with unsigned ones where they are proven equivalent.
Include the generated interface declarations.
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358