MLIR  17.0.0git
ExpandOps.cpp
Go to the documentation of this file.
1 //===- ExpandOps.cpp - Pass to legalize Arith ops for LLVM lowering --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
12 #include "mlir/IR/TypeUtilities.h"
14 
15 namespace mlir {
16 namespace arith {
17 #define GEN_PASS_DEF_ARITHEXPANDOPS
18 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
19 } // namespace arith
20 } // namespace mlir
21 
22 using namespace mlir;
23 
24 /// Create an integer or index constant.
25 static Value createConst(Location loc, Type type, int value,
26  PatternRewriter &rewriter) {
27  return rewriter.create<arith::ConstantOp>(
28  loc, rewriter.getIntegerAttr(type, value));
29 }
30 
31 namespace {
32 
33 /// Expands CeilDivUIOp (n, m) into
34 /// n == 0 ? 0 : ((n-1) / m) + 1
35 struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
37  LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
38  PatternRewriter &rewriter) const final {
39  Location loc = op.getLoc();
40  Value a = op.getLhs();
41  Value b = op.getRhs();
42  Value zero = createConst(loc, a.getType(), 0, rewriter);
43  Value compare =
44  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
45  Value one = createConst(loc, a.getType(), 1, rewriter);
46  Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
47  Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
48  Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
49  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne);
50  return success();
51  }
52 };
53 
54 /// Expands CeilDivSIOp (n, m) into
55 /// 1) x = (m > 0) ? -1 : 1
56 /// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
57 struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
59  LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
60  PatternRewriter &rewriter) const final {
61  Location loc = op.getLoc();
62  Type type = op.getType();
63  Value a = op.getLhs();
64  Value b = op.getRhs();
65  Value plusOne = createConst(loc, type, 1, rewriter);
66  Value zero = createConst(loc, type, 0, rewriter);
67  Value minusOne = createConst(loc, type, -1, rewriter);
68  // Compute x = (b>0) ? -1 : 1.
69  Value compare =
70  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
71  Value x = rewriter.create<arith::SelectOp>(loc, compare, minusOne, plusOne);
72  // Compute positive res: 1 + ((x+a)/b).
73  Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
74  Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
75  Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
76  // Compute negative res: - ((-a)/b).
77  Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
78  Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
79  Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
80  // Result is (a*b>0) ? pos result : neg result.
81  // Note, we want to avoid using a*b because of possible overflow.
82  // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
83  // not particuliarly care if a*b<0 is true or false when b is zero
84  // as this will result in an illegal divide. So `a*b<0` can be reformulated
85  // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'.
86  // We pick the first expression here.
87  Value aNeg =
88  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
89  Value aPos =
90  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
91  Value bNeg =
92  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
93  Value bPos =
94  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
95  Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg);
96  Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
97  Value compareRes =
98  rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
99  // Perform substitution and return success.
100  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, posRes,
101  negRes);
102  return success();
103  }
104 };
105 
106 /// Expands FloorDivSIOp (n, m) into
107 /// 1) x = (m<0) ? 1 : -1
108 /// 2) return (n*m<0) ? - ((-n+x) / m) -1 : n / m
109 struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
111  LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
112  PatternRewriter &rewriter) const final {
113  Location loc = op.getLoc();
114  Type type = op.getType();
115  Value a = op.getLhs();
116  Value b = op.getRhs();
117  Value plusOne = createConst(loc, type, 1, rewriter);
118  Value zero = createConst(loc, type, 0, rewriter);
119  Value minusOne = createConst(loc, type, -1, rewriter);
120  // Compute x = (b<0) ? 1 : -1.
121  Value compare =
122  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
123  Value x = rewriter.create<arith::SelectOp>(loc, compare, plusOne, minusOne);
124  // Compute negative res: -1 - ((x-a)/b).
125  Value xMinusA = rewriter.create<arith::SubIOp>(loc, x, a);
126  Value xMinusADivB = rewriter.create<arith::DivSIOp>(loc, xMinusA, b);
127  Value negRes = rewriter.create<arith::SubIOp>(loc, minusOne, xMinusADivB);
128  // Compute positive res: a/b.
129  Value posRes = rewriter.create<arith::DivSIOp>(loc, a, b);
130  // Result is (a*b<0) ? negative result : positive result.
131  // Note, we want to avoid using a*b because of possible overflow.
132  // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
133  // not particuliarly care if a*b<0 is true or false when b is zero
134  // as this will result in an illegal divide. So `a*b<0` can be reformulated
135  // as `(a>0 && b<0) || (a>0 && b<0)' or `(a>0 && b<0) || (a>0 && b<=0)'.
136  // We pick the first expression here.
137  Value aNeg =
138  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
139  Value aPos =
140  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
141  Value bNeg =
142  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
143  Value bPos =
144  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
145  Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bPos);
146  Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg);
147  Value compareRes =
148  rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
149  // Perform substitution and return success.
150  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, negRes,
151  posRes);
152  return success();
153  }
154 };
155 
156 template <typename OpTy, arith::CmpFPredicate pred>
157 struct MaxMinFOpConverter : public OpRewritePattern<OpTy> {
158 public:
160 
161  LogicalResult matchAndRewrite(OpTy op,
162  PatternRewriter &rewriter) const final {
163  Value lhs = op.getLhs();
164  Value rhs = op.getRhs();
165 
166  Location loc = op.getLoc();
167  // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
168  static_assert(pred == arith::CmpFPredicate::UGT ||
169  pred == arith::CmpFPredicate::ULT,
170  "pred must be either UGT or ULT");
171  Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
172  Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
173 
174  // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'.
175  Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
176  rhs, rhs);
177  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
178  return success();
179  }
180 };
181 
182 struct ArithExpandOpsPass
183  : public arith::impl::ArithExpandOpsBase<ArithExpandOpsPass> {
184  void runOnOperation() override {
185  RewritePatternSet patterns(&getContext());
186  ConversionTarget target(getContext());
187 
189 
190  target.addLegalDialect<arith::ArithDialect>();
191  // clang-format off
192  target.addIllegalOp<
193  arith::CeilDivSIOp,
194  arith::CeilDivUIOp,
195  arith::FloorDivSIOp,
196  arith::MaxFOp,
197  arith::MinFOp
198  >();
199  // clang-format on
200  if (failed(applyPartialConversion(getOperation(), target,
201  std::move(patterns))))
202  signalPassFailure();
203  }
204 };
205 
206 } // namespace
207 
209  RewritePatternSet &patterns) {
210  patterns
211  .add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
212  patterns.getContext());
213 }
214 
217  // clang-format off
218  patterns.add<
219  MaxMinFOpConverter<MaxFOp, arith::CmpFPredicate::UGT>,
220  MaxMinFOpConverter<MinFOp, arith::CmpFPredicate::ULT>
221  >(patterns.getContext());
222  // clang-format on
223 }
224 
225 std::unique_ptr<Pass> mlir::arith::createArithExpandOpsPass() {
226  return std::make_unique<ArithExpandOpsPass>();
227 }
static Value createConst(Location loc, Type type, int value, PatternRewriter &rewriter)
Create an integer or index constant.
Definition: ExpandOps.cpp:25
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:212
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:621
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ceil/floor division ops.
Definition: ExpandOps.cpp:208
std::unique_ptr< Pass > createArithExpandOpsPass()
Create a pass to legalize Arith ops.
Definition: ExpandOps.cpp:225
void populateArithExpandOpsPatterns(RewritePatternSet &patterns)
Add patterns to expand Arith ops.
Definition: ExpandOps.cpp:215
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
Definition: Fraction.h:59
Include the generated interface declarations.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361