MLIR  21.0.0git
UnsignedWhenEquivalent.cpp
Go to the documentation of this file.
1 //===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with
2 // unsigned
3 // ones when all their arguments and results are statically non-negative --===//
4 //
5 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6 // See https://llvm.org/LICENSE.txt for license information.
7 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //
9 //===----------------------------------------------------------------------===//
10 
12 
16 #include "mlir/IR/PatternMatch.h"
18 
19 namespace mlir {
20 namespace arith {
21 #define GEN_PASS_DEF_ARITHUNSIGNEDWHENEQUIVALENT
22 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
23 } // namespace arith
24 } // namespace mlir
25 
26 using namespace mlir;
27 using namespace mlir::arith;
28 using namespace mlir::dataflow;
29 
30 /// Succeeds when the comparison predicate is a signed operation and all the
31 /// operands are non-negative, indicating that the cmpi operation `op` can have
32 /// its predicate changed to an unsigned equivalent.
33 static LogicalResult isCmpIConvertable(DataFlowSolver &solver, CmpIOp op) {
34  CmpIPredicate pred = op.getPredicate();
35  switch (pred) {
36  case CmpIPredicate::sle:
37  case CmpIPredicate::slt:
38  case CmpIPredicate::sge:
39  case CmpIPredicate::sgt:
40  return success(llvm::all_of(op.getOperands(), [&solver](Value v) -> bool {
41  return succeeded(staticallyNonNegative(solver, v));
42  }));
43  default:
44  return failure();
45  }
46 }
47 
48 /// Return the unsigned equivalent of a signed comparison predicate,
49 /// or the predicate itself if there is none.
50 static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
51  switch (pred) {
52  case CmpIPredicate::sle:
53  return CmpIPredicate::ule;
54  case CmpIPredicate::slt:
55  return CmpIPredicate::ult;
56  case CmpIPredicate::sge:
57  return CmpIPredicate::uge;
58  case CmpIPredicate::sgt:
59  return CmpIPredicate::ugt;
60  default:
61  return pred;
62  }
63 }
64 
65 namespace {
66 class DataFlowListener : public RewriterBase::Listener {
67 public:
68  DataFlowListener(DataFlowSolver &s) : s(s) {}
69 
70 protected:
71  void notifyOperationErased(Operation *op) override {
72  s.eraseState(s.getProgramPointAfter(op));
73  for (Value res : op->getResults())
74  s.eraseState(res);
75  }
76 
77  DataFlowSolver &s;
78 };
79 
80 // TODO: IntegerRangeAnalysis internally assumes index is 64bit and this pattern
81 // (via staticallyNonNegative) relies on this. These transformations may not be
82 // valid for 32bit index, need more investigation.
83 
84 template <typename Signed, typename Unsigned>
85 struct ConvertOpToUnsigned final : OpRewritePattern<Signed> {
86  ConvertOpToUnsigned(MLIRContext *context, DataFlowSolver &s)
87  : OpRewritePattern<Signed>(context), solver(s) {}
88 
89  LogicalResult matchAndRewrite(Signed op, PatternRewriter &rw) const override {
90  if (failed(
91  staticallyNonNegative(this->solver, static_cast<Operation *>(op))))
92  return failure();
93 
94  rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(), op->getOperands(),
95  op->getAttrs());
96  return success();
97  }
98 
99 private:
100  DataFlowSolver &solver;
101 };
102 
103 struct ConvertCmpIToUnsigned final : OpRewritePattern<CmpIOp> {
104  ConvertCmpIToUnsigned(MLIRContext *context, DataFlowSolver &s)
105  : OpRewritePattern<CmpIOp>(context), solver(s) {}
106 
107  LogicalResult matchAndRewrite(CmpIOp op, PatternRewriter &rw) const override {
108  if (failed(isCmpIConvertable(this->solver, op)))
109  return failure();
110 
111  rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()),
112  op.getLhs(), op.getRhs());
113  return success();
114  }
115 
116 private:
117  DataFlowSolver &solver;
118 };
119 
120 struct ArithUnsignedWhenEquivalentPass
121  : public arith::impl::ArithUnsignedWhenEquivalentBase<
122  ArithUnsignedWhenEquivalentPass> {
123 
124  void runOnOperation() override {
125  Operation *op = getOperation();
126  MLIRContext *ctx = op->getContext();
127  DataFlowSolver solver;
128  solver.load<DeadCodeAnalysis>();
129  solver.load<IntegerRangeAnalysis>();
130  if (failed(solver.initializeAndRun(op)))
131  return signalPassFailure();
132 
133  DataFlowListener listener(solver);
134 
137 
138  walkAndApplyPatterns(op, std::move(patterns), &listener);
139  }
140 };
141 } // end anonymous namespace
142 
145  patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
146  ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
147  ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
148  ConvertOpToUnsigned<RemSIOp, RemUIOp>,
149  ConvertOpToUnsigned<MinSIOp, MinUIOp>,
150  ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
151  ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
152  patterns.getContext(), solver);
153 }
154 
156  return std::make_unique<ArithUnsignedWhenEquivalentPass>();
157 }
static CmpIPredicate toUnsignedPred(CmpIPredicate pred)
Return the unsigned equivalent of a signed comparison predicate, or the predicate itself if there is ...
static LogicalResult isCmpIConvertable(DataFlowSolver &solver, CmpIOp op)
Succeeds when the comparison predicate is a signed operation and all the operands are non-negative,...
The general data-flow analysis solver.
void eraseState(AnchorT anchor)
Erase any analysis state associated with the given lattice anchor.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
result_range getResults()
Definition: Operation.h:415
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Dead code analysis analyzes control-flow, as understood by RegionBranchOpInterface and BranchOpInterf...
Integer range analysis determines the integer value range of SSA values using operations that define ...
void populateUnsignedWhenEquivalentPatterns(RewritePatternSet &patterns, DataFlowSolver &solver)
Replace signed ops with unsigned ones where they are proven equivalent.
std::unique_ptr< Pass > createArithUnsignedWhenEquivalentPass()
Create a pass to replace signed ops with unsigned ones where they are proven equivalent.
LogicalResult staticallyNonNegative(DataFlowSolver &solver, Operation *op)
Succeeds if an op can be converted to its unsigned equivalent without changing its semantics.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314