MLIR  22.0.0git
ExpandPatterns.cpp
Go to the documentation of this file.
1 //===- ExpandPatterns.cpp - Code to expand various math operations. -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements expansion of various math operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/TypeUtilities.h"
21 
22 using namespace mlir;
23 
24 /// Create a float constant.
25 static Value createFloatConst(Location loc, Type type, APFloat value,
26  OpBuilder &b) {
27  bool losesInfo = false;
28  auto eltType = getElementTypeOrSelf(type);
29  // Convert double to the given `FloatType` with round-to-nearest-ties-to-even.
30  value.convert(cast<FloatType>(eltType).getFloatSemantics(),
31  APFloat::rmNearestTiesToEven, &losesInfo);
32  auto attr = b.getFloatAttr(eltType, value);
33  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
34  return arith::ConstantOp::create(b, loc,
35  DenseElementsAttr::get(shapedTy, attr));
36  }
37 
38  return arith::ConstantOp::create(b, loc, attr);
39 }
40 
41 static Value createFloatConst(Location loc, Type type, double value,
42  OpBuilder &b) {
43  return createFloatConst(loc, type, APFloat(value), b);
44 }
45 
46 /// Create an integer constant.
47 static Value createIntConst(Location loc, Type type, int64_t value,
48  OpBuilder &b) {
49  auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
50  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
51  return arith::ConstantOp::create(b, loc,
52  DenseElementsAttr::get(shapedTy, attr));
53  }
54 
55  return arith::ConstantOp::create(b, loc, attr);
56 }
57 
59  Type opType = operand.getType();
60  Type i64Ty = b.getI64Type();
61  if (auto shapedTy = dyn_cast<ShapedType>(opType))
62  i64Ty = shapedTy.clone(i64Ty);
63  Value fixedConvert = arith::FPToSIOp::create(b, i64Ty, operand);
64  Value fpFixedConvert = arith::SIToFPOp::create(b, opType, fixedConvert);
65  // The truncation does not preserve the sign when the truncated
66  // value is -0. So here the sign is copied again.
67  return math::CopySignOp::create(b, fpFixedConvert, operand);
68 }
69 
70 // sinhf(float x) -> (exp(x) - exp(-x)) / 2
71 static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
72  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
73  Value operand = op.getOperand();
74  Type opType = operand.getType();
75 
76  Value exp = math::ExpOp::create(b, operand);
77  Value neg = arith::NegFOp::create(b, operand);
78  Value nexp = math::ExpOp::create(b, neg);
79  Value sub = arith::SubFOp::create(b, exp, nexp);
80  Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
81  Value res = arith::MulFOp::create(b, sub, half);
82  rewriter.replaceOp(op, res);
83  return success();
84 }
85 
86 // coshf(float x) -> (exp(x) + exp(-x)) / 2
87 static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
88  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
89  Value operand = op.getOperand();
90  Type opType = operand.getType();
91 
92  Value exp = math::ExpOp::create(b, operand);
93  Value neg = arith::NegFOp::create(b, operand);
94  Value nexp = math::ExpOp::create(b, neg);
95  Value add = arith::AddFOp::create(b, exp, nexp);
96  Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
97  Value res = arith::MulFOp::create(b, add, half);
98  rewriter.replaceOp(op, res);
99  return success();
100 }
101 
102 /// Expands tanh op into
103 /// 1-exp^{-2x} / 1+exp^{-2x}
104 /// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
105 /// We compute a "signs" value which is -1 if input is negative and +1 if input
106 /// is positive. Then multiply the input by this value, guaranteeing that the
107 /// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
108 /// 1]. Expand the computation on the input `x * sign(x)`, then multiply the
109 /// result by `sign(x)` to retain sign of the real result.
110 static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
111  auto floatType = op.getOperand().getType();
112  Location loc = op.getLoc();
113  Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
114  Value one = createFloatConst(loc, floatType, 1.0, rewriter);
115  Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
116 
117  // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
118  Value isNegative = arith::CmpFOp::create(
119  rewriter, loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
120  Value isNegativeFloat =
121  arith::UIToFPOp::create(rewriter, loc, floatType, isNegative);
122  Value isNegativeTimesNegTwo =
123  arith::MulFOp::create(rewriter, loc, isNegativeFloat, negTwo);
124  Value sign = arith::AddFOp::create(rewriter, loc, isNegativeTimesNegTwo, one);
125 
126  // Normalize input to positive value: y = sign(x) * x
127  Value positiveX = arith::MulFOp::create(rewriter, loc, sign, op.getOperand());
128 
129  // Decompose on normalized input
130  Value negDoubledX = arith::MulFOp::create(rewriter, loc, negTwo, positiveX);
131  Value exp2x = math::ExpOp::create(rewriter, loc, negDoubledX);
132  Value dividend = arith::SubFOp::create(rewriter, loc, one, exp2x);
133  Value divisor = arith::AddFOp::create(rewriter, loc, one, exp2x);
134  Value positiveRes = arith::DivFOp::create(rewriter, loc, dividend, divisor);
135 
136  // Multiply result by sign(x) to retain signs from negative inputs
137  rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
138 
139  return success();
140 }
141 
142 // Converts math.tan to math.sin, math.cos, and arith.divf.
143 static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
144  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
145  Value operand = op.getOperand();
146  Type type = operand.getType();
147  Value sin = math::SinOp::create(b, type, operand);
148  Value cos = math::CosOp::create(b, type, operand);
149  Value div = arith::DivFOp::create(b, type, sin, cos);
150  rewriter.replaceOp(op, div);
151  return success();
152 }
153 
154 // asinh(float x) -> log(x + sqrt(x**2 + 1))
155 static LogicalResult convertAsinhOp(math::AsinhOp op,
156  PatternRewriter &rewriter) {
157  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
158  Value operand = op.getOperand();
159  Type opType = operand.getType();
160 
161  Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
162  Value fma = math::FmaOp::create(b, operand, operand, one);
163  Value sqrt = math::SqrtOp::create(b, fma);
164  Value add = arith::AddFOp::create(b, operand, sqrt);
165  Value res = math::LogOp::create(b, add);
166  rewriter.replaceOp(op, res);
167  return success();
168 }
169 
170 // acosh(float x) -> log(x + sqrt(x**2 - 1))
171 static LogicalResult convertAcoshOp(math::AcoshOp op,
172  PatternRewriter &rewriter) {
173  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
174  Value operand = op.getOperand();
175  Type opType = operand.getType();
176 
177  Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
178  Value fma = math::FmaOp::create(b, operand, operand, negOne);
179  Value sqrt = math::SqrtOp::create(b, fma);
180  Value add = arith::AddFOp::create(b, operand, sqrt);
181  Value res = math::LogOp::create(b, add);
182  rewriter.replaceOp(op, res);
183  return success();
184 }
185 
186 // atanh(float x) -> log((1 + x) / (1 - x)) / 2
187 static LogicalResult convertAtanhOp(math::AtanhOp op,
188  PatternRewriter &rewriter) {
189  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
190  Value operand = op.getOperand();
191  Type opType = operand.getType();
192 
193  Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
194  Value add = arith::AddFOp::create(b, operand, one);
195  Value neg = arith::NegFOp::create(b, operand);
196  Value sub = arith::AddFOp::create(b, neg, one);
197  Value div = arith::DivFOp::create(b, add, sub);
198  Value log = math::LogOp::create(b, div);
199  Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
200  Value res = arith::MulFOp::create(b, log, half);
201  rewriter.replaceOp(op, res);
202  return success();
203 }
204 
205 static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
206  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
207  Value operandA = op.getOperand(0);
208  Value operandB = op.getOperand(1);
209  Value operandC = op.getOperand(2);
210  Type type = op.getType();
211  Value mult = arith::MulFOp::create(b, type, operandA, operandB);
212  Value add = arith::AddFOp::create(b, type, mult, operandC);
213  rewriter.replaceOp(op, add);
214  return success();
215 }
216 
217 // Converts a ceilf() function to the following:
218 // ceilf(float x) ->
219 // y = (float)(int) x
220 // if (x > y) then incr = 1 else incr = 0
221 // y = y + incr <= replace this op with the ceilf op.
222 static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
223  // Creating constants assumes the static shaped type.
224  auto shapedType = dyn_cast<ShapedType>(op.getType());
225  if (shapedType && !shapedType.hasStaticShape())
226  return failure();
227 
228  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
229  Value operand = op.getOperand();
230  Type opType = operand.getType();
231  Value fpFixedConvert = createTruncatedFPValue(operand, b);
232 
233  // Creating constants for later use.
234  Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
235  Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
236 
237  Value gtCheck = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT, operand,
238  fpFixedConvert);
239  Value incrValue =
240  arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);
241 
242  Value ret = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
243  rewriter.replaceOp(op, ret);
244  return success();
245 }
246 
247 // Convert `math.fpowi` to a series of `arith.mulf` operations.
248 // If the power is negative, we divide one by the result.
249 // If both the base and power are zero, the result is 1.
250 // In the case of non constant power, we convert the operation to `math.powf`.
251 static LogicalResult convertFPowIOp(math::FPowIOp op,
252  PatternRewriter &rewriter) {
253  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
254  Value base = op.getOperand(0);
255  Value power = op.getOperand(1);
256  Type baseType = base.getType();
257 
258  auto convertFPowItoPowf = [&]() -> LogicalResult {
259  Value castPowerToFp =
260  arith::SIToFPOp::create(rewriter, op.getLoc(), baseType, power);
261  Value res = math::PowFOp::create(rewriter, op.getLoc(), baseType, base,
262  castPowerToFp);
263  rewriter.replaceOp(op, res);
264  return success();
265  };
266 
267  Attribute cstAttr;
268  if (!matchPattern(power, m_Constant(&cstAttr)))
269  return convertFPowItoPowf();
270 
271  APInt value;
272  if (!matchPattern(cstAttr, m_ConstantInt(&value)))
273  return convertFPowItoPowf();
274 
275  int64_t powerInt = value.getSExtValue();
276  bool isNegative = powerInt < 0;
277  int64_t absPower = std::abs(powerInt);
278  Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
279  Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
280 
281  while (absPower > 0) {
282  if (absPower & 1)
283  res = arith::MulFOp::create(b, baseType, base, res);
284  absPower >>= 1;
285  base = arith::MulFOp::create(b, baseType, base, base);
286  }
287 
288  // Make sure not to introduce UB in case of negative power.
289  if (isNegative) {
290  auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
291  .getFloatSemantics();
292  Value zero =
293  createFloatConst(op->getLoc(), baseType,
294  APFloat::getZero(sem, /*Negative=*/false), rewriter);
295  Value negZero =
296  createFloatConst(op->getLoc(), baseType,
297  APFloat::getZero(sem, /*Negative=*/true), rewriter);
298  Value posInfinity =
299  createFloatConst(op->getLoc(), baseType,
300  APFloat::getInf(sem, /*Negative=*/false), rewriter);
301  Value negInfinity =
302  createFloatConst(op->getLoc(), baseType,
303  APFloat::getInf(sem, /*Negative=*/true), rewriter);
304  Value zeroEqCheck =
305  arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, zero);
306  Value negZeroEqCheck =
307  arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, negZero);
308  res = arith::DivFOp::create(b, baseType, one, res);
309  res =
310  arith::SelectOp::create(b, op->getLoc(), zeroEqCheck, posInfinity, res);
311  res = arith::SelectOp::create(b, op->getLoc(), negZeroEqCheck, negInfinity,
312  res);
313  }
314 
315  rewriter.replaceOp(op, res);
316  return success();
317 }
318 
319 // Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
320 // Some special cases where b is constant are handled separately:
321 // when b == 0, or |b| == 0.5, 1.0, or 2.0.
322 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
323  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
324  Value operandA = op.getOperand(0);
325  Value operandB = op.getOperand(1);
326  auto typeA = operandA.getType();
327  auto typeB = operandB.getType();
328 
329  auto &sem =
330  cast<mlir::FloatType>(getElementTypeOrSelf(typeB)).getFloatSemantics();
331  APFloat valueB(sem);
332  auto mulf = [&](Value x, Value y) -> Value {
333  return arith::MulFOp::create(b, x, y);
334  };
335  if (matchPattern(operandB, m_ConstantFloat(&valueB))) {
336  if (valueB.isZero()) {
337  // a^0 -> 1
338  Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
339  rewriter.replaceOp(op, one);
340  return success();
341  }
342  if (valueB.isExactlyValue(1.0)) {
343  // a^1 -> a
344  rewriter.replaceOp(op, operandA);
345  return success();
346  }
347  if (valueB.isExactlyValue(-1.0)) {
348  // a^(-1) -> 1 / a
349  Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
350  Value div = arith::DivFOp::create(b, one, operandA);
351  rewriter.replaceOp(op, div);
352  return success();
353  }
354  if (valueB.isExactlyValue(0.5)) {
355  // a^(1/2) -> sqrt(a)
356  Value sqrt = math::SqrtOp::create(b, operandA);
357  rewriter.replaceOp(op, sqrt);
358  return success();
359  }
360  if (valueB.isExactlyValue(-0.5)) {
361  // a^(-1/2) -> 1 / sqrt(a)
362  Value rsqrt = math::RsqrtOp::create(b, operandA);
363  rewriter.replaceOp(op, rsqrt);
364  return success();
365  }
366  if (valueB.isExactlyValue(2.0)) {
367  // a^2 -> a * a
368  rewriter.replaceOp(op, mulf(operandA, operandA));
369  return success();
370  }
371  if (valueB.isExactlyValue(-2.0)) {
372  // a^(-2) -> 1 / (a * a)
373  Value one =
374  createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
375  Value div = arith::DivFOp::create(b, one, mulf(operandA, operandA));
376  rewriter.replaceOp(op, div);
377  return success();
378  }
379  if (valueB.isExactlyValue(3.0)) {
380  rewriter.replaceOp(op, mulf(mulf(operandA, operandA), operandA));
381  return success();
382  }
383  }
384 
385  Value logA = math::LogOp::create(b, operandA);
386  Value mult = arith::MulFOp::create(b, operandB, logA);
387  Value expResult = math::ExpOp::create(b, mult);
388  rewriter.replaceOp(op, expResult);
389  return success();
390 }
391 
392 // exp2f(float x) -> exp(x * ln(2))
393 // Proof: Let's say 2^x = y
394 // ln(2^x) = ln(y)
395 // x * ln(2) = ln(y) => e ^(x*ln(2)) = y
396 static LogicalResult convertExp2fOp(math::Exp2Op op,
397  PatternRewriter &rewriter) {
398  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
399  Value operand = op.getOperand();
400  Type opType = operand.getType();
401  Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
402  Value mult = arith::MulFOp::create(b, opType, operand, ln2);
403  Value exp = math::ExpOp::create(b, op->getLoc(), mult);
404  rewriter.replaceOp(op, exp);
405  return success();
406 }
407 
408 static LogicalResult convertRoundOp(math::RoundOp op,
409  PatternRewriter &rewriter) {
410  Location loc = op.getLoc();
411  ImplicitLocOpBuilder b(loc, rewriter);
412  Value operand = op.getOperand();
413  Type opType = operand.getType();
414  Type opEType = getElementTypeOrSelf(opType);
415 
416  if (!opEType.isF32()) {
417  return rewriter.notifyMatchFailure(op, "not a round of f32.");
418  }
419 
420  Type i32Ty = b.getI32Type();
421  if (auto shapedTy = dyn_cast<ShapedType>(opType))
422  i32Ty = shapedTy.clone(i32Ty);
423 
424  Value half = createFloatConst(loc, opType, 0.5, b);
425  Value c23 = createIntConst(loc, i32Ty, 23, b);
426  Value c127 = createIntConst(loc, i32Ty, 127, b);
427  Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
428 
429  Value incrValue = math::CopySignOp::create(b, half, operand);
430  Value add = arith::AddFOp::create(b, opType, operand, incrValue);
431  Value fpFixedConvert = createTruncatedFPValue(add, b);
432 
433  // There are three cases where adding 0.5 to the value and truncating by
434  // converting to an i64 does not result in the correct behavior:
435  //
436  // 1. Special values: +-inf and +-nan
437  // Casting these special values to i64 has undefined behavior. To identify
438  // these values, we use the fact that these values are the only float
439  // values with the maximum possible biased exponent.
440  //
441  // 2. Large values: 2^23 <= |x| <= INT_64_MAX
442  // Adding 0.5 to a float larger than or equal to 2^23 results in precision
443  // errors that sometimes round the value up and sometimes round the value
444  // down. For example:
445  // 8388608.0 + 0.5 = 8388608.0
446  // 8388609.0 + 0.5 = 8388610.0
447  //
448  // 3. Very large values: |x| > INT_64_MAX
449  // Casting to i64 a value greater than the max i64 value will overflow the
450  // i64 leading to wrong outputs.
451  //
452  // All three cases satisfy the property `biasedExp >= 23`.
453  Value operandBitcast = arith::BitcastOp::create(b, i32Ty, operand);
454  Value operandExp = arith::AndIOp::create(
455  b, arith::ShRUIOp::create(b, operandBitcast, c23), expMask);
456  Value operandBiasedExp = arith::SubIOp::create(b, operandExp, c127);
457  Value isSpecialValOrLargeVal = arith::CmpIOp::create(
458  b, arith::CmpIPredicate::sge, operandBiasedExp, c23);
459 
460  Value result = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand,
461  fpFixedConvert);
462  rewriter.replaceOp(op, result);
463  return success();
464 }
465 
466 // Converts math.ctlz to scf and arith operations. This is done
467 // by performing a binary search on the bits.
468 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
469  PatternRewriter &rewriter) {
470  auto operand = op.getOperand();
471  auto operandTy = operand.getType();
472  auto eTy = getElementTypeOrSelf(operandTy);
473  Location loc = op.getLoc();
474 
475  int32_t bitwidth = eTy.getIntOrFloatBitWidth();
476  if (bitwidth > 64)
477  return failure();
478 
479  uint64_t allbits = -1;
480  if (bitwidth < 64) {
481  allbits = allbits >> (64 - bitwidth);
482  }
483 
484  Value x = operand;
485  Value count = createIntConst(loc, operandTy, 0, rewriter);
486  for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
487  auto half = bw / 2;
488  auto bits = createIntConst(loc, operandTy, half, rewriter);
489  auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter);
490 
491  Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ule,
492  x, mask);
493  Value add = arith::AddIOp::create(rewriter, loc, count, bits);
494  Value shift = arith::ShLIOp::create(rewriter, loc, x, bits);
495 
496  x = arith::SelectOp::create(rewriter, loc, pred, shift, x);
497  count = arith::SelectOp::create(rewriter, loc, pred, add, count);
498  }
499 
500  Value zero = createIntConst(loc, operandTy, 0, rewriter);
501  Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
502  operand, zero);
503 
504  Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter);
505  Value sel = arith::SelectOp::create(rewriter, loc, pred, bwval, count);
506  rewriter.replaceOp(op, sel);
507  return success();
508 }
509 
510 // Convert `math.roundeven` into `math.round` + arith ops
511 static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
512  PatternRewriter &rewriter) {
513  Location loc = op.getLoc();
514  ImplicitLocOpBuilder b(loc, rewriter);
515  auto operand = op.getOperand();
516  Type operandTy = operand.getType();
517  Type resultTy = op.getType();
518  Type operandETy = getElementTypeOrSelf(operandTy);
519  Type resultETy = getElementTypeOrSelf(resultTy);
520 
521  if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
522  return rewriter.notifyMatchFailure(op, "not a roundeven of f16 or f32.");
523  }
524 
525  Type fTy = operandTy;
526  Type iTy = rewriter.getIntegerType(operandETy.getIntOrFloatBitWidth());
527  if (auto shapedTy = dyn_cast<ShapedType>(fTy)) {
528  iTy = shapedTy.clone(iTy);
529  }
530 
531  unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
532  // The width returned by getFPMantissaWidth includes the integer bit.
533  unsigned mantissaWidth =
534  llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
535  unsigned exponentWidth = bitWidth - mantissaWidth - 1;
536 
537  // The names of the variables correspond to f32.
538  // f64: 1 bit sign | 11 bits exponent | 52 bits mantissa.
539  // f32: 1 bit sign | 8 bits exponent | 23 bits mantissa.
540  // f16: 1 bit sign | 5 bits exponent | 10 bits mantissa.
541  Value c1Float = createFloatConst(loc, fTy, 1.0, b);
542  Value c0 = createIntConst(loc, iTy, 0, b);
543  Value c1 = createIntConst(loc, iTy, 1, b);
544  Value cNeg1 = createIntConst(loc, iTy, -1, b);
545  Value c23 = createIntConst(loc, iTy, mantissaWidth, b);
546  Value c31 = createIntConst(loc, iTy, bitWidth - 1, b);
547  Value c127 = createIntConst(loc, iTy, (1ull << (exponentWidth - 1)) - 1, b);
548  Value c2To22 = createIntConst(loc, iTy, 1ull << (mantissaWidth - 1), b);
549  Value c23Mask = createIntConst(loc, iTy, (1ull << mantissaWidth) - 1, b);
550  Value expMask = createIntConst(loc, iTy, (1ull << exponentWidth) - 1, b);
551 
552  Value operandBitcast = arith::BitcastOp::create(b, iTy, operand);
553  Value round = math::RoundOp::create(b, operand);
554  Value roundBitcast = arith::BitcastOp::create(b, iTy, round);
555 
556  // Get biased exponents for operand and round(operand)
557  Value operandExp = arith::AndIOp::create(
558  b, arith::ShRUIOp::create(b, operandBitcast, c23), expMask);
559  Value operandBiasedExp = arith::SubIOp::create(b, operandExp, c127);
560  Value roundExp = arith::AndIOp::create(
561  b, arith::ShRUIOp::create(b, roundBitcast, c23), expMask);
562  Value roundBiasedExp = arith::SubIOp::create(b, roundExp, c127);
563 
564  auto safeShiftRight = [&](Value x, Value shift) -> Value {
565  // Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior
566  Value clampedShift = arith::MaxSIOp::create(b, shift, c0);
567  clampedShift = arith::MinSIOp::create(b, clampedShift, c31);
568  return arith::ShRUIOp::create(b, x, clampedShift);
569  };
570 
571  auto maskMantissa = [&](Value mantissa,
572  Value mantissaMaskRightShift) -> Value {
573  Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
574  return arith::AndIOp::create(b, mantissa, shiftedMantissaMask);
575  };
576 
577  // A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring
578  // the leftmost `clamp(biasedExp - 1, 0, 23)` bits, is zero. Large numbers
579  // with `biasedExp > 23` (numbers where there is not enough precision to store
580  // decimals) are always even, and they satisfy the even condition trivially
581  // since the mantissa without all its bits is zero. The even condition
582  // is also true for +-0, since they have `biasedExp = -127` and the entire
583  // mantissa is zero. The case of +-1 has to be handled separately. Here
584  // we identify these values by noting that +-1 are the only whole numbers with
585  // `biasedExp == 0`.
586  //
587  // The special values +-inf and +-nan also satisfy the same property that
588  // whole non-unit even numbers satisfy. In particular, the special values have
589  // `biasedExp > 23`, so they get treated as large numbers with no room for
590  // decimals, which are always even.
591  Value roundBiasedExpEq0 =
592  arith::CmpIOp::create(b, arith::CmpIPredicate::eq, roundBiasedExp, c0);
593  Value roundBiasedExpMinus1 = arith::SubIOp::create(b, roundBiasedExp, c1);
594  Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
595  Value roundIsNotEvenOrSpecialVal = arith::CmpIOp::create(
596  b, arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
597  roundIsNotEvenOrSpecialVal =
598  arith::OrIOp::create(b, roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
599 
600  // A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive
601  // integers if the bit at index `biasedExp` starting from the left in the
602  // mantissa is 1 and all the bits to the right are zero. Values with
603  // `biasedExp >= 23` don't have decimals, so they are never halfway. The
604  // values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`,
605  // so these are handled separately. In particular, if `biasedExp == -1`, the
606  // value is halfway if the entire mantissa is zero.
607  Value operandBiasedExpEqNeg1 = arith::CmpIOp::create(
608  b, arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
609  Value expectedOperandMaskedMantissa = arith::SelectOp::create(
610  b, operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
611  Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
612  Value operandIsHalfway =
613  arith::CmpIOp::create(b, arith::CmpIPredicate::eq, operandMaskedMantissa,
614  expectedOperandMaskedMantissa);
615  // Ensure `biasedExp` is in the valid range for half values.
616  Value operandBiasedExpGeNeg1 = arith::CmpIOp::create(
617  b, arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
618  Value operandBiasedExpLt23 = arith::CmpIOp::create(
619  b, arith::CmpIPredicate::slt, operandBiasedExp, c23);
620  operandIsHalfway =
621  arith::AndIOp::create(b, operandIsHalfway, operandBiasedExpLt23);
622  operandIsHalfway =
623  arith::AndIOp::create(b, operandIsHalfway, operandBiasedExpGeNeg1);
624 
625  // Adjust rounded operand with `round(operand) - sign(operand)` to correct the
626  // case where `round` rounded in the opposite direction of `roundeven`.
627  Value sign = math::CopySignOp::create(b, c1Float, operand);
628  Value roundShifted = arith::SubFOp::create(b, round, sign);
629  // If the rounded value is even or a special value, we default to the behavior
630  // of `math.round`.
631  Value needsShift =
632  arith::AndIOp::create(b, roundIsNotEvenOrSpecialVal, operandIsHalfway);
633  Value result = arith::SelectOp::create(b, needsShift, roundShifted, round);
634  // The `x - sign` adjustment does not preserve the sign when we are adjusting
635  // the value -1 to -0. So here the sign is copied again to ensure that -0.5 is
636  // rounded to -0.0.
637  result = math::CopySignOp::create(b, result, operand);
638  rewriter.replaceOp(op, result);
639  return success();
640 }
641 
642 // Convert `math.rsqrt` into `arith.divf` + `math.sqrt`
643 static LogicalResult convertRsqrtOp(math::RsqrtOp op,
644  PatternRewriter &rewriter) {
645 
646  auto operand = op.getOperand();
647  auto operandTy = operand.getType();
648  // Operand type must be shatic shaped type to create const float.
649  auto shapedOperandType = dyn_cast<ShapedType>(operandTy);
650  if (shapedOperandType && !shapedOperandType.hasStaticShape())
651  return failure();
652 
653  auto eTy = getElementTypeOrSelf(operandTy);
654  if (!isa<FloatType>(eTy))
655  return failure();
656 
657  Location loc = op->getLoc();
658  auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);
659  auto sqrtOp = math::SqrtOp::create(rewriter, loc, operand);
660  rewriter.replaceOpWithNewOp<arith::DivFOp>(op, constOneFloat, sqrtOp);
661  return success();
662 }
663 
665  patterns.add(convertCtlzOp);
666 }
667 
669  patterns.add(convertSinhOp);
670 }
671 
673  patterns.add(convertCoshOp);
674 }
675 
677  patterns.add(convertTanOp);
678 }
679 
681  patterns.add(convertTanhOp);
682 }
683 
686 }
687 
690 }
691 
694 }
695 
697  patterns.add(convertFmaFOp);
698 }
699 
701  patterns.add(convertCeilOp);
702 }
703 
706 }
707 
709  patterns.add(convertPowfOp);
710 }
711 
714 }
715 
718 }
719 
722 }
723 
726 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static LogicalResult convertRsqrtOp(math::RsqrtOp op, PatternRewriter &rewriter)
static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b)
static LogicalResult convertFPowIOp(math::FPowIOp op, PatternRewriter &rewriter)
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundOp(math::RoundOp op, PatternRewriter &rewriter)
static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter)
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, PatternRewriter &rewriter)
static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter)
static LogicalResult convertAtanhOp(math::AtanhOp op, PatternRewriter &rewriter)
static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, PatternRewriter &rewriter)
static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter)
static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter)
static Value createFloatConst(Location loc, Type type, APFloat value, OpBuilder &b)
Create a float constant.
static LogicalResult convertAsinhOp(math::AsinhOp op, PatternRewriter &rewriter)
static Value createIntConst(Location loc, Type type, int64_t value, OpBuilder &b)
Create an integer constant.
static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter)
Expands tanh op into 1-exp^{-2x} / 1+exp^{-2x} To avoid overflow we exploit the reflection symmetry t...
static LogicalResult convertAcoshOp(math::AcoshOp op, PatternRewriter &rewriter)
static LogicalResult convertExp2fOp(math::Exp2Op op, PatternRewriter &rewriter)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:249
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:621
Location getLoc() const
Accessors for the implied location.
Definition: Builders.h:654
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:205
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
detail::LazyTextBuild add(const char *fmt, Ts &&...ts)
Create a Remark with llvm::formatv formatting.
Definition: Remarks.h:463
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
void populateExpandSinhPattern(RewritePatternSet &patterns)
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
void populateExpandRsqrtPattern(RewritePatternSet &patterns)
void populateExpandTanhPattern(RewritePatternSet &patterns)
void populateExpandFmaFPattern(RewritePatternSet &patterns)
void populateExpandAcoshPattern(RewritePatternSet &patterns)
void populateExpandFPowIPattern(RewritePatternSet &patterns)
void populateExpandPowFPattern(RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateExpandTanPattern(RewritePatternSet &patterns)
const FrozenRewritePatternSet & patterns
void populateExpandCoshPattern(RewritePatternSet &patterns)
void populateExpandRoundFPattern(RewritePatternSet &patterns)
void populateExpandExp2FPattern(RewritePatternSet &patterns)
void populateExpandCeilFPattern(RewritePatternSet &patterns)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void populateExpandCtlzPattern(RewritePatternSet &patterns)
void populateExpandAsinhPattern(RewritePatternSet &patterns)
void populateExpandRoundEvenPattern(RewritePatternSet &patterns)
void populateExpandAtanhPattern(RewritePatternSet &patterns)
detail::constant_float_value_binder m_ConstantFloat(FloatAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor float (splat) and writes the float value to bind_va...
Definition: Matchers.h:520