MLIR  20.0.0git
ExpandPatterns.cpp
Go to the documentation of this file.
1 //===- ExpandPatterns.cpp - Code to expand various math operations. -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements expansion of various math operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/TypeUtilities.h"
22 
23 using namespace mlir;
24 
25 /// Create a float constant.
26 static Value createFloatConst(Location loc, Type type, APFloat value,
27  OpBuilder &b) {
28  bool losesInfo = false;
29  auto eltType = getElementTypeOrSelf(type);
30  // Convert double to the given `FloatType` with round-to-nearest-ties-to-even.
31  value.convert(cast<FloatType>(eltType).getFloatSemantics(),
32  APFloat::rmNearestTiesToEven, &losesInfo);
33  auto attr = b.getFloatAttr(eltType, value);
34  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
35  return b.create<arith::ConstantOp>(loc,
36  DenseElementsAttr::get(shapedTy, attr));
37  }
38 
39  return b.create<arith::ConstantOp>(loc, attr);
40 }
41 
42 static Value createFloatConst(Location loc, Type type, double value,
43  OpBuilder &b) {
44  return createFloatConst(loc, type, APFloat(value), b);
45 }
46 
47 /// Create an integer constant.
48 static Value createIntConst(Location loc, Type type, int64_t value,
49  OpBuilder &b) {
50  auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
51  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
52  return b.create<arith::ConstantOp>(loc,
53  DenseElementsAttr::get(shapedTy, attr));
54  }
55 
56  return b.create<arith::ConstantOp>(loc, attr);
57 }
58 
60  Type opType = operand.getType();
61  Type i64Ty = b.getI64Type();
62  if (auto shapedTy = dyn_cast<ShapedType>(opType))
63  i64Ty = shapedTy.clone(i64Ty);
64  Value fixedConvert = b.create<arith::FPToSIOp>(i64Ty, operand);
65  Value fpFixedConvert = b.create<arith::SIToFPOp>(opType, fixedConvert);
66  // The truncation does not preserve the sign when the truncated
67  // value is -0. So here the sign is copied again.
68  return b.create<math::CopySignOp>(fpFixedConvert, operand);
69 }
70 
71 // sinhf(float x) -> (exp(x) - exp(-x)) / 2
72 static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
73  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
74  Value operand = op.getOperand();
75  Type opType = operand.getType();
76 
77  Value exp = b.create<math::ExpOp>(operand);
78  Value neg = b.create<arith::NegFOp>(operand);
79  Value nexp = b.create<math::ExpOp>(neg);
80  Value sub = b.create<arith::SubFOp>(exp, nexp);
81  Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
82  Value res = b.create<arith::MulFOp>(sub, half);
83  rewriter.replaceOp(op, res);
84  return success();
85 }
86 
87 // coshf(float x) -> (exp(x) + exp(-x)) / 2
88 static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
89  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
90  Value operand = op.getOperand();
91  Type opType = operand.getType();
92 
93  Value exp = b.create<math::ExpOp>(operand);
94  Value neg = b.create<arith::NegFOp>(operand);
95  Value nexp = b.create<math::ExpOp>(neg);
96  Value add = b.create<arith::AddFOp>(exp, nexp);
97  Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
98  Value res = b.create<arith::MulFOp>(add, half);
99  rewriter.replaceOp(op, res);
100  return success();
101 }
102 
103 /// Expands tanh op into
104 /// 1-exp^{-2x} / 1+exp^{-2x}
105 /// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
106 /// We compute a "signs" value which is -1 if input is negative and +1 if input
107 /// is positive. Then multiply the input by this value, guaranteeing that the
108 /// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
109 /// 1]. Expand the computation on the input `x * sign(x)`, then multiply the
110 /// result by `sign(x)` to retain sign of the real result.
111 static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
112  auto floatType = op.getOperand().getType();
113  Location loc = op.getLoc();
114  Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
115  Value one = createFloatConst(loc, floatType, 1.0, rewriter);
116  Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
117 
118  // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
119  Value isNegative = rewriter.create<arith::CmpFOp>(
120  loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
121  Value isNegativeFloat =
122  rewriter.create<arith::UIToFPOp>(loc, floatType, isNegative);
123  Value isNegativeTimesNegTwo =
124  rewriter.create<arith::MulFOp>(loc, isNegativeFloat, negTwo);
125  Value sign = rewriter.create<arith::AddFOp>(loc, isNegativeTimesNegTwo, one);
126 
127  // Normalize input to positive value: y = sign(x) * x
128  Value positiveX = rewriter.create<arith::MulFOp>(loc, sign, op.getOperand());
129 
130  // Decompose on normalized input
131  Value negDoubledX = rewriter.create<arith::MulFOp>(loc, negTwo, positiveX);
132  Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
133  Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
134  Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
135  Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
136 
137  // Multiply result by sign(x) to retain signs from negative inputs
138  rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
139 
140  return success();
141 }
142 
143 // Converts math.tan to math.sin, math.cos, and arith.divf.
144 static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
145  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
146  Value operand = op.getOperand();
147  Type type = operand.getType();
148  Value sin = b.create<math::SinOp>(type, operand);
149  Value cos = b.create<math::CosOp>(type, operand);
150  Value div = b.create<arith::DivFOp>(type, sin, cos);
151  rewriter.replaceOp(op, div);
152  return success();
153 }
154 
155 // asinh(float x) -> log(x + sqrt(x**2 + 1))
156 static LogicalResult convertAsinhOp(math::AsinhOp op,
157  PatternRewriter &rewriter) {
158  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
159  Value operand = op.getOperand();
160  Type opType = operand.getType();
161 
162  Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
163  Value fma = b.create<math::FmaOp>(operand, operand, one);
164  Value sqrt = b.create<math::SqrtOp>(fma);
165  Value add = b.create<arith::AddFOp>(operand, sqrt);
166  Value res = b.create<math::LogOp>(add);
167  rewriter.replaceOp(op, res);
168  return success();
169 }
170 
171 // acosh(float x) -> log(x + sqrt(x**2 - 1))
172 static LogicalResult convertAcoshOp(math::AcoshOp op,
173  PatternRewriter &rewriter) {
174  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
175  Value operand = op.getOperand();
176  Type opType = operand.getType();
177 
178  Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
179  Value fma = b.create<math::FmaOp>(operand, operand, negOne);
180  Value sqrt = b.create<math::SqrtOp>(fma);
181  Value add = b.create<arith::AddFOp>(operand, sqrt);
182  Value res = b.create<math::LogOp>(add);
183  rewriter.replaceOp(op, res);
184  return success();
185 }
186 
187 // atanh(float x) -> log((1 + x) / (1 - x)) / 2
188 static LogicalResult convertAtanhOp(math::AtanhOp op,
189  PatternRewriter &rewriter) {
190  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
191  Value operand = op.getOperand();
192  Type opType = operand.getType();
193 
194  Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
195  Value add = b.create<arith::AddFOp>(operand, one);
196  Value neg = b.create<arith::NegFOp>(operand);
197  Value sub = b.create<arith::AddFOp>(neg, one);
198  Value div = b.create<arith::DivFOp>(add, sub);
199  Value log = b.create<math::LogOp>(div);
200  Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
201  Value res = b.create<arith::MulFOp>(log, half);
202  rewriter.replaceOp(op, res);
203  return success();
204 }
205 
206 static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
207  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
208  Value operandA = op.getOperand(0);
209  Value operandB = op.getOperand(1);
210  Value operandC = op.getOperand(2);
211  Type type = op.getType();
212  Value mult = b.create<arith::MulFOp>(type, operandA, operandB);
213  Value add = b.create<arith::AddFOp>(type, mult, operandC);
214  rewriter.replaceOp(op, add);
215  return success();
216 }
217 
218 // Converts a floorf() function to the following:
219 // floorf(float x) ->
220 // y = (float)(int) x
221 // if (x < 0) then incr = -1 else incr = 0
222 // y = y + incr <= replace this op with the floorf op.
223 static LogicalResult convertFloorOp(math::FloorOp op,
224  PatternRewriter &rewriter) {
225  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
226  Value operand = op.getOperand();
227  Type opType = operand.getType();
228  Value fpFixedConvert = createTruncatedFPValue(operand, b);
229 
230  // Creating constants for later use.
231  Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
232  Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
233 
234  Value negCheck =
235  b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
236  Value incrValue =
237  b.create<arith::SelectOp>(op->getLoc(), negCheck, negOne, zero);
238  Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
239  rewriter.replaceOp(op, ret);
240  return success();
241 }
242 
243 // Converts a ceilf() function to the following:
244 // ceilf(float x) ->
245 // y = (float)(int) x
246 // if (x > y) then incr = 1 else incr = 0
247 // y = y + incr <= replace this op with the ceilf op.
248 static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
249  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
250  Value operand = op.getOperand();
251  Type opType = operand.getType();
252  Value fpFixedConvert = createTruncatedFPValue(operand, b);
253 
254  // Creating constants for later use.
255  Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
256  Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
257 
258  Value gtCheck = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand,
259  fpFixedConvert);
260  Value incrValue = b.create<arith::SelectOp>(op->getLoc(), gtCheck, one, zero);
261 
262  Value ret = b.create<arith::AddFOp>(opType, fpFixedConvert, incrValue);
263  rewriter.replaceOp(op, ret);
264  return success();
265 }
266 
267 // Convert `math.fpowi` to a series of `arith.mulf` operations.
268 // If the power is negative, we divide one by the result.
269 // If both the base and power are zero, the result is 1.
270 // In the case of non constant power, we convert the operation to `math.powf`.
271 static LogicalResult convertFPowIOp(math::FPowIOp op,
272  PatternRewriter &rewriter) {
273  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
274  Value base = op.getOperand(0);
275  Value power = op.getOperand(1);
276  Type baseType = base.getType();
277 
278  auto convertFPowItoPowf = [&]() -> LogicalResult {
279  Value castPowerToFp =
280  rewriter.create<arith::SIToFPOp>(op.getLoc(), baseType, power);
281  Value res = rewriter.create<math::PowFOp>(op.getLoc(), baseType, base,
282  castPowerToFp);
283  rewriter.replaceOp(op, res);
284  return success();
285  };
286 
287  Attribute cstAttr;
288  if (!matchPattern(power, m_Constant(&cstAttr)))
289  return convertFPowItoPowf();
290 
291  APInt value;
292  if (!matchPattern(cstAttr, m_ConstantInt(&value)))
293  return convertFPowItoPowf();
294 
295  int64_t powerInt = value.getSExtValue();
296  bool isNegative = powerInt < 0;
297  int64_t absPower = std::abs(powerInt);
298  Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
299  Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
300 
301  while (absPower > 0) {
302  if (absPower & 1)
303  res = b.create<arith::MulFOp>(baseType, base, res);
304  absPower >>= 1;
305  base = b.create<arith::MulFOp>(baseType, base, base);
306  }
307 
308  // Make sure not to introduce UB in case of negative power.
309  if (isNegative) {
310  auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
311  .getFloatSemantics();
312  Value zero =
313  createFloatConst(op->getLoc(), baseType,
314  APFloat::getZero(sem, /*Negative=*/false), rewriter);
315  Value negZero =
316  createFloatConst(op->getLoc(), baseType,
317  APFloat::getZero(sem, /*Negative=*/true), rewriter);
318  Value posInfinity =
319  createFloatConst(op->getLoc(), baseType,
320  APFloat::getInf(sem, /*Negative=*/false), rewriter);
321  Value negInfinity =
322  createFloatConst(op->getLoc(), baseType,
323  APFloat::getInf(sem, /*Negative=*/true), rewriter);
324  Value zeroEqCheck =
325  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, zero);
326  Value negZeroEqCheck =
327  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, res, negZero);
328  res = b.create<arith::DivFOp>(baseType, one, res);
329  res =
330  b.create<arith::SelectOp>(op->getLoc(), zeroEqCheck, posInfinity, res);
331  res = b.create<arith::SelectOp>(op->getLoc(), negZeroEqCheck, negInfinity,
332  res);
333  }
334 
335  rewriter.replaceOp(op, res);
336  return success();
337 }
338 
339 // Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
340 static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
341  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
342  Value operandA = op.getOperand(0);
343  Value operandB = op.getOperand(1);
344  Type opType = operandA.getType();
345  Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
346  Value two = createFloatConst(op->getLoc(), opType, 2.00, rewriter);
347  Value negOne = createFloatConst(op->getLoc(), opType, -1.00, rewriter);
348  Value opASquared = b.create<arith::MulFOp>(opType, operandA, operandA);
349  Value opBHalf = b.create<arith::DivFOp>(opType, operandB, two);
350 
351  Value logA = b.create<math::LogOp>(opType, opASquared);
352  Value mult = b.create<arith::MulFOp>(opType, opBHalf, logA);
353  Value expResult = b.create<math::ExpOp>(opType, mult);
354  Value negExpResult = b.create<arith::MulFOp>(opType, expResult, negOne);
355  Value remainder = b.create<arith::RemFOp>(opType, operandB, two);
356  Value negCheck =
357  b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operandA, zero);
358  Value oddPower =
359  b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, remainder, zero);
360  Value oddAndNeg = b.create<arith::AndIOp>(op->getLoc(), oddPower, negCheck);
361 
362  Value res = b.create<arith::SelectOp>(op->getLoc(), oddAndNeg, negExpResult,
363  expResult);
364  rewriter.replaceOp(op, res);
365  return success();
366 }
367 
368 // exp2f(float x) -> exp(x * ln(2))
369 // Proof: Let's say 2^x = y
370 // ln(2^x) = ln(y)
371 // x * ln(2) = ln(y) => e ^(x*ln(2)) = y
372 static LogicalResult convertExp2fOp(math::Exp2Op op,
373  PatternRewriter &rewriter) {
374  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
375  Value operand = op.getOperand();
376  Type opType = operand.getType();
377  Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
378  Value mult = b.create<arith::MulFOp>(opType, operand, ln2);
379  Value exp = b.create<math::ExpOp>(op->getLoc(), mult);
380  rewriter.replaceOp(op, exp);
381  return success();
382 }
383 
384 static LogicalResult convertRoundOp(math::RoundOp op,
385  PatternRewriter &rewriter) {
386  Location loc = op.getLoc();
387  ImplicitLocOpBuilder b(loc, rewriter);
388  Value operand = op.getOperand();
389  Type opType = operand.getType();
390  Type opEType = getElementTypeOrSelf(opType);
391 
392  if (!opEType.isF32()) {
393  return rewriter.notifyMatchFailure(op, "not a round of f32.");
394  }
395 
396  Type i32Ty = b.getI32Type();
397  if (auto shapedTy = dyn_cast<ShapedType>(opType))
398  i32Ty = shapedTy.clone(i32Ty);
399 
400  Value half = createFloatConst(loc, opType, 0.5, b);
401  Value c23 = createIntConst(loc, i32Ty, 23, b);
402  Value c127 = createIntConst(loc, i32Ty, 127, b);
403  Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
404 
405  Value incrValue = b.create<math::CopySignOp>(half, operand);
406  Value add = b.create<arith::AddFOp>(opType, operand, incrValue);
407  Value fpFixedConvert = createTruncatedFPValue(add, b);
408 
409  // There are three cases where adding 0.5 to the value and truncating by
410  // converting to an i64 does not result in the correct behavior:
411  //
412  // 1. Special values: +-inf and +-nan
413  // Casting these special values to i64 has undefined behavior. To identify
414  // these values, we use the fact that these values are the only float
415  // values with the maximum possible biased exponent.
416  //
417  // 2. Large values: 2^23 <= |x| <= INT_64_MAX
418  // Adding 0.5 to a float larger than or equal to 2^23 results in precision
419  // errors that sometimes round the value up and sometimes round the value
420  // down. For example:
421  // 8388608.0 + 0.5 = 8388608.0
422  // 8388609.0 + 0.5 = 8388610.0
423  //
424  // 3. Very large values: |x| > INT_64_MAX
425  // Casting to i64 a value greater than the max i64 value will overflow the
426  // i64 leading to wrong outputs.
427  //
428  // All three cases satisfy the property `biasedExp >= 23`.
429  Value operandBitcast = b.create<arith::BitcastOp>(i32Ty, operand);
430  Value operandExp = b.create<arith::AndIOp>(
431  b.create<arith::ShRUIOp>(operandBitcast, c23), expMask);
432  Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127);
433  Value isSpecialValOrLargeVal =
434  b.create<arith::CmpIOp>(arith::CmpIPredicate::sge, operandBiasedExp, c23);
435 
436  Value result = b.create<arith::SelectOp>(isSpecialValOrLargeVal, operand,
437  fpFixedConvert);
438  rewriter.replaceOp(op, result);
439  return success();
440 }
441 
442 // Converts math.ctlz to scf and arith operations. This is done
443 // by performing a binary search on the bits.
444 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
445  PatternRewriter &rewriter) {
446  auto operand = op.getOperand();
447  auto operandTy = operand.getType();
448  auto eTy = getElementTypeOrSelf(operandTy);
449  Location loc = op.getLoc();
450 
451  int32_t bitwidth = eTy.getIntOrFloatBitWidth();
452  if (bitwidth > 64)
453  return failure();
454 
455  uint64_t allbits = -1;
456  if (bitwidth < 64) {
457  allbits = allbits >> (64 - bitwidth);
458  }
459 
460  Value x = operand;
461  Value count = createIntConst(loc, operandTy, 0, rewriter);
462  for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
463  auto half = bw / 2;
464  auto bits = createIntConst(loc, operandTy, half, rewriter);
465  auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter);
466 
467  Value pred =
468  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule, x, mask);
469  Value add = rewriter.create<arith::AddIOp>(loc, count, bits);
470  Value shift = rewriter.create<arith::ShLIOp>(loc, x, bits);
471 
472  x = rewriter.create<arith::SelectOp>(loc, pred, shift, x);
473  count = rewriter.create<arith::SelectOp>(loc, pred, add, count);
474  }
475 
476  Value zero = createIntConst(loc, operandTy, 0, rewriter);
477  Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
478  operand, zero);
479 
480  Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter);
481  Value sel = rewriter.create<arith::SelectOp>(loc, pred, bwval, count);
482  rewriter.replaceOp(op, sel);
483  return success();
484 }
485 
486 // Convert `math.roundeven` into `math.round` + arith ops
487 static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
488  PatternRewriter &rewriter) {
489  Location loc = op.getLoc();
490  ImplicitLocOpBuilder b(loc, rewriter);
491  auto operand = op.getOperand();
492  Type operandTy = operand.getType();
493  Type resultTy = op.getType();
494  Type operandETy = getElementTypeOrSelf(operandTy);
495  Type resultETy = getElementTypeOrSelf(resultTy);
496 
497  if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
498  return rewriter.notifyMatchFailure(op, "not a roundeven of f16 or f32.");
499  }
500 
501  Type fTy = operandTy;
502  Type iTy = rewriter.getIntegerType(operandETy.getIntOrFloatBitWidth());
503  if (auto shapedTy = dyn_cast<ShapedType>(fTy)) {
504  iTy = shapedTy.clone(iTy);
505  }
506 
507  unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
508  // The width returned by getFPMantissaWidth includes the integer bit.
509  unsigned mantissaWidth =
510  llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
511  unsigned exponentWidth = bitWidth - mantissaWidth - 1;
512 
513  // The names of the variables correspond to f32.
514  // f64: 1 bit sign | 11 bits exponent | 52 bits mantissa.
515  // f32: 1 bit sign | 8 bits exponent | 23 bits mantissa.
516  // f16: 1 bit sign | 5 bits exponent | 10 bits mantissa.
517  Value c1Float = createFloatConst(loc, fTy, 1.0, b);
518  Value c0 = createIntConst(loc, iTy, 0, b);
519  Value c1 = createIntConst(loc, iTy, 1, b);
520  Value cNeg1 = createIntConst(loc, iTy, -1, b);
521  Value c23 = createIntConst(loc, iTy, mantissaWidth, b);
522  Value c31 = createIntConst(loc, iTy, bitWidth - 1, b);
523  Value c127 = createIntConst(loc, iTy, (1ull << (exponentWidth - 1)) - 1, b);
524  Value c2To22 = createIntConst(loc, iTy, 1ull << (mantissaWidth - 1), b);
525  Value c23Mask = createIntConst(loc, iTy, (1ull << mantissaWidth) - 1, b);
526  Value expMask = createIntConst(loc, iTy, (1ull << exponentWidth) - 1, b);
527 
528  Value operandBitcast = b.create<arith::BitcastOp>(iTy, operand);
529  Value round = b.create<math::RoundOp>(operand);
530  Value roundBitcast = b.create<arith::BitcastOp>(iTy, round);
531 
532  // Get biased exponents for operand and round(operand)
533  Value operandExp = b.create<arith::AndIOp>(
534  b.create<arith::ShRUIOp>(operandBitcast, c23), expMask);
535  Value operandBiasedExp = b.create<arith::SubIOp>(operandExp, c127);
536  Value roundExp = b.create<arith::AndIOp>(
537  b.create<arith::ShRUIOp>(roundBitcast, c23), expMask);
538  Value roundBiasedExp = b.create<arith::SubIOp>(roundExp, c127);
539 
540  auto safeShiftRight = [&](Value x, Value shift) -> Value {
541  // Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior
542  Value clampedShift = b.create<arith::MaxSIOp>(shift, c0);
543  clampedShift = b.create<arith::MinSIOp>(clampedShift, c31);
544  return b.create<arith::ShRUIOp>(x, clampedShift);
545  };
546 
547  auto maskMantissa = [&](Value mantissa,
548  Value mantissaMaskRightShift) -> Value {
549  Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
550  return b.create<arith::AndIOp>(mantissa, shiftedMantissaMask);
551  };
552 
553  // A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring
554  // the leftmost `clamp(biasedExp - 1, 0, 23)` bits, is zero. Large numbers
555  // with `biasedExp > 23` (numbers where there is not enough precision to store
556  // decimals) are always even, and they satisfy the even condition trivially
557  // since the mantissa without all its bits is zero. The even condition
558  // is also true for +-0, since they have `biasedExp = -127` and the entire
559  // mantissa is zero. The case of +-1 has to be handled separately. Here
560  // we identify these values by noting that +-1 are the only whole numbers with
561  // `biasedExp == 0`.
562  //
563  // The special values +-inf and +-nan also satisfy the same property that
564  // whole non-unit even numbers satisfy. In particular, the special values have
565  // `biasedExp > 23`, so they get treated as large numbers with no room for
566  // decimals, which are always even.
567  Value roundBiasedExpEq0 =
568  b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, roundBiasedExp, c0);
569  Value roundBiasedExpMinus1 = b.create<arith::SubIOp>(roundBiasedExp, c1);
570  Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
571  Value roundIsNotEvenOrSpecialVal = b.create<arith::CmpIOp>(
572  arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
573  roundIsNotEvenOrSpecialVal =
574  b.create<arith::OrIOp>(roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
575 
576  // A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive
577  // integers if the bit at index `biasedExp` starting from the left in the
578  // mantissa is 1 and all the bits to the right are zero. Values with
579  // `biasedExp >= 23` don't have decimals, so they are never halfway. The
580  // values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`,
581  // so these are handled separately. In particular, if `biasedExp == -1`, the
582  // value is halfway if the entire mantissa is zero.
583  Value operandBiasedExpEqNeg1 = b.create<arith::CmpIOp>(
584  arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
585  Value expectedOperandMaskedMantissa = b.create<arith::SelectOp>(
586  operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
587  Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
588  Value operandIsHalfway =
589  b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, operandMaskedMantissa,
590  expectedOperandMaskedMantissa);
591  // Ensure `biasedExp` is in the valid range for half values.
592  Value operandBiasedExpGeNeg1 = b.create<arith::CmpIOp>(
593  arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
594  Value operandBiasedExpLt23 =
595  b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, operandBiasedExp, c23);
596  operandIsHalfway =
597  b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpLt23);
598  operandIsHalfway =
599  b.create<arith::AndIOp>(operandIsHalfway, operandBiasedExpGeNeg1);
600 
601  // Adjust rounded operand with `round(operand) - sign(operand)` to correct the
602  // case where `round` rounded in the opposite direction of `roundeven`.
603  Value sign = b.create<math::CopySignOp>(c1Float, operand);
604  Value roundShifted = b.create<arith::SubFOp>(round, sign);
605  // If the rounded value is even or a special value, we default to the behavior
606  // of `math.round`.
607  Value needsShift =
608  b.create<arith::AndIOp>(roundIsNotEvenOrSpecialVal, operandIsHalfway);
609  Value result = b.create<arith::SelectOp>(needsShift, roundShifted, round);
610  // The `x - sign` adjustment does not preserve the sign when we are adjusting
611  // the value -1 to -0. So here the sign is copied again to ensure that -0.5 is
612  // rounded to -0.0.
613  result = b.create<math::CopySignOp>(result, operand);
614  rewriter.replaceOp(op, result);
615  return success();
616 }
617 
618 // Convert `math.rsqrt` into `arith.divf` + `math.sqrt`
619 static LogicalResult convertRsqrtOp(math::RsqrtOp op,
620  PatternRewriter &rewriter) {
621 
622  auto operand = op.getOperand();
623  auto operandTy = operand.getType();
624  auto eTy = getElementTypeOrSelf(operandTy);
625  if (!isa<FloatType>(eTy))
626  return failure();
627 
628  Location loc = op->getLoc();
629  auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);
630  auto sqrtOp = rewriter.create<math::SqrtOp>(loc, operand);
631  rewriter.replaceOpWithNewOp<arith::DivFOp>(op, constOneFloat, sqrtOp);
632  return success();
633 }
634 
636  patterns.add(convertCtlzOp);
637 }
638 
640  patterns.add(convertSinhOp);
641 }
642 
644  patterns.add(convertCoshOp);
645 }
646 
648  patterns.add(convertTanOp);
649 }
650 
652  patterns.add(convertTanhOp);
653 }
654 
656  patterns.add(convertAsinhOp);
657 }
658 
660  patterns.add(convertAcoshOp);
661 }
662 
664  patterns.add(convertAtanhOp);
665 }
666 
668  patterns.add(convertFmaFOp);
669 }
670 
672  patterns.add(convertCeilOp);
673 }
674 
676  patterns.add(convertExp2fOp);
677 }
678 
680  patterns.add(convertPowfOp);
681 }
682 
684  patterns.add(convertFPowIOp);
685 }
686 
688  patterns.add(convertRoundOp);
689 }
690 
692  patterns.add(convertFloorOp);
693 }
694 
696  patterns.add(convertRoundEvenOp);
697 }
698 
700  patterns.add(convertRsqrtOp);
701 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static LogicalResult convertRsqrtOp(math::RsqrtOp op, PatternRewriter &rewriter)
static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b)
static LogicalResult convertFPowIOp(math::FPowIOp op, PatternRewriter &rewriter)
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundOp(math::RoundOp op, PatternRewriter &rewriter)
static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter)
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, PatternRewriter &rewriter)
static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter)
static LogicalResult convertFloorOp(math::FloorOp op, PatternRewriter &rewriter)
static LogicalResult convertAtanhOp(math::AtanhOp op, PatternRewriter &rewriter)
static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, PatternRewriter &rewriter)
static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter)
static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter)
static Value createFloatConst(Location loc, Type type, APFloat value, OpBuilder &b)
Create a float constant.
static LogicalResult convertAsinhOp(math::AsinhOp op, PatternRewriter &rewriter)
static Value createIntConst(Location loc, Type type, int64_t value, OpBuilder &b)
Create an integer constant.
static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter)
Expands tanh op into 1-exp^{-2x} / 1+exp^{-2x} To avoid overflow we exploit the reflection symmetry t...
static LogicalResult convertAcoshOp(math::AcoshOp op, PatternRewriter &rewriter)
static LogicalResult convertExp2fOp(math::Exp2Op op, PatternRewriter &rewriter)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:242
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:265
IntegerType getI64Type()
Definition: Builders.cpp:89
IntegerType getI32Type()
Definition: Builders.cpp:87
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:91
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:210
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
Value getOperand(unsigned idx)
Definition: Operation.h:345
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:52
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:126
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:135
Fraction abs(const Fraction &f)
Definition: Fraction.h:106
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
void populateExpandSinhPattern(RewritePatternSet &patterns)
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:438
void populateExpandRsqrtPattern(RewritePatternSet &patterns)
void populateExpandTanhPattern(RewritePatternSet &patterns)
void populateExpandFmaFPattern(RewritePatternSet &patterns)
void populateExpandAcoshPattern(RewritePatternSet &patterns)
void populateExpandFPowIPattern(RewritePatternSet &patterns)
void populateExpandPowFPattern(RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateExpandTanPattern(RewritePatternSet &patterns)
void populateExpandCoshPattern(RewritePatternSet &patterns)
void populateExpandRoundFPattern(RewritePatternSet &patterns)
void populateExpandExp2FPattern(RewritePatternSet &patterns)
void populateExpandCeilFPattern(RewritePatternSet &patterns)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
void populateExpandCtlzPattern(RewritePatternSet &patterns)
void populateExpandAsinhPattern(RewritePatternSet &patterns)
void populateExpandRoundEvenPattern(RewritePatternSet &patterns)
void populateExpandAtanhPattern(RewritePatternSet &patterns)
void populateExpandFloorFPattern(RewritePatternSet &patterns)