MLIR  22.0.0git
ExpandOps.cpp
Go to the documentation of this file.
1 //===- ExpandPatterns.cpp - Code to expand various math operations. -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements expansion of various math operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/TypeUtilities.h"
20 
21 using namespace mlir;
22 
23 namespace mlir::math {
24 #define GEN_PASS_DEF_MATHEXPANDOPSPASS
25 #include "mlir/Dialect/Math/Transforms/Passes.h.inc"
26 } // namespace mlir::math
27 
28 /// Create a float constant.
29 static Value createFloatConst(Location loc, Type type, APFloat value,
30  OpBuilder &b) {
31  bool losesInfo = false;
32  auto eltType = getElementTypeOrSelf(type);
33  // Convert double to the given `FloatType` with round-to-nearest-ties-to-even.
34  value.convert(cast<FloatType>(eltType).getFloatSemantics(),
35  APFloat::rmNearestTiesToEven, &losesInfo);
36  auto attr = b.getFloatAttr(eltType, value);
37  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
38  return arith::ConstantOp::create(b, loc,
39  DenseElementsAttr::get(shapedTy, attr));
40  }
41 
42  return arith::ConstantOp::create(b, loc, attr);
43 }
44 
45 static Value createFloatConst(Location loc, Type type, double value,
46  OpBuilder &b) {
47  return createFloatConst(loc, type, APFloat(value), b);
48 }
49 
50 /// Create an integer constant.
51 static Value createIntConst(Location loc, Type type, int64_t value,
52  OpBuilder &b) {
53  auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
54  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
55  return arith::ConstantOp::create(b, loc,
56  DenseElementsAttr::get(shapedTy, attr));
57  }
58 
59  return arith::ConstantOp::create(b, loc, attr);
60 }
61 
63  Type opType = operand.getType();
64  Type i64Ty = b.getI64Type();
65  if (auto shapedTy = dyn_cast<ShapedType>(opType))
66  i64Ty = shapedTy.clone(i64Ty);
67  Value fixedConvert = arith::FPToSIOp::create(b, i64Ty, operand);
68  Value fpFixedConvert = arith::SIToFPOp::create(b, opType, fixedConvert);
69  // The truncation does not preserve the sign when the truncated
70  // value is -0. So here the sign is copied again.
71  return math::CopySignOp::create(b, fpFixedConvert, operand);
72 }
73 
74 // sinhf(float x) -> (exp(x) - exp(-x)) / 2
75 static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
76  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
77  Value operand = op.getOperand();
78  Type opType = operand.getType();
79 
80  Value exp = math::ExpOp::create(b, operand);
81  Value neg = arith::NegFOp::create(b, operand);
82  Value nexp = math::ExpOp::create(b, neg);
83  Value sub = arith::SubFOp::create(b, exp, nexp);
84  Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
85  Value res = arith::MulFOp::create(b, sub, half);
86  rewriter.replaceOp(op, res);
87  return success();
88 }
89 
90 // coshf(float x) -> (exp(x) + exp(-x)) / 2
91 static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
92  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
93  Value operand = op.getOperand();
94  Type opType = operand.getType();
95 
96  Value exp = math::ExpOp::create(b, operand);
97  Value neg = arith::NegFOp::create(b, operand);
98  Value nexp = math::ExpOp::create(b, neg);
99  Value add = arith::AddFOp::create(b, exp, nexp);
100  Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
101  Value res = arith::MulFOp::create(b, add, half);
102  rewriter.replaceOp(op, res);
103  return success();
104 }
105 
106 /// Expands tanh op into
107 /// 1-exp^{-2x} / 1+exp^{-2x}
108 /// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
109 /// We compute a "signs" value which is -1 if input is negative and +1 if input
110 /// is positive. Then multiply the input by this value, guaranteeing that the
111 /// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
112 /// 1]. Expand the computation on the input `x * sign(x)`, then multiply the
113 /// result by `sign(x)` to retain sign of the real result.
114 static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
115  auto floatType = op.getOperand().getType();
116  Location loc = op.getLoc();
117  Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
118  Value one = createFloatConst(loc, floatType, 1.0, rewriter);
119  Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
120 
121  // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
122  Value isNegative = arith::CmpFOp::create(
123  rewriter, loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
124  Value isNegativeFloat =
125  arith::UIToFPOp::create(rewriter, loc, floatType, isNegative);
126  Value isNegativeTimesNegTwo =
127  arith::MulFOp::create(rewriter, loc, isNegativeFloat, negTwo);
128  Value sign = arith::AddFOp::create(rewriter, loc, isNegativeTimesNegTwo, one);
129 
130  // Normalize input to positive value: y = sign(x) * x
131  Value positiveX = arith::MulFOp::create(rewriter, loc, sign, op.getOperand());
132 
133  // Decompose on normalized input
134  Value negDoubledX = arith::MulFOp::create(rewriter, loc, negTwo, positiveX);
135  Value exp2x = math::ExpOp::create(rewriter, loc, negDoubledX);
136  Value dividend = arith::SubFOp::create(rewriter, loc, one, exp2x);
137  Value divisor = arith::AddFOp::create(rewriter, loc, one, exp2x);
138  Value positiveRes = arith::DivFOp::create(rewriter, loc, dividend, divisor);
139 
140  // Multiply result by sign(x) to retain signs from negative inputs
141  rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
142 
143  return success();
144 }
145 
146 // Converts math.tan to math.sin, math.cos, and arith.divf.
147 static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
148  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
149  Value operand = op.getOperand();
150  Type type = operand.getType();
151  Value sin = math::SinOp::create(b, type, operand);
152  Value cos = math::CosOp::create(b, type, operand);
153  Value div = arith::DivFOp::create(b, type, sin, cos);
154  rewriter.replaceOp(op, div);
155  return success();
156 }
157 
158 // asinh(float x) -> log(x + sqrt(x**2 + 1))
159 static LogicalResult convertAsinhOp(math::AsinhOp op,
160  PatternRewriter &rewriter) {
161  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
162  Value operand = op.getOperand();
163  Type opType = operand.getType();
164 
165  Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
166  Value fma = math::FmaOp::create(b, operand, operand, one);
167  Value sqrt = math::SqrtOp::create(b, fma);
168  Value add = arith::AddFOp::create(b, operand, sqrt);
169  Value res = math::LogOp::create(b, add);
170  rewriter.replaceOp(op, res);
171  return success();
172 }
173 
174 // acosh(float x) -> log(x + sqrt(x**2 - 1))
175 static LogicalResult convertAcoshOp(math::AcoshOp op,
176  PatternRewriter &rewriter) {
177  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
178  Value operand = op.getOperand();
179  Type opType = operand.getType();
180 
181  Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
182  Value fma = math::FmaOp::create(b, operand, operand, negOne);
183  Value sqrt = math::SqrtOp::create(b, fma);
184  Value add = arith::AddFOp::create(b, operand, sqrt);
185  Value res = math::LogOp::create(b, add);
186  rewriter.replaceOp(op, res);
187  return success();
188 }
189 
190 // atanh(float x) -> log((1 + x) / (1 - x)) / 2
191 static LogicalResult convertAtanhOp(math::AtanhOp op,
192  PatternRewriter &rewriter) {
193  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
194  Value operand = op.getOperand();
195  Type opType = operand.getType();
196 
197  Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
198  Value add = arith::AddFOp::create(b, operand, one);
199  Value neg = arith::NegFOp::create(b, operand);
200  Value sub = arith::AddFOp::create(b, neg, one);
201  Value div = arith::DivFOp::create(b, add, sub);
202  Value log = math::LogOp::create(b, div);
203  Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
204  Value res = arith::MulFOp::create(b, log, half);
205  rewriter.replaceOp(op, res);
206  return success();
207 }
208 
209 static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
210  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
211  Value operandA = op.getOperand(0);
212  Value operandB = op.getOperand(1);
213  Value operandC = op.getOperand(2);
214  Type type = op.getType();
215  Value mult = arith::MulFOp::create(b, type, operandA, operandB);
216  Value add = arith::AddFOp::create(b, type, mult, operandC);
217  rewriter.replaceOp(op, add);
218  return success();
219 }
220 
221 // Converts a ceilf() function to the following:
222 // ceilf(float x) ->
223 // y = (float)(int) x
224 // if (x > y) then incr = 1 else incr = 0
225 // y = y + incr <= replace this op with the ceilf op.
226 static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
227  // Creating constants assumes the static shaped type.
228  auto shapedType = dyn_cast<ShapedType>(op.getType());
229  if (shapedType && !shapedType.hasStaticShape())
230  return failure();
231 
232  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
233  Value operand = op.getOperand();
234  Type opType = operand.getType();
235  Value fpFixedConvert = createTruncatedFPValue(operand, b);
236 
237  // Creating constants for later use.
238  Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
239  Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
240 
241  Value gtCheck = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT, operand,
242  fpFixedConvert);
243  Value incrValue =
244  arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);
245 
246  Value ret = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
247  rewriter.replaceOp(op, ret);
248  return success();
249 }
250 
251 // Convert `math.fpowi` to a series of `arith.mulf` operations.
252 // If the power is negative, we divide one by the result.
253 // If both the base and power are zero, the result is 1.
254 // In the case of non constant power, we convert the operation to `math.powf`.
255 static LogicalResult convertFPowIOp(math::FPowIOp op,
256  PatternRewriter &rewriter) {
257  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
258  Value base = op.getOperand(0);
259  Value power = op.getOperand(1);
260  Type baseType = base.getType();
261 
262  auto convertFPowItoPowf = [&]() -> LogicalResult {
263  Value castPowerToFp =
264  arith::SIToFPOp::create(rewriter, op.getLoc(), baseType, power);
265  Value res = math::PowFOp::create(rewriter, op.getLoc(), baseType, base,
266  castPowerToFp);
267  rewriter.replaceOp(op, res);
268  return success();
269  };
270 
271  Attribute cstAttr;
272  if (!matchPattern(power, m_Constant(&cstAttr)))
273  return convertFPowItoPowf();
274 
275  APInt value;
276  if (!matchPattern(cstAttr, m_ConstantInt(&value)))
277  return convertFPowItoPowf();
278 
279  int64_t powerInt = value.getSExtValue();
280  bool isNegative = powerInt < 0;
281  int64_t absPower = std::abs(powerInt);
282  Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
283  Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
284 
285  while (absPower > 0) {
286  if (absPower & 1)
287  res = arith::MulFOp::create(b, baseType, base, res);
288  absPower >>= 1;
289  base = arith::MulFOp::create(b, baseType, base, base);
290  }
291 
292  // Make sure not to introduce UB in case of negative power.
293  if (isNegative) {
294  auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
295  .getFloatSemantics();
296  Value zero =
297  createFloatConst(op->getLoc(), baseType,
298  APFloat::getZero(sem, /*Negative=*/false), rewriter);
299  Value negZero =
300  createFloatConst(op->getLoc(), baseType,
301  APFloat::getZero(sem, /*Negative=*/true), rewriter);
302  Value posInfinity =
303  createFloatConst(op->getLoc(), baseType,
304  APFloat::getInf(sem, /*Negative=*/false), rewriter);
305  Value negInfinity =
306  createFloatConst(op->getLoc(), baseType,
307  APFloat::getInf(sem, /*Negative=*/true), rewriter);
308  Value zeroEqCheck =
309  arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, zero);
310  Value negZeroEqCheck =
311  arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, negZero);
312  res = arith::DivFOp::create(b, baseType, one, res);
313  res =
314  arith::SelectOp::create(b, op->getLoc(), zeroEqCheck, posInfinity, res);
315  res = arith::SelectOp::create(b, op->getLoc(), negZeroEqCheck, negInfinity,
316  res);
317  }
318 
319  rewriter.replaceOp(op, res);
320  return success();
321 }
322 
323 // Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
324 // Some special cases where b is constant are handled separately:
325 // when b == 0, or |b| == 0.5, 1.0, or 2.0.
326 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
327  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
328  Value operandA = op.getOperand(0);
329  Value operandB = op.getOperand(1);
330  auto typeA = operandA.getType();
331  auto typeB = operandB.getType();
332 
333  auto &sem =
334  cast<mlir::FloatType>(getElementTypeOrSelf(typeB)).getFloatSemantics();
335  APFloat valueB(sem);
336  auto mulf = [&](Value x, Value y) -> Value {
337  return arith::MulFOp::create(b, x, y);
338  };
339  if (matchPattern(operandB, m_ConstantFloat(&valueB))) {
340  if (valueB.isZero()) {
341  // a^0 -> 1
342  Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
343  rewriter.replaceOp(op, one);
344  return success();
345  }
346  if (valueB.isExactlyValue(1.0)) {
347  // a^1 -> a
348  rewriter.replaceOp(op, operandA);
349  return success();
350  }
351  if (valueB.isExactlyValue(-1.0)) {
352  // a^(-1) -> 1 / a
353  Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
354  Value div = arith::DivFOp::create(b, one, operandA);
355  rewriter.replaceOp(op, div);
356  return success();
357  }
358  if (valueB.isExactlyValue(0.5)) {
359  // a^(1/2) -> sqrt(a)
360  Value sqrt = math::SqrtOp::create(b, operandA);
361  rewriter.replaceOp(op, sqrt);
362  return success();
363  }
364  if (valueB.isExactlyValue(-0.5)) {
365  // a^(-1/2) -> 1 / sqrt(a)
366  Value rsqrt = math::RsqrtOp::create(b, operandA);
367  rewriter.replaceOp(op, rsqrt);
368  return success();
369  }
370  if (valueB.isExactlyValue(2.0)) {
371  // a^2 -> a * a
372  rewriter.replaceOp(op, mulf(operandA, operandA));
373  return success();
374  }
375  if (valueB.isExactlyValue(-2.0)) {
376  // a^(-2) -> 1 / (a * a)
377  Value one =
378  createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
379  Value div = arith::DivFOp::create(b, one, mulf(operandA, operandA));
380  rewriter.replaceOp(op, div);
381  return success();
382  }
383  if (valueB.isExactlyValue(3.0)) {
384  rewriter.replaceOp(op, mulf(mulf(operandA, operandA), operandA));
385  return success();
386  }
387  }
388 
389  Value logA = math::LogOp::create(b, operandA);
390  Value mult = arith::MulFOp::create(b, operandB, logA);
391  Value expResult = math::ExpOp::create(b, mult);
392  rewriter.replaceOp(op, expResult);
393  return success();
394 }
395 
396 // exp2f(float x) -> exp(x * ln(2))
397 // Proof: Let's say 2^x = y
398 // ln(2^x) = ln(y)
399 // x * ln(2) = ln(y) => e ^(x*ln(2)) = y
400 static LogicalResult convertExp2fOp(math::Exp2Op op,
401  PatternRewriter &rewriter) {
402  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
403  Value operand = op.getOperand();
404  Type opType = operand.getType();
405  Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
406  Value mult = arith::MulFOp::create(b, opType, operand, ln2);
407  Value exp = math::ExpOp::create(b, op->getLoc(), mult);
408  rewriter.replaceOp(op, exp);
409  return success();
410 }
411 
412 static LogicalResult convertRoundOp(math::RoundOp op,
413  PatternRewriter &rewriter) {
414  Location loc = op.getLoc();
415  ImplicitLocOpBuilder b(loc, rewriter);
416  Value operand = op.getOperand();
417  Type opType = operand.getType();
418  Type opEType = getElementTypeOrSelf(opType);
419 
420  if (!opEType.isF32()) {
421  return rewriter.notifyMatchFailure(op, "not a round of f32.");
422  }
423 
424  Type i32Ty = b.getI32Type();
425  if (auto shapedTy = dyn_cast<ShapedType>(opType))
426  i32Ty = shapedTy.clone(i32Ty);
427 
428  Value half = createFloatConst(loc, opType, 0.5, b);
429  Value c23 = createIntConst(loc, i32Ty, 23, b);
430  Value c127 = createIntConst(loc, i32Ty, 127, b);
431  Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
432 
433  Value incrValue = math::CopySignOp::create(b, half, operand);
434  Value add = arith::AddFOp::create(b, opType, operand, incrValue);
435  Value fpFixedConvert = createTruncatedFPValue(add, b);
436 
437  // There are three cases where adding 0.5 to the value and truncating by
438  // converting to an i64 does not result in the correct behavior:
439  //
440  // 1. Special values: +-inf and +-nan
441  // Casting these special values to i64 has undefined behavior. To identify
442  // these values, we use the fact that these values are the only float
443  // values with the maximum possible biased exponent.
444  //
445  // 2. Large values: 2^23 <= |x| <= INT_64_MAX
446  // Adding 0.5 to a float larger than or equal to 2^23 results in precision
447  // errors that sometimes round the value up and sometimes round the value
448  // down. For example:
449  // 8388608.0 + 0.5 = 8388608.0
450  // 8388609.0 + 0.5 = 8388610.0
451  //
452  // 3. Very large values: |x| > INT_64_MAX
453  // Casting to i64 a value greater than the max i64 value will overflow the
454  // i64 leading to wrong outputs.
455  //
456  // All three cases satisfy the property `biasedExp >= 23`.
457  Value operandBitcast = arith::BitcastOp::create(b, i32Ty, operand);
458  Value operandExp = arith::AndIOp::create(
459  b, arith::ShRUIOp::create(b, operandBitcast, c23), expMask);
460  Value operandBiasedExp = arith::SubIOp::create(b, operandExp, c127);
461  Value isSpecialValOrLargeVal = arith::CmpIOp::create(
462  b, arith::CmpIPredicate::sge, operandBiasedExp, c23);
463 
464  Value result = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand,
465  fpFixedConvert);
466  rewriter.replaceOp(op, result);
467  return success();
468 }
469 
470 // Converts math.ctlz to scf and arith operations. This is done
471 // by performing a binary search on the bits.
472 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
473  PatternRewriter &rewriter) {
474  auto operand = op.getOperand();
475  auto operandTy = operand.getType();
476  auto eTy = getElementTypeOrSelf(operandTy);
477  Location loc = op.getLoc();
478 
479  int32_t bitwidth = eTy.getIntOrFloatBitWidth();
480  if (bitwidth > 64)
481  return failure();
482 
483  uint64_t allbits = -1;
484  if (bitwidth < 64) {
485  allbits = allbits >> (64 - bitwidth);
486  }
487 
488  Value x = operand;
489  Value count = createIntConst(loc, operandTy, 0, rewriter);
490  for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
491  auto half = bw / 2;
492  auto bits = createIntConst(loc, operandTy, half, rewriter);
493  auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter);
494 
495  Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ule,
496  x, mask);
497  Value add = arith::AddIOp::create(rewriter, loc, count, bits);
498  Value shift = arith::ShLIOp::create(rewriter, loc, x, bits);
499 
500  x = arith::SelectOp::create(rewriter, loc, pred, shift, x);
501  count = arith::SelectOp::create(rewriter, loc, pred, add, count);
502  }
503 
504  Value zero = createIntConst(loc, operandTy, 0, rewriter);
505  Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
506  operand, zero);
507 
508  Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter);
509  Value sel = arith::SelectOp::create(rewriter, loc, pred, bwval, count);
510  rewriter.replaceOp(op, sel);
511  return success();
512 }
513 
514 // Convert `math.roundeven` into `math.round` + arith ops
515 static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
516  PatternRewriter &rewriter) {
517  Location loc = op.getLoc();
518  ImplicitLocOpBuilder b(loc, rewriter);
519  auto operand = op.getOperand();
520  Type operandTy = operand.getType();
521  Type resultTy = op.getType();
522  Type operandETy = getElementTypeOrSelf(operandTy);
523  Type resultETy = getElementTypeOrSelf(resultTy);
524 
525  if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
526  return rewriter.notifyMatchFailure(op, "not a roundeven of f16 or f32.");
527  }
528 
529  Type fTy = operandTy;
530  Type iTy = rewriter.getIntegerType(operandETy.getIntOrFloatBitWidth());
531  if (auto shapedTy = dyn_cast<ShapedType>(fTy)) {
532  iTy = shapedTy.clone(iTy);
533  }
534 
535  unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
536  // The width returned by getFPMantissaWidth includes the integer bit.
537  unsigned mantissaWidth =
538  llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
539  unsigned exponentWidth = bitWidth - mantissaWidth - 1;
540 
541  // The names of the variables correspond to f32.
542  // f64: 1 bit sign | 11 bits exponent | 52 bits mantissa.
543  // f32: 1 bit sign | 8 bits exponent | 23 bits mantissa.
544  // f16: 1 bit sign | 5 bits exponent | 10 bits mantissa.
545  Value c1Float = createFloatConst(loc, fTy, 1.0, b);
546  Value c0 = createIntConst(loc, iTy, 0, b);
547  Value c1 = createIntConst(loc, iTy, 1, b);
548  Value cNeg1 = createIntConst(loc, iTy, -1, b);
549  Value c23 = createIntConst(loc, iTy, mantissaWidth, b);
550  Value c31 = createIntConst(loc, iTy, bitWidth - 1, b);
551  Value c127 = createIntConst(loc, iTy, (1ull << (exponentWidth - 1)) - 1, b);
552  Value c2To22 = createIntConst(loc, iTy, 1ull << (mantissaWidth - 1), b);
553  Value c23Mask = createIntConst(loc, iTy, (1ull << mantissaWidth) - 1, b);
554  Value expMask = createIntConst(loc, iTy, (1ull << exponentWidth) - 1, b);
555 
556  Value operandBitcast = arith::BitcastOp::create(b, iTy, operand);
557  Value round = math::RoundOp::create(b, operand);
558  Value roundBitcast = arith::BitcastOp::create(b, iTy, round);
559 
560  // Get biased exponents for operand and round(operand)
561  Value operandExp = arith::AndIOp::create(
562  b, arith::ShRUIOp::create(b, operandBitcast, c23), expMask);
563  Value operandBiasedExp = arith::SubIOp::create(b, operandExp, c127);
564  Value roundExp = arith::AndIOp::create(
565  b, arith::ShRUIOp::create(b, roundBitcast, c23), expMask);
566  Value roundBiasedExp = arith::SubIOp::create(b, roundExp, c127);
567 
568  auto safeShiftRight = [&](Value x, Value shift) -> Value {
569  // Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior
570  Value clampedShift = arith::MaxSIOp::create(b, shift, c0);
571  clampedShift = arith::MinSIOp::create(b, clampedShift, c31);
572  return arith::ShRUIOp::create(b, x, clampedShift);
573  };
574 
575  auto maskMantissa = [&](Value mantissa,
576  Value mantissaMaskRightShift) -> Value {
577  Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
578  return arith::AndIOp::create(b, mantissa, shiftedMantissaMask);
579  };
580 
581  // A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring
582  // the leftmost `clamp(biasedExp - 1, 0, 23)` bits, is zero. Large numbers
583  // with `biasedExp > 23` (numbers where there is not enough precision to store
584  // decimals) are always even, and they satisfy the even condition trivially
585  // since the mantissa without all its bits is zero. The even condition
586  // is also true for +-0, since they have `biasedExp = -127` and the entire
587  // mantissa is zero. The case of +-1 has to be handled separately. Here
588  // we identify these values by noting that +-1 are the only whole numbers with
589  // `biasedExp == 0`.
590  //
591  // The special values +-inf and +-nan also satisfy the same property that
592  // whole non-unit even numbers satisfy. In particular, the special values have
593  // `biasedExp > 23`, so they get treated as large numbers with no room for
594  // decimals, which are always even.
595  Value roundBiasedExpEq0 =
596  arith::CmpIOp::create(b, arith::CmpIPredicate::eq, roundBiasedExp, c0);
597  Value roundBiasedExpMinus1 = arith::SubIOp::create(b, roundBiasedExp, c1);
598  Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
599  Value roundIsNotEvenOrSpecialVal = arith::CmpIOp::create(
600  b, arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
601  roundIsNotEvenOrSpecialVal =
602  arith::OrIOp::create(b, roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
603 
604  // A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive
605  // integers if the bit at index `biasedExp` starting from the left in the
606  // mantissa is 1 and all the bits to the right are zero. Values with
607  // `biasedExp >= 23` don't have decimals, so they are never halfway. The
608  // values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`,
609  // so these are handled separately. In particular, if `biasedExp == -1`, the
610  // value is halfway if the entire mantissa is zero.
611  Value operandBiasedExpEqNeg1 = arith::CmpIOp::create(
612  b, arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
613  Value expectedOperandMaskedMantissa = arith::SelectOp::create(
614  b, operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
615  Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
616  Value operandIsHalfway =
617  arith::CmpIOp::create(b, arith::CmpIPredicate::eq, operandMaskedMantissa,
618  expectedOperandMaskedMantissa);
619  // Ensure `biasedExp` is in the valid range for half values.
620  Value operandBiasedExpGeNeg1 = arith::CmpIOp::create(
621  b, arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
622  Value operandBiasedExpLt23 = arith::CmpIOp::create(
623  b, arith::CmpIPredicate::slt, operandBiasedExp, c23);
624  operandIsHalfway =
625  arith::AndIOp::create(b, operandIsHalfway, operandBiasedExpLt23);
626  operandIsHalfway =
627  arith::AndIOp::create(b, operandIsHalfway, operandBiasedExpGeNeg1);
628 
629  // Adjust rounded operand with `round(operand) - sign(operand)` to correct the
630  // case where `round` rounded in the opposite direction of `roundeven`.
631  Value sign = math::CopySignOp::create(b, c1Float, operand);
632  Value roundShifted = arith::SubFOp::create(b, round, sign);
633  // If the rounded value is even or a special value, we default to the behavior
634  // of `math.round`.
635  Value needsShift =
636  arith::AndIOp::create(b, roundIsNotEvenOrSpecialVal, operandIsHalfway);
637  Value result = arith::SelectOp::create(b, needsShift, roundShifted, round);
638  // The `x - sign` adjustment does not preserve the sign when we are adjusting
639  // the value -1 to -0. So here the sign is copied again to ensure that -0.5 is
640  // rounded to -0.0.
641  result = math::CopySignOp::create(b, result, operand);
642  rewriter.replaceOp(op, result);
643  return success();
644 }
645 
646 // Convert `math.rsqrt` into `arith.divf` + `math.sqrt`
647 static LogicalResult convertRsqrtOp(math::RsqrtOp op,
648  PatternRewriter &rewriter) {
649 
650  auto operand = op.getOperand();
651  auto operandTy = operand.getType();
652  // Operand type must be shatic shaped type to create const float.
653  auto shapedOperandType = dyn_cast<ShapedType>(operandTy);
654  if (shapedOperandType && !shapedOperandType.hasStaticShape())
655  return failure();
656 
657  auto eTy = getElementTypeOrSelf(operandTy);
658  if (!isa<FloatType>(eTy))
659  return failure();
660 
661  Location loc = op->getLoc();
662  auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);
663  auto sqrtOp = math::SqrtOp::create(rewriter, loc, operand);
664  rewriter.replaceOpWithNewOp<arith::DivFOp>(op, constOneFloat, sqrtOp);
665  return success();
666 }
667 
668 // Convert `math.clampf` into `arith.minimumf` + `arith.maximumf`
669 static LogicalResult convertClampfOp(math::ClampFOp op,
670  PatternRewriter &rewriter) {
671  auto minOp = arith::MinimumFOp::create(rewriter, op.getLoc(), op.getValue(),
672  op.getMin(), op.getFastmath());
673  rewriter.replaceOpWithNewOp<arith::MaximumFOp>(op, minOp, op.getMax(),
674  op.getFastmath());
675  return success();
676 }
677 
679  ArrayRef<StringRef> opMnemonics) {
680  auto filter = [&](StringRef name) {
681  // This should be a static assert and `consume_front` take a twine, but none
682  // is currently possible. TODO: augment `StringRef::consume_front` and make
683  // `getDialectNamespace` use `std::string_view`.
684  assert("math" == MathDialect::getDialectNamespace());
685  name.consume_front("math.");
686  return opMnemonics.empty() || (llvm::count(opMnemonics, name) > 0);
687  };
688  if (filter(CountLeadingZerosOp::getOperationName()))
689  patterns.add(convertCtlzOp);
690  if (filter(SinhOp::getOperationName()))
691  patterns.add(convertSinhOp);
692  if (filter(CoshOp::getOperationName()))
693  patterns.add(convertCoshOp);
694  if (filter(TanOp::getOperationName()))
695  patterns.add(convertTanOp);
696  if (filter(TanhOp::getOperationName()))
697  patterns.add(convertTanhOp);
698  if (filter(AsinhOp::getOperationName()))
700  if (filter(AcoshOp::getOperationName()))
702  if (filter(AtanhOp::getOperationName()))
704  if (filter(FmaOp::getOperationName()))
705  patterns.add(convertFmaFOp);
706  if (filter(CeilOp::getOperationName()))
707  patterns.add(convertCeilOp);
708  if (filter(Exp2Op::getOperationName()))
710  if (filter(PowFOp::getOperationName()))
711  patterns.add(convertPowfOp);
712  if (filter(FPowIOp::getOperationName()))
714  if (filter(RoundOp::getOperationName()))
716  if (filter(RoundEvenOp::getOperationName()))
718  if (filter(RsqrtOp::getOperationName()))
720  if (filter(ClampFOp::getOperationName()))
722 }
723 
724 //===----------------------------------------------------------------------===//
725 // MathExpandOpsPass pass
726 //===----------------------------------------------------------------------===//
727 namespace {
728 struct MathExpandOpsPass final
729  : math::impl::MathExpandOpsPassBase<MathExpandOpsPass> {
730  using MathExpandOpsPassBase::MathExpandOpsPassBase;
731 
732  void runOnOperation() override {
734  SmallVector<StringRef> mnemonics =
735  llvm::to_vector_of<StringRef>(opMnemonics);
737  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
738  return signalPassFailure();
739  }
740 };
741 } // namespace
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult convertRsqrtOp(math::RsqrtOp op, PatternRewriter &rewriter)
Definition: ExpandOps.cpp:647
static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b)
Definition: ExpandOps.cpp:62
static LogicalResult convertFPowIOp(math::FPowIOp op, PatternRewriter &rewriter)
Definition: ExpandOps.cpp:255
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter)
Definition: ExpandOps.cpp:326
static LogicalResult convertClampfOp(math::ClampFOp op, PatternRewriter &rewriter)
Definition: ExpandOps.cpp:669
static LogicalResult convertRoundOp(math::RoundOp op, PatternRewriter &rewriter)
Definition: ExpandOps.cpp:412
static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter)
Definition: ExpandOps.cpp:147
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, PatternRewriter &rewriter)
Definition: ExpandOps.cpp:472
static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter)
Definition: ExpandOps.cpp:209
static LogicalResult convertAtanhOp(math::AtanhOp op, PatternRewriter &rewriter)
Definition: ExpandOps.cpp:191
static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter)
Definition: ExpandOps.cpp:226
static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, PatternRewriter &rewriter)
Definition: ExpandOps.cpp:515
static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter)
Definition: ExpandOps.cpp:91
static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter)
Definition: ExpandOps.cpp:75
static Value createFloatConst(Location loc, Type type, APFloat value, OpBuilder &b)
Create a float constant.
Definition: ExpandOps.cpp:29
static LogicalResult convertAsinhOp(math::AsinhOp op, PatternRewriter &rewriter)
Definition: ExpandOps.cpp:159
static Value createIntConst(Location loc, Type type, int64_t value, OpBuilder &b)
Create an integer constant.
Definition: ExpandOps.cpp:51
static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter)
Expands tanh op into 1-exp^{-2x} / 1+exp^{-2x} To avoid overflow we exploit the reflection symmetry t...
Definition: ExpandOps.cpp:114
static LogicalResult convertAcoshOp(math::AcoshOp op, PatternRewriter &rewriter)
Definition: ExpandOps.cpp:175
static LogicalResult convertExp2fOp(math::Exp2Op op, PatternRewriter &rewriter)
Definition: ExpandOps.cpp:400
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:227
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:253
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:623
Location getLoc() const
Accessors for the implied location.
Definition: Builders.h:656
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:207
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
void populateExpansionPatterns(RewritePatternSet &patterns, ArrayRef< StringRef > opMnemonics={})
Adds patterns to expand math operations into other more fundamental operations.
Definition: ExpandOps.cpp:678
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
detail::LazyTextBuild add(const char *fmt, Ts &&...ts)
Create a Remark with llvm::formatv formatting.
Definition: Remarks.h:463
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
detail::constant_float_value_binder m_ConstantFloat(FloatAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor float (splat) and writes the float value to bind_va...
Definition: Matchers.h:520