MLIR  21.0.0git
ExpandPatterns.cpp
Go to the documentation of this file.
1 //===- ExpandPatterns.cpp - Code to expand various math operations. -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements expansion of various math operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/TypeUtilities.h"
22 #include "llvm/ADT/APFloat.h"
23 
24 using namespace mlir;
25 
26 /// Create a float constant.
27 static Value createFloatConst(Location loc, Type type, APFloat value,
28  OpBuilder &b) {
29  bool losesInfo = false;
30  auto eltType = getElementTypeOrSelf(type);
31  // Convert double to the given `FloatType` with round-to-nearest-ties-to-even.
32  value.convert(cast<FloatType>(eltType).getFloatSemantics(),
33  APFloat::rmNearestTiesToEven, &losesInfo);
34  auto attr = b.getFloatAttr(eltType, value);
35  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
36  return b.create<arith::ConstantOp>(loc,
37  DenseElementsAttr::get(shapedTy, attr));
38  }
39 
40  return b.create<arith::ConstantOp>(loc, attr);
41 }
42 
43 static Value createFloatConst(Location loc, Type type, double value,
44  OpBuilder &b) {
45  return createFloatConst(loc, type, APFloat(value), b);
46 }
47 
48 /// Create an integer constant.
49 static Value createIntConst(Location loc, Type type, int64_t value,
50  OpBuilder &b) {
51  auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
52  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
53  return b.create<arith::ConstantOp>(loc,
54  DenseElementsAttr::get(shapedTy, attr));
55  }
56 
57  return b.create<arith::ConstantOp>(loc, attr);
58 }
59 
61  Type opType = operand.getType();
62  Type i64Ty = b.getI64Type();
63  if (auto shapedTy = dyn_cast<ShapedType>(opType))
64  i64Ty = shapedTy.clone(i64Ty);
65  Value fixedConvert = b.create<arith::FPToSIOp>(i64Ty, operand);
66  Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
67  // The truncation does not preserve the sign when the truncated
68  // value is -0. So here the sign is copied again.
69  return b.create<math::CopySignOp>(fpFixedConvert, operand);
70 }
71 
72 // sinhf(float x) -> (exp(x) - exp(-x)) / 2
73 static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
74  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
75  Value operand = op.getOperand();
76  Type opType = operand.getType();
77 
78  Value exp = b.create<math::ExpOp>(operand);
79  Value neg = b.create<arith::NegFOp>(operand);
80  Value nexp = b.create<math::ExpOp>(neg);
81  Value sub = b.create<arith::SubFOp>(exp, nexp);
82  Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
83  Value res = b.create<arith::MulFOp>(sub, half);
84  rewriter.replaceOp(op, res);
85  return success();
86 }
87 
88 // coshf(float x) -> (exp(x) + exp(-x)) / 2
89 static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
90  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
91  Value operand = op.getOperand();
92  Type opType = operand.getType();
93 
94  Value exp = b.create<math::ExpOp>(operand);
95  Value neg = b.create<arith::NegFOp>(operand);
96  Value nexp = b.create<math::ExpOp>(neg);
97  Value add = b.create<arith::AddFOp>(exp, nexp);
98  Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
99  Value res = b.create<arith::MulFOp>(add, half);
100  rewriter.replaceOp(op, res);
101  return success();
102 }
103 
104 /// Expands tanh op into
105 /// 1-exp^{-2x} / 1+exp^{-2x}
106 /// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
107 /// We compute a "signs" value which is -1 if input is negative and +1 if input
108 /// is positive. Then multiply the input by this value, guaranteeing that the
109 /// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
110 /// 1]. Expand the computation on the input `x * sign(x)`, then multiply the
111 /// result by `sign(x)` to retain sign of the real result.
112 static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
113  auto floatType = op.getOperand().getType();
114  Location loc = op.getLoc();
115  Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
116  Value one = createFloatConst(loc, floatType, 1.0, rewriter);
117  Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
118 
119  // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
120  Value isNegative = rewriter.create<arith::CmpFOp>(
121  loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
122  Value isNegativeFloat =
123  rewriter.create<arith::UIToFPOp>(loc, floatType, isNegative);
124  Value isNegativeTimesNegTwo =
125  rewriter.create<arith::MulFOp>(loc, isNegativeFloat, negTwo);
126  Value sign = rewriter.create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one);
127 
128  // Normalize input to positive value: y = sign(x) * x
129  Value positiveX = rewriter.create<arith::MulFOp>(loc, sign, op.getOperand());
130 
131  // Decompose on normalized input
132  Value negDoubledX = rewriter.create<arith::MulFOp>(loc, negTwo, positiveX);
133  Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
134  Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
135  Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
136  Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
137 
138  // Multiply result by sign(x) to retain signs from negative inputs
139  rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
140 
141  return success();
142 }
143 
144 // Converts math.tan to math.sin, math.cos, and arith.divf.
145 static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
146  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
147  Value operand = op.getOperand();
148  Type type = operand.getType();
149  Value sin = b.create<math::SinOp>(type, operand);
150  Value cos = b.create<math::CosOp>(type, operand);
151  Value div = b.create<arith::DivFOp>(type, sin, cos);
152  rewriter.replaceOp(op, div);
153  return success();
154 }
155 
156 // asinh(float x) -> log(x + sqrt(x**2 + 1))
157 static LogicalResult convertAsinhOp(math::AsinhOp op,
158  PatternRewriter &rewriter) {
159  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
160  Value operand = op.getOperand();
161  Type opType = operand.getType();
162 
163  Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
164  Value fma = b.create<math::FmaOp>(operand, operand, one);
165  Value sqrt = b.create<math::SqrtOp>(fma);
166  Value add = b.create<arith::AddFOp>(operand, sqrt);
167  Value res = b.create<math::LogOp>(add);
168  rewriter.replaceOp(op, res);
169  return success();
170 }
171 
172 // acosh(float x) -> log(x + sqrt(x**2 - 1))
173 static LogicalResult convertAcoshOp(math::AcoshOp op,
174  PatternRewriter &rewriter) {
175  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
176  Value operand = op.getOperand();
177  Type opType = operand.getType();
178 
179  Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
180  Value fma = b.create<math::FmaOp>(operand, operand, negOne);
181  Value sqrt = b.create<math::SqrtOp>(fma);
182  Value add = b.create<arith::AddFOp>(operand, sqrt);
183  Value res = b.create<math::LogOp>(add);
184  rewriter.replaceOp(op, res);
185  return success();
186 }
187 
188 // atanh(float x) -> log((1 + x) / (1 - x)) / 2
189 static LogicalResult convertAtanhOp(math::AtanhOp op,
190  PatternRewriter &rewriter) {
191  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
192  Value operand = op.getOperand();
193  Type opType = operand.getType();
194 
195  Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
196  Value add = b.create<arith::AddFOp>(operand, one);
197  Value neg = b.create<arith::NegFOp>(operand);
198  Value sub = b.create<arith::AddFOp>(neg, one);
199  Value div = b.create<arith::DivFOp>(add, sub);
200  Value log = b.create<math::LogOp>(div);
201  Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
202  Value res = b.create<arith::MulFOp>(log, half);
203  rewriter.replaceOp(op, res);
204  return success();
205 }
206 
207 static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
208  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
209  Value operandA = op.getOperand(0);
210  Value operandB = op.getOperand(1);
211  Value operandC = op.getOperand(2);
212  Type type = op.getType();
213  Value mult = b.create<arith::MulFOp>(type, operandA, operandB);
214  Value add = b.create<arith::AddFOp>(type, mult, operandC);
215  rewriter.replaceOp(op, add);
216  return success();
217 }
218 
219 // Converts a ceilf() function to the following:
220 // ceilf(float x) ->
221 // y = (float)(int) x
222 // if (x > y) then incr = 1 else incr = 0
223 // y = y + incr <= replace this op with the ceilf op.
224 static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
225  // Creating constants assumes the static shaped type.
226  auto shapedType = dyn_cast<ShapedType>(op.getType());
227  if (shapedType && !shapedType.hasStaticShape())
228  return failure();
229 
230  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
231  Value operand = op.getOperand();
232  Type opType = operand.getType();
233  Value fpFixedConvert = createTruncatedFPValue(operand, b);
234 
235  // Creating constants for later use.
236  Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
237  Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
238 
239  Value gtCheck = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand,
240  fpFixedConvert);
241  Value incrValue = b.create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero);
242 
243  Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
244  rewriter.replaceOp(op, ret);
245  return success();
246 }
247 
248 // Convert `math.fpowi` to a series of `arith.mulf` operations.
249 // If the power is negative, we divide one by the result.
250 // If both the base and power are zero, the result is 1.
251 // In the case of non constant power, we convert the operation to `math.powf`.
252 static LogicalResult convertFPowIOp(math::FPowIOp op,
253  PatternRewriter &rewriter) {
254  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
255  Value base = op.getOperand(0);
256  Value power = op.getOperand(1);
257  Type baseType = base.getType();
258 
259  auto convertFPowItoPowf = [&]() -> LogicalResult {
260  Value castPowerToFp =
261  rewriter.create<arith::SIToFPOp>(op.getLoc(), baseType, power);
262  Value res = rewriter.create<math::PowFOp>(op.getLoc(), baseType, base,
263  castPowerToFp);
264  rewriter.replaceOp(op, res);
265  return success();
266  };
267 
268  Attribute cstAttr;
269  if (!matchPattern(power, m_Constant(&cstAttr)))
270  return convertFPowItoPowf();
271 
272  APInt value;
273  if (!matchPattern(cstAttr, m_ConstantInt(&value)))
274  return convertFPowItoPowf();
275 
276  int64_t powerInt = value.getSExtValue();
277  bool isNegative = powerInt < 0;
278  int64_t absPower = std::abs(powerInt);
279  Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
280  Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
281 
282  while (absPower > 0) {
283  if (absPower & 1)
284  res = b.create<arith::MulFOp>(baseType, base, res);
285  absPower >>= 1;
286  base = b.create<arith::MulFOp>(baseType, base, base);
287  }
288 
289  // Make sure not to introduce UB in case of negative power.
290  if (isNegative) {
291  auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
292  .getFloatSemantics();
293  Value zero =
294  createFloatConst(op->getLoc(), baseType,
295  APFloat::getZero(sem, /*Negative=*/false), rewriter);
296  Value negZero =
297  createFloatConst(op->getLoc(), baseType,
298  APFloat::getZero(sem, /*Negative=*/true), rewriter);
299  Value posInfinity =
300  createFloatConst(op->getLoc(), baseType,
301  APFloat::getInf(sem, /*Negative=*/false), rewriter);
302  Value negInfinity =
303  createFloatConst(op->getLoc(), baseType,
304  APFloat::getInf(sem, /*Negative=*/true), rewriter);
305  Value zeroEqCheck =
306  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero);
307  Value negZeroEqCheck =
308  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero);
309  res = b.create<arith::DivFOp>(baseType, one, res);
310  res =
311  b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
312  res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
313  res);
314  }
315 
316  rewriter.replaceOp(op, res);
317  return success();
318 }
319 
320 // Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
321 // Some special cases where b is constant are handled separately:
322 // when b == 0, or |b| == 0.5, 1.0, or 2.0.
323 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
324  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
325  Value operandA = op.getOperand(0);
326  Value operandB = op.getOperand(1);
327  auto typeA = operandA.getType();
328  auto typeB = operandB.getType();
329 
330  auto &sem =
331  cast<mlir::FloatType>(getElementTypeOrSelf(typeB)).getFloatSemantics();
332  APFloat valueB(sem);
333  auto mulf = [&](Value x, Value y) -> Value {
334  return b.create<arith::MulFOp>(x, y);
335  };
336  if (matchPattern(operandB, m_ConstantFloat(&valueB))) {
337  if (valueB.isZero()) {
338  // a^0 -> 1
339  Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
340  rewriter.replaceOp(op, one);
341  return success();
342  }
343  if (valueB.isExactlyValue(1.0)) {
344  // a^1 -> a
345  rewriter.replaceOp(op, operandA);
346  return success();
347  }
348  if (valueB.isExactlyValue(-1.0)) {
349  // a^(-1) -> 1 / a
350  Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
351  Value div = b.create<arith::DivFOp>(one, operandA);
352  rewriter.replaceOp(op, div);
353  return success();
354  }
355  if (valueB.isExactlyValue(0.5)) {
356  // a^(1/2) -> sqrt(a)
357  Value sqrt = b.create<math::SqrtOp>(operandA);
358  rewriter.replaceOp(op, sqrt);
359  return success();
360  }
361  if (valueB.isExactlyValue(-0.5)) {
362  // a^(-1/2) -> 1 / sqrt(a)
363  Value rsqrt = b.create<math::RsqrtOp>(operandA);
364  rewriter.replaceOp(op, rsqrt);
365  return success();
366  }
367  if (valueB.isExactlyValue(2.0)) {
368  // a^2 -> a * a
369  rewriter.replaceOp(op, mulf(operandA, operandA));
370  return success();
371  }
372  if (valueB.isExactlyValue(-2.0)) {
373  // a^(-2) -> 1 / (a * a)
374  Value one =
375  createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
376  Value div = b.create<arith::DivFOp>(one, mulf(operandA, operandA));
377  rewriter.replaceOp(op, div);
378  return success();
379  }
380  if (valueB.isExactlyValue(3.0)) {
381  rewriter.replaceOp(op, mulf(mulf(operandA, operandA), operandA));
382  return success();
383  }
384  }
385 
386  Value logA = b.create<math::LogOp>(operandA);
387  Value mult = b.create<arith::MulFOp>(operandB, logA);
388  Value expResult = b.create<math::ExpOp>(mult);
389  rewriter.replaceOp(op, expResult);
390  return success();
391 }
392 
393 // exp2f(float x) -> exp(x * ln(2))
394 // Proof: Let's say 2^x = y
395 // ln(2^x) = ln(y)
396 // x * ln(2) = ln(y) => e ^(x*ln(2)) = y
397 static LogicalResult convertExp2fOp(math::Exp2Op op,
398  PatternRewriter &rewriter) {
399  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
400  Value operand = op.getOperand();
401  Type opType = operand.getType();
402  Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
403  Value mult = b.create<arith::MulFOp>(opType, operand, ln2);
404  Value exp = b.create<math::ExpOp>(op->getLoc(), mult);
405  rewriter.replaceOp(op, exp);
406  return success();
407 }
408 
409 static LogicalResult convertRoundOp(math::RoundOp op,
410  PatternRewriter &rewriter) {
411  Location loc = op.getLoc();
412  ImplicitLocOpBuilder b(loc, rewriter);
413  Value operand = op.getOperand();
414  Type opType = operand.getType();
415  Type opEType = getElementTypeOrSelf(opType);
416 
417  if (!opEType.isF32()) {
418  return rewriter.notifyMatchFailure(op, "not a round of f32.");
419  }
420 
421  Type i32Ty = b.getI32Type();
422  if (auto shapedTy = dyn_cast<ShapedType>(opType))
423  i32Ty = shapedTy.clone(i32Ty);
424 
425  Value half = createFloatConst(loc, opType, 0.5, b);
426  Value c23 = createIntConst(loc, i32Ty, 23, b);
427  Value c127 = createIntConst(loc, i32Ty, 127, b);
428  Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
429 
430  Value incrValue = b.create<math::CopySignOp>(half, operand);
431  Value add = b.create<arith::AddFOp>(opType, operand, incrValue);
432  Value fpFixedConvert = createTruncatedFPValue(add, b);
433 
434  // There are three cases where adding 0.5 to the value and truncating by
435  // converting to an i64 does not result in the correct behavior:
436  //
437  // 1. Special values: +-inf and +-nan
438  // Casting these special values to i64 has undefined behavior. To identify
439  // these values, we use the fact that these values are the only float
440  // values with the maximum possible biased exponent.
441  //
442  // 2. Large values: 2^23 <= |x| <= INT_64_MAX
443  // Adding 0.5 to a float larger than or equal to 2^23 results in precision
444  // errors that sometimes round the value up and sometimes round the value
445  // down. For example:
446  // 8388608.0 + 0.5 = 8388608.0
447  // 8388609.0 + 0.5 = 8388610.0
448  //
449  // 3. Very large values: |x| > INT_64_MAX
450  // Casting to i64 a value greater than the max i64 value will overflow the
451  // i64 leading to wrong outputs.
452  //
453  // All three cases satisfy the property `biasedExp >= 23`.
454  Value operandBitcast = b.create<arith::BitcastOp>(i32Ty, operand);
455  Value operandExp = b.create<arith::AndIOp>(
456  b.create<arith::ShRUIOp>(operandBitcast, c23), expMask);
457  Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127);
458  Value isSpecialValOrLargeVal =
459  b.create<arith::CmpIOp>(arith::CmpIPredicate::sge, operandBiasedExp, c23);
460 
461  Value result = b.create<arith::SelectOp>(isSpecialValOrLargeVal, operand,
462  fpFixedConvert);
463  rewriter.replaceOp(op, result);
464  return success();
465 }
466 
467 // Converts math.ctlz to scf and arith operations. This is done
468 // by performing a binary search on the bits.
469 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
470  PatternRewriter &rewriter) {
471  auto operand = op.getOperand();
472  auto operandTy = operand.getType();
473  auto eTy = getElementTypeOrSelf(operandTy);
474  Location loc = op.getLoc();
475 
476  int32_t bitwidth = eTy.getIntOrFloatBitWidth();
477  if (bitwidth > 64)
478  return failure();
479 
480  uint64_t allbits = -1;
481  if (bitwidth < 64) {
482  allbits = allbits >> (64 - bitwidth);
483  }
484 
485  Value x = operand;
486  Value count = createIntConst(loc, operandTy, 0, rewriter);
487  for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
488  auto half = bw / 2;
489  auto bits = createIntConst(loc, operandTy, half, rewriter);
490  auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter);
491 
492  Value pred =
493  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, x, mask);
494  Value add = rewriter.create<arith::AddIOp>(loc, count, bits);
495  Value shift = rewriter.create<arith::ShLIOp>(loc, x, bits);
496 
497  x = rewriter.create<arith::SelectOp>(loc, pred, shift, x);
498  count = rewriter.create<arith::SelectOp>(loc, pred, add, count);
499  }
500 
501  Value zero = createIntConst(loc, operandTy, 0, rewriter);
502  Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
503  operand, zero);
504 
505  Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter);
506  Value sel = rewriter.create<arith::SelectOp>(loc, pred, bwval, count);
507  rewriter.replaceOp(op, sel);
508  return success();
509 }
510 
511 // Convert `math.roundeven` into `math.round` + arith ops
512 static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
513  PatternRewriter &rewriter) {
514  Location loc = op.getLoc();
515  ImplicitLocOpBuilder b(loc, rewriter);
516  auto operand = op.getOperand();
517  Type operandTy = operand.getType();
518  Type resultTy = op.getType();
519  Type operandETy = getElementTypeOrSelf(operandTy);
520  Type resultETy = getElementTypeOrSelf(resultTy);
521 
522  if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
523  return rewriter.notifyMatchFailure(op, "not a roundeven of f16 or f32.");
524  }
525 
526  Type fTy = operandTy;
527  Type iTy = rewriter.getIntegerType(operandETy.getIntOrFloatBitWidth());
528  if (auto shapedTy = dyn_cast<ShapedType>(fTy)) {
529  iTy = shapedTy.clone(iTy);
530  }
531 
532  unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
533  // The width returned by getFPMantissaWidth includes the integer bit.
534  unsigned mantissaWidth =
535  llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
536  unsigned exponentWidth = bitWidth - mantissaWidth - 1;
537 
538  // The names of the variables correspond to f32.
539  // f64: 1 bit sign | 11 bits exponent | 52 bits mantissa.
540  // f32: 1 bit sign | 8 bits exponent | 23 bits mantissa.
541  // f16: 1 bit sign | 5 bits exponent | 10 bits mantissa.
542  Value c1Float = createFloatConst(loc, fTy, 1.0, b);
543  Value c0 = createIntConst(loc, iTy, 0, b);
544  Value c1 = createIntConst(loc, iTy, 1, b);
545  Value cNeg1 = createIntConst(loc, iTy, -1, b);
546  Value c23 = createIntConst(loc, iTy, mantissaWidth, b);
547  Value c31 = createIntConst(loc, iTy, bitWidth - 1, b);
548  Value c127 = createIntConst(loc, iTy, (1ull << (exponentWidth - 1)) - 1, b);
549  Value c2To22 = createIntConst(loc, iTy, 1ull << (mantissaWidth - 1), b);
550  Value c23Mask = createIntConst(loc, iTy, (1ull << mantissaWidth) - 1, b);
551  Value expMask = createIntConst(loc, iTy, (1ull << exponentWidth) - 1, b);
552 
553  Value operandBitcast = b.create<arith::BitcastOp>(iTy, operand);
554  Value round = b.create<math::RoundOp>(operand);
555  Value roundBitcast = b.create<arith::BitcastOp>(iTy, round);
556 
557  // Get biased exponents for operand and round(operand)
558  Value operandExp = b.create<arith::AndIOp>(
559  b.create<arith::ShRUIOp>(operandBitcast, c23), expMask);
560  Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127);
561  Value roundExp = b.create<arith::AndIOp>(
562  b.create<arith::ShRUIOp>(roundBitcast, c23), expMask);
563  Value roundBiasedExp = b.create<arith::SubIOp>(roundExp, c127);
564 
565  auto safeShiftRight = [&](Value x, Value shift) -> Value {
566  // Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior
567  Value clampedShift = b.create<arith::MaxSIOp>(shift, c0);
568  clampedShift = b.create<arith::MinSIOp>(clampedShift, c31);
569  return b.create<arith::ShRUIOp>(x, clampedShift);
570  };
571 
572  auto maskMantissa = [&](Value mantissa,
573  Value mantissaMaskRightShift) -> Value {
574  Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
575  return b.create<arith::AndIOp>(mantissa, shiftedMantissaMask);
576  };
577 
578  // A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring
579  // the leftmost `clamp(biasedExp - 1, 0, 23)` bits, is zero. Large numbers
580  // with `biasedExp > 23` (numbers where there is not enough precision to store
581  // decimals) are always even, and they satisfy the even condition trivially
582  // since the mantissa without all its bits is zero. The even condition
583  // is also true for +-0, since they have `biasedExp = -127` and the entire
584  // mantissa is zero. The case of +-1 has to be handled separately. Here
585  // we identify these values by noting that +-1 are the only whole numbers with
586  // `biasedExp == 0`.
587  //
588  // The special values +-inf and +-nan also satisfy the same property that
589  // whole non-unit even numbers satisfy. In particular, the special values have
590  // `biasedExp > 23`, so they get treated as large numbers with no room for
591  // decimals, which are always even.
592  Value roundBiasedExpEq0 =
593  b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, roundBiasedExp, c0);
594  Value roundBiasedExpMinus1 = b.create<arith::SubIOp>(roundBiasedExp, c1);
595  Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
596  Value roundIsNotEvenOrSpecialVal = b.create<arith::CmpIOp>(
597  arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
598  roundIsNotEvenOrSpecialVal =
599  b.create<arith::OrIOp>(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
600 
601  // A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive
602  // integers if the bit at index `biasedExp` starting from the left in the
603  // mantissa is 1 and all the bits to the right are zero. Values with
604  // `biasedExp >= 23` don't have decimals, so they are never halfway. The
605  // values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`,
606  // so these are handled separately. In particular, if `biasedExp == -1`, the
607  // value is halfway if the entire mantissa is zero.
608  Value operandBiasedExpEqNeg1 = b.create<arith::CmpIOp>(
609  arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
610  Value expectedOperandMaskedMantissa = b.create<arith::SelectOp>(
611  operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
612  Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
613  Value operandIsHalfway =
614  b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, operandMaskedMantissa,
615  expectedOperandMaskedMantissa);
616  // Ensure `biasedExp` is in the valid range for half values.
617  Value operandBiasedExpGeNeg1 = b.create<arith::CmpIOp>(
618  arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
619  Value operandBiasedExpLt23 =
620  b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, operandBiasedExp, c23);
621  operandIsHalfway =
622  b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpLt23);
623  operandIsHalfway =
624  b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpGeNeg1);
625 
626  // Adjust rounded operand with `round(operand) - sign(operand)` to correct the
627  // case where `round` rounded in the opposite direction of `roundeven`.
628  Value sign = b.create<math::CopySignOp>(c1Float, operand);
629  Value roundShifted = b.create<arith::SubFOp>(round, sign);
630  // If the rounded value is even or a special value, we default to the behavior
631  // of `math.round`.
632  Value needsShift =
633  b.create<arith::AndIOp>(roundIsNotEvenOrSpecialVal, operandIsHalfway);
634  Value result = b.create<arith::SelectOp>(needsShift, roundShifted, round);
635  // The `x - sign` adjustment does not preserve the sign when we are adjusting
636  // the value -1 to -0. So here the sign is copied again to ensure that -0.5 is
637  // rounded to -0.0.
638  result = b.create<math::CopySignOp>(result, operand);
639  rewriter.replaceOp(op, result);
640  return success();
641 }
642 
643 // Convert `math.rsqrt` into `arith.divf` + `math.sqrt`
644 static LogicalResult convertRsqrtOp(math::RsqrtOp op,
645  PatternRewriter &rewriter) {
646 
647  auto operand = op.getOperand();
648  auto operandTy = operand.getType();
649  // Operand type must be shatic shaped type to create const float.
650  auto shapedOperandType = dyn_cast<ShapedType>(operandTy);
651  if (shapedOperandType && !shapedOperandType.hasStaticShape())
652  return failure();
653 
654  auto eTy = getElementTypeOrSelf(operandTy);
655  if (!isa<FloatType>(eTy))
656  return failure();
657 
658  Location loc = op->getLoc();
659  auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);
660  auto sqrtOp = rewriter.create<math::SqrtOp>(loc, operand);
661  rewriter.replaceOpWithNewOp<arith::DivFOp>(op, constOneFloat, sqrtOp);
662  return success();
663 }
664 
666  patterns.add(convertCtlzOp);
667 }
668 
670  patterns.add(convertSinhOp);
671 }
672 
674  patterns.add(convertCoshOp);
675 }
676 
678  patterns.add(convertTanOp);
679 }
680 
682  patterns.add(convertTanhOp);
683 }
684 
687 }
688 
691 }
692 
695 }
696 
698  patterns.add(convertFmaFOp);
699 }
700 
702  patterns.add(convertCeilOp);
703 }
704 
707 }
708 
710  patterns.add(convertPowfOp);
711 }
712 
715 }
716 
719 }
720 
723 }
724 
727 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static LogicalResult convertRsqrtOp(math::RsqrtOp op, PatternRewriter &rewriter)
static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b)
static LogicalResult convertFPowIOp(math::FPowIOp op, PatternRewriter &rewriter)
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundOp(math::RoundOp op, PatternRewriter &rewriter)
static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter)
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, PatternRewriter &rewriter)
static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter)
static LogicalResult convertAtanhOp(math::AtanhOp op, PatternRewriter &rewriter)
static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, PatternRewriter &rewriter)
static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter)
static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter)
static Value createFloatConst(Location loc, Type type, APFloat value, OpBuilder &b)
Create a float constant.
static LogicalResult convertAsinhOp(math::AsinhOp op, PatternRewriter &rewriter)
static Value createIntConst(Location loc, Type type, int64_t value, OpBuilder &b)
Create an integer constant.
static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter)
Expands tanh op into 1-exp^{-2x} / 1+exp^{-2x} To avoid overflow we exploit the reflection symmetry t...
static LogicalResult convertAcoshOp(math::AcoshOp op, PatternRewriter &rewriter)
static LogicalResult convertExp2fOp(math::Exp2Op op, PatternRewriter &rewriter)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:554
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
void populateExpandSinhPattern(RewritePatternSet &patterns)
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
void populateExpandRsqrtPattern(RewritePatternSet &patterns)
void populateExpandTanhPattern(RewritePatternSet &patterns)
void populateExpandFmaFPattern(RewritePatternSet &patterns)
void populateExpandAcoshPattern(RewritePatternSet &patterns)
void populateExpandFPowIPattern(RewritePatternSet &patterns)
void populateExpandPowFPattern(RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateExpandTanPattern(RewritePatternSet &patterns)
const FrozenRewritePatternSet & patterns
void populateExpandCoshPattern(RewritePatternSet &patterns)
void populateExpandRoundFPattern(RewritePatternSet &patterns)
void populateExpandExp2FPattern(RewritePatternSet &patterns)
void populateExpandCeilFPattern(RewritePatternSet &patterns)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
void populateExpandCtlzPattern(RewritePatternSet &patterns)
void populateExpandAsinhPattern(RewritePatternSet &patterns)
void populateExpandRoundEvenPattern(RewritePatternSet &patterns)
void populateExpandAtanhPattern(RewritePatternSet &patterns)
detail::constant_float_value_binder m_ConstantFloat(FloatAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor float (splat) and writes the float value to bind_va...
Definition: Matchers.h:520