1 //===- ExpandOps.cpp - Pass to legalize Arith ops for LLVM lowering --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
14 #include "mlir/IR/TypeUtilities.h"
17 namespace mlir {
18 namespace arith {
20 #include "mlir/Dialect/Arith/Transforms/"
21 } // namespace arith
22 } // namespace mlir
24 using namespace mlir;
26 /// Create an integer or index constant.
27 static Value createConst(Location loc, Type type, int value,
28  PatternRewriter &rewriter) {
29  auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value);
30  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
31  return rewriter.create<arith::ConstantOp>(
32  loc, DenseElementsAttr::get(shapedTy, attr));
33  }
35  return rewriter.create<arith::ConstantOp>(loc, attr);
36 }
38 namespace {
40 /// Expands CeilDivUIOp (n, m) into
41 /// n == 0 ? 0 : ((n-1) / m) + 1
42 struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
44  LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
45  PatternRewriter &rewriter) const final {
46  Location loc = op.getLoc();
47  Value a = op.getLhs();
48  Value b = op.getRhs();
49  Value zero = createConst(loc, a.getType(), 0, rewriter);
50  Value compare =
51  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
52  Value one = createConst(loc, a.getType(), 1, rewriter);
53  Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
54  Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
55  Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
56  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne);
57  return success();
58  }
59 };
61 /// Expands CeilDivSIOp (n, m) into
62 /// 1) x = (m > 0) ? -1 : 1
63 /// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
64 struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
66  LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
67  PatternRewriter &rewriter) const final {
68  Location loc = op.getLoc();
69  Type type = op.getType();
70  Value a = op.getLhs();
71  Value b = op.getRhs();
72  Value plusOne = createConst(loc, type, 1, rewriter);
73  Value zero = createConst(loc, type, 0, rewriter);
74  Value minusOne = createConst(loc, type, -1, rewriter);
75  // Compute x = (b>0) ? -1 : 1.
76  Value compare =
77  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
78  Value x = rewriter.create<arith::SelectOp>(loc, compare, minusOne, plusOne);
79  // Compute positive res: 1 + ((x+a)/b).
80  Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
81  Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
82  Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
83  // Compute negative res: - ((-a)/b).
84  Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
85  Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
86  Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
87  // Result is (a*b>0) ? pos result : neg result.
88  // Note, we want to avoid using a*b because of possible overflow.
89  // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
90  // not particuliarly care if a*b<0 is true or false when b is zero
91  // as this will result in an illegal divide. So `a*b<0` can be reformulated
92  // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'.
93  // We pick the first expression here.
94  Value aNeg =
95  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
96  Value aPos =
97  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
98  Value bNeg =
99  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
100  Value bPos =
101  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
102  Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg);
103  Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
104  Value compareRes =
105  rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
106  // Perform substitution and return success.
107  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, posRes,
108  negRes);
109  return success();
110  }
111 };
113 /// Expands FloorDivSIOp (x, y) into
114 /// z = x / y
115 /// if (z * y != x && (x < 0) != (y < 0)) {
116 /// return z - 1;
117 /// } else {
118 /// return z;
119 /// }
120 struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
122  LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
123  PatternRewriter &rewriter) const final {
124  Location loc = op.getLoc();
125  Type type = op.getType();
126  Value a = op.getLhs();
127  Value b = op.getRhs();
129  Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
130  Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
131  Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
132  loc, arith::CmpIPredicate::ne, a, product);
133  Value zero = createConst(loc, type, 0, rewriter);
135  Value aNeg =
136  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
137  Value bNeg =
138  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
140  Value signOpposite = rewriter.create<arith::CmpIOp>(
141  loc, arith::CmpIPredicate::ne, aNeg, bNeg);
142  Value cond =
143  rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signOpposite);
145  Value minusOne = createConst(loc, type, -1, rewriter);
146  Value quotientMinusOne =
147  rewriter.create<arith::AddIOp>(loc, quotient, minusOne);
149  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
150  quotient);
151  return success();
152  }
153 };
155 template <typename OpTy, arith::CmpFPredicate pred>
156 struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
157 public:
160  LogicalResult matchAndRewrite(OpTy op,
161  PatternRewriter &rewriter) const final {
162  Value lhs = op.getLhs();
163  Value rhs = op.getRhs();
165  Location loc = op.getLoc();
166  // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
167  static_assert(pred == arith::CmpFPredicate::UGT ||
168  pred == arith::CmpFPredicate::ULT,
169  "pred must be either UGT or ULT");
170  Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
171  Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
173  // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'.
174  Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
175  rhs, rhs);
176  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
177  return success();
178  }
179 };
181 template <typename OpTy, arith::CmpFPredicate pred>
182 struct MaxNumMinNumFOpConverter : public OpRewritePattern<OpTy> {
183 public:
186  LogicalResult matchAndRewrite(OpTy op,
187  PatternRewriter &rewriter) const final {
188  Value lhs = op.getLhs();
189  Value rhs = op.getRhs();
191  Location loc = op.getLoc();
192  // If any operand is NaN, 'cmp' will be true (and 'select' returns 'lhs').
193  static_assert(pred == arith::CmpFPredicate::UGT ||
194  pred == arith::CmpFPredicate::ULT,
195  "pred must be either UGT or ULT");
196  Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs);
197  Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
199  // Handle the case where lhs is NaN: 'isNaN(lhs) ? rhs : select'.
200  Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
201  lhs, lhs);
202  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select);
203  return success();
204  }
205 };
207 struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
209  LogicalResult matchAndRewrite(arith::ExtFOp op,
210  PatternRewriter &rewriter) const final {
211  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
212  auto operand = op.getOperand();
213  Type operandTy = operand.getType();
214  Type resultTy = op.getType();
215  Type operandETy = getElementTypeOrSelf(operandTy);
216  Type resultETy = getElementTypeOrSelf(resultTy);
218  if (!operandETy.isBF16() || !resultETy.isF32()) {
219  return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
220  }
222  Type i16Ty = b.getI16Type();
223  Type i32Ty = b.getI32Type();
224  if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
225  i16Ty = shapedTy.clone(i16Ty);
226  i32Ty = shapedTy.clone(i32Ty);
227  }
229  Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
230  Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
232  Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
233  Value shl = b.create<arith::ShLIOp>(exti, c16);
234  Value result = b.create<arith::BitcastOp>(resultTy, shl);
236  rewriter.replaceOp(op, result);
237  return success();
238  }
239 };
241 struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
243  LogicalResult matchAndRewrite(arith::TruncFOp op,
244  PatternRewriter &rewriter) const final {
245  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
246  auto operand = op.getOperand();
247  Type operandTy = operand.getType();
248  Type resultTy = op.getType();
249  Type operandETy = getElementTypeOrSelf(operandTy);
250  Type resultETy = getElementTypeOrSelf(resultTy);
252  if (!operandETy.isF32() || !resultETy.isBF16()) {
253  return rewriter.notifyMatchFailure(op, "not a trunc of f32 to bf16.");
254  }
256  if (op.getRoundingmodeAttr()) {
257  return rewriter.notifyMatchFailure(
258  op, "only applicable to default rounding mode.");
259  }
261  Type i16Ty = b.getI16Type();
262  Type i32Ty = b.getI32Type();
263  Type f32Ty = b.getF32Type();
264  if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
265  i16Ty = shapedTy.clone(i16Ty);
266  i32Ty = shapedTy.clone(i32Ty);
267  f32Ty = shapedTy.clone(f32Ty);
268  }
270  // Algorithm borrowed from this excellent code:
271  //
272  // There is a magic idea there, to let the addition of the rounding_bias to
273  // the mantissa simply overflow into the exponent bits. It's a bit of an
274  // aggressive, obfuscating optimization, but it is well-tested code, and it
275  // results in more concise and efficient IR.
276  // The case of NaN is handled separately (see isNaN and the final select).
277  // The case of infinities is NOT handled separately, which deserves an
278  // explanation. As the encoding of infinities has zero mantissa, the
279  // rounding-bias addition never carries into the exponent so that just gets
280  // truncated away, and as bfloat16 and float32 have the same number of
281  // exponent bits, that simple truncation is the desired outcome for
282  // infinities.
283  Value isNan =
284  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand);
285  // Constant used to make the rounding bias.
286  Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
287  // Constant used to generate a quiet NaN.
288  Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
289  // Small constants used to address bits.
290  Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
291  Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
292  // Reinterpret the input f32 value as bits.
293  Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand);
294  // Read bit 16 as a value in {0,1}.
295  Value bit16 =
296  b.create<arith::AndIOp>(b.create<arith::ShRUIOp>(bitcast, c16), c1);
297  // Determine the rounding bias to add as either 0x7fff or 0x8000 depending
298  // on bit 16, implementing the tie-breaking "to nearest even".
299  Value roundingBias = b.create<arith::AddIOp>(bit16, c7FFF);
300  // Add the rounding bias. Generally we want this to be added to the
301  // mantissa, but nothing prevents this to from carrying into the exponent
302  // bits, which would feel like a bug, but this is the magic trick here:
303  // when that happens, the mantissa gets reset to zero and the exponent
304  // gets incremented by the carry... which is actually exactly what we
305  // want.
306  Value biased = b.create<arith::AddIOp>(bitcast, roundingBias);
307  // Now that the rounding-bias has been added, truncating the low bits
308  // yields the correctly rounded result.
309  Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
310  Value normalCaseResult_i16 =
311  b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
312  // Select either the above-computed result, or a quiet NaN constant
313  // if the input was NaN.
314  Value select =
315  b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
316  Value result = b.create<arith::BitcastOp>(resultTy, select);
317  rewriter.replaceOp(op, result);
318  return success();
319  }
320 };
322 struct ArithExpandOpsPass
323  : public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
324  using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
326  void runOnOperation() override {
327  RewritePatternSet patterns(&getContext());
328  ConversionTarget target(getContext());
332  target.addLegalDialect<arith::ArithDialect>();
333  // clang-format off
334  target.addIllegalOp<
335  arith::CeilDivSIOp,
336  arith::CeilDivUIOp,
337  arith::FloorDivSIOp,
338  arith::MaximumFOp,
339  arith::MinimumFOp,
340  arith::MaxNumFOp,
341  arith::MinNumFOp
342  >();
344  if (includeBf16) {
346  target.addDynamicallyLegalOp<arith::ExtFOp>(
347  [](arith::ExtFOp op) {
349  Type outETy = getElementTypeOrSelf(op.getType());
350  return !(inETy.isBF16() && outETy.isF32());
351  });
353  target.addDynamicallyLegalOp<arith::TruncFOp>(
354  [](arith::TruncFOp op) {
356  Type outETy = getElementTypeOrSelf(op.getType());
357  return !(inETy.isF32() && outETy.isBF16());
358  });
359  }
361  // clang-format on
362  if (failed(applyPartialConversion(getOperation(), target,
363  std::move(patterns))))
364  signalPassFailure();
365  }
366 };
368 } // namespace
371  RewritePatternSet &patterns) {
372  patterns
373  .add<CeilDivSIOpConverter, CeilDivUIOpConverter, FloorDivSIOpConverter>(
374  patterns.getContext());
375 }
378  patterns.add<BFloat16ExtFOpConverter, BFloat16TruncFOpConverter>(
379  patterns.getContext());
380 }
384  // clang-format off
385  patterns.add<
386  MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
387  MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
388  MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
389  MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
390  >(patterns.getContext());
391  // clang-format on
392 }
