MLIR  19.0.0git
UnsignedWhenEquivalent.cpp
Go to the documentation of this file.
1 //===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with
2 // unsigned
3 // ones when all their arguments and results are statically non-negative --===//
4 //
5 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6 // See https://llvm.org/LICENSE.txt for license information.
7 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //
9 //===----------------------------------------------------------------------===//
10 
12 
17 
18 namespace mlir {
19 namespace arith {
20 #define GEN_PASS_DEF_ARITHUNSIGNEDWHENEQUIVALENT
21 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
22 } // namespace arith
23 } // namespace mlir
24 
25 using namespace mlir;
26 using namespace mlir::arith;
27 using namespace mlir::dataflow;
28 
29 /// Succeeds when a value is statically non-negative in that it has a lower
30 /// bound on its value (if it is treated as signed) and that bound is
31 /// non-negative.
33  auto *result = solver.lookupState<IntegerValueRangeLattice>(v);
34  if (!result || result->getValue().isUninitialized())
35  return failure();
36  const ConstantIntRanges &range = result->getValue().getValue();
37  return success(range.smin().isNonNegative());
38 }
39 
40 /// Succeeds if an op can be converted to its unsigned equivalent without
41 /// changing its semantics. This is the case when none of its openands or
42 /// results can be below 0 when analyzed from a signed perspective.
44  Operation *op) {
45  auto nonNegativePred = [&solver](Value v) -> bool {
46  return succeeded(staticallyNonNegative(solver, v));
47  };
48  return success(llvm::all_of(op->getOperands(), nonNegativePred) &&
49  llvm::all_of(op->getResults(), nonNegativePred));
50 }
51 
52 /// Succeeds when the comparison predicate is a signed operation and all the
53 /// operands are non-negative, indicating that the cmpi operation `op` can have
54 /// its predicate changed to an unsigned equivalent.
55 static LogicalResult isCmpIConvertable(DataFlowSolver &solver, CmpIOp op) {
56  CmpIPredicate pred = op.getPredicate();
57  switch (pred) {
58  case CmpIPredicate::sle:
59  case CmpIPredicate::slt:
60  case CmpIPredicate::sge:
61  case CmpIPredicate::sgt:
62  return success(llvm::all_of(op.getOperands(), [&solver](Value v) -> bool {
63  return succeeded(staticallyNonNegative(solver, v));
64  }));
65  default:
66  return failure();
67  }
68 }
69 
70 /// Return the unsigned equivalent of a signed comparison predicate,
71 /// or the predicate itself if there is none.
72 static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
73  switch (pred) {
74  case CmpIPredicate::sle:
75  return CmpIPredicate::ule;
76  case CmpIPredicate::slt:
77  return CmpIPredicate::ult;
78  case CmpIPredicate::sge:
79  return CmpIPredicate::uge;
80  case CmpIPredicate::sgt:
81  return CmpIPredicate::ugt;
82  default:
83  return pred;
84  }
85 }
86 
87 namespace {
88 template <typename Signed, typename Unsigned>
89 struct ConvertOpToUnsigned : OpConversionPattern<Signed> {
91 
92  LogicalResult matchAndRewrite(Signed op, typename Signed::Adaptor adaptor,
93  ConversionPatternRewriter &rw) const override {
95  adaptor.getOperands(), op->getAttrs());
96  return success();
97  }
98 };
99 
100 struct ConvertCmpIToUnsigned : OpConversionPattern<CmpIOp> {
102 
103  LogicalResult matchAndRewrite(CmpIOp op, CmpIOpAdaptor adaptor,
104  ConversionPatternRewriter &rw) const override {
105  rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()),
106  op.getLhs(), op.getRhs());
107  return success();
108  }
109 };
110 
111 struct ArithUnsignedWhenEquivalentPass
112  : public arith::impl::ArithUnsignedWhenEquivalentBase<
113  ArithUnsignedWhenEquivalentPass> {
114  /// Implementation structure: first find all equivalent ops and collect them,
115  /// then perform all the rewrites in a second pass over the target op. This
116  /// ensures that analysis results are not invalidated during rewriting.
117  void runOnOperation() override {
118  Operation *op = getOperation();
119  MLIRContext *ctx = op->getContext();
120  DataFlowSolver solver;
121  solver.load<DeadCodeAnalysis>();
122  solver.load<IntegerRangeAnalysis>();
123  if (failed(solver.initializeAndRun(op)))
124  return signalPassFailure();
125 
126  ConversionTarget target(*ctx);
127  target.addLegalDialect<ArithDialect>();
128  target.addDynamicallyLegalOp<DivSIOp, CeilDivSIOp, FloorDivSIOp, RemSIOp,
129  MinSIOp, MaxSIOp, ExtSIOp>(
130  [&solver](Operation *op) -> std::optional<bool> {
131  return failed(staticallyNonNegative(solver, op));
132  });
133  target.addDynamicallyLegalOp<CmpIOp>(
134  [&solver](CmpIOp op) -> std::optional<bool> {
135  return failed(isCmpIConvertable(solver, op));
136  });
137 
138  RewritePatternSet patterns(ctx);
139  patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
140  ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
141  ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
142  ConvertOpToUnsigned<RemSIOp, RemUIOp>,
143  ConvertOpToUnsigned<MinSIOp, MinUIOp>,
144  ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
145  ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
146  ctx);
147 
148  if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
149  signalPassFailure();
150  }
151  }
152 };
153 } // end anonymous namespace
154 
156  return std::make_unique<ArithUnsignedWhenEquivalentPass>();
157 }
static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v)
Succeeds when a value is statically non-negative in that it has a lower bound on its value (if it is ...
static CmpIPredicate toUnsignedPred(CmpIPredicate pred)
Return the unsigned equivalent of a signed comparison predicate, or the predicate itself if there is ...
static LogicalResult isCmpIConvertable(DataFlowSolver &solver, CmpIOp op)
Succeeds when the comparison predicate is a signed operation and all the operands are non-negative,...
A set of arbitrary-precision integers representing bounds on a given integer value.
const APInt & smin() const
The minimum value of an integer when it is interpreted as signed.
This class implements a pattern rewriter for use with ConversionPatterns.
This class describes a specific conversion target.
The general data-flow analysis solver.
const StateT * lookupState(PointT point) const
Lookup an analysis state for the given program point.
AnalysisT * load(Args &&...args)
Load an analysis into the solver. Return the analysis instance.
LogicalResult initializeAndRun(Operation *top)
Initialize the children analyses starting from the provided top-level operation and run the analysis ...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Dead code analysis analyzes control-flow, as understood by RegionBranchOpInterface and BranchOpInterf...
Integer range analysis determines the integer value range of SSA values using operations that define ...
This lattice element represents the integer value range of an SSA value.
std::unique_ptr< Pass > createArithUnsignedWhenEquivalentPass()
Create a pass to replace signed ops with unsigned ones where they are proven equivalent.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26