MLIR 23.0.0git
UnsignedWhenEquivalent.cpp
Go to the documentation of this file.
1//===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with
2// unsigned
3// ones when all their arguments and results are statically non-negative --===//
4//
5// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6// See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9//===----------------------------------------------------------------------===//
10
13
19
20namespace mlir {
21namespace arith {
22#define GEN_PASS_DEF_ARITHUNSIGNEDWHENEQUIVALENTPASS
23#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
24} // namespace arith
25} // namespace mlir
26
27using namespace mlir;
28using namespace mlir::arith;
29using namespace mlir::dataflow;
30
31/// Succeeds when the comparison predicate is a signed operation and all the
32/// operands are non-negative, indicating that the cmpi operation `op` can have
33/// its predicate changed to an unsigned equivalent.
34static LogicalResult isCmpIConvertable(DataFlowSolver &solver, CmpIOp op) {
35 CmpIPredicate pred = op.getPredicate();
36 switch (pred) {
37 case CmpIPredicate::sle:
38 case CmpIPredicate::slt:
39 case CmpIPredicate::sge:
40 case CmpIPredicate::sgt:
41 return success(llvm::all_of(op.getOperands(), [&solver](Value v) -> bool {
42 return succeeded(staticallyNonNegative(solver, v));
43 }));
44 default:
45 return failure();
46 }
47}
48
49/// Return the unsigned equivalent of a signed comparison predicate,
50/// or the predicate itself if there is none.
51static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
52 switch (pred) {
53 case CmpIPredicate::sle:
54 return CmpIPredicate::ule;
55 case CmpIPredicate::slt:
56 return CmpIPredicate::ult;
57 case CmpIPredicate::sge:
58 return CmpIPredicate::uge;
59 case CmpIPredicate::sgt:
60 return CmpIPredicate::ugt;
61 default:
62 return pred;
63 }
64}
65
66namespace {
67class DataFlowListener : public RewriterBase::Listener {
68public:
69 DataFlowListener(DataFlowSolver &s) : s(s) {}
70
71protected:
72 void notifyOperationErased(Operation *op) override {
73 s.eraseState(s.getProgramPointAfter(op));
74 for (Value res : op->getResults())
75 s.eraseState(res);
76 }
77
78 DataFlowSolver &s;
79};
80
81// TODO: IntegerRangeAnalysis internally assumes index is 64bit and this pattern
82// (via staticallyNonNegative) relies on this. These transformations may not be
83// valid for 32bit index, need more investigation.
84
85template <typename Signed, typename Unsigned>
86struct ConvertOpToUnsigned final : OpRewritePattern<Signed> {
87 ConvertOpToUnsigned(MLIRContext *context, DataFlowSolver &s)
88 : OpRewritePattern<Signed>(context), solver(s) {}
89
90 LogicalResult matchAndRewrite(Signed op, PatternRewriter &rw) const override {
91 if (failed(
92 staticallyNonNegative(this->solver, static_cast<Operation *>(op))))
93 return failure();
94
95 rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(), op->getOperands(),
96 op->getAttrs());
97 return success();
98 }
99
100private:
101 DataFlowSolver &solver;
102};
103
104struct ConvertCmpIToUnsigned final : OpRewritePattern<CmpIOp> {
105 ConvertCmpIToUnsigned(MLIRContext *context, DataFlowSolver &s)
106 : OpRewritePattern<CmpIOp>(context), solver(s) {}
107
108 LogicalResult matchAndRewrite(CmpIOp op, PatternRewriter &rw) const override {
109 if (failed(isCmpIConvertable(this->solver, op)))
110 return failure();
111
112 rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()),
113 op.getLhs(), op.getRhs());
114 return success();
115 }
116
117private:
118 DataFlowSolver &solver;
119};
120
121struct ArithUnsignedWhenEquivalentPass
122 : public arith::impl::ArithUnsignedWhenEquivalentPassBase<
123 ArithUnsignedWhenEquivalentPass> {
124
125 void runOnOperation() override {
126 Operation *op = getOperation();
127 MLIRContext *ctx = op->getContext();
128 DataFlowSolver solver;
129 solver.load<SparseConstantPropagation>();
130 solver.load<DeadCodeAnalysis>();
131 solver.load<IntegerRangeAnalysis>();
132 if (failed(solver.initializeAndRun(op)))
133 return signalPassFailure();
134
135 DataFlowListener listener(solver);
136
137 RewritePatternSet patterns(ctx);
139
140 walkAndApplyPatterns(op, std::move(patterns), &listener);
141 }
142};
143} // end anonymous namespace
144
146 RewritePatternSet &patterns, DataFlowSolver &solver) {
147 patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
148 ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
149 ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
150 ConvertOpToUnsigned<RemSIOp, RemUIOp>,
151 ConvertOpToUnsigned<MinSIOp, MinUIOp>,
152 ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
153 ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
154 patterns.getContext(), solver);
155}
return success()
static CmpIPredicate toUnsignedPred(CmpIPredicate pred)
Return the unsigned equivalent of a signed comparison predicate, or the predicate itself if there is ...
static LogicalResult isCmpIConvertable(DataFlowSolver &solver, CmpIOp op)
Succeeds when the comparison predicate is a signed operation and all the operands are non-negative,...
The general data-flow analysis solver.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
result_range getResults()
Definition Operation.h:441
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void populateUnsignedWhenEquivalentPatterns(RewritePatternSet &patterns, DataFlowSolver &solver)
Replace signed ops with unsigned ones where they are proven equivalent.
LogicalResult staticallyNonNegative(DataFlowSolver &solver, Operation *op)
Succeeds if an op can be converted to its unsigned equivalent without changing its semantics.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...