MLIR  20.0.0git
PolynomialApproximation.cpp
Go to the documentation of this file.
1 //===- PolynomialApproximation.cpp - Approximate math operations ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements expansion of math operations to fast approximations
10 // that do not rely on any of the library functions.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <climits>
15 #include <cmath>
16 #include <cstddef>
17 
26 #include "mlir/IR/Builders.h"
27 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/OpDefinition.h"
30 #include "mlir/IR/PatternMatch.h"
31 #include "mlir/IR/TypeUtilities.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/Support/MathExtras.h"
37 
38 using namespace mlir;
39 using namespace mlir::math;
40 using namespace mlir::vector;
41 
42 // Helper to encapsulate a vector's shape (including scalable dims).
43 struct VectorShape {
46 
47  bool empty() const { return sizes.empty(); }
48 };
49 
50 // Returns vector shape if the type is a vector. Returns an empty shape if it is
51 // not a vector.
53  auto vectorType = dyn_cast<VectorType>(type);
54  return vectorType
55  ? VectorShape{vectorType.getShape(), vectorType.getScalableDims()}
56  : VectorShape{};
57 }
58 
59 static VectorShape vectorShape(Value value) {
60  return vectorShape(value.getType());
61 }
62 
63 //----------------------------------------------------------------------------//
64 // Broadcast scalar types and values into vector types and values.
65 //----------------------------------------------------------------------------//
66 
67 // Broadcasts scalar type into vector type (iff shape is non-scalar).
68 static Type broadcast(Type type, VectorShape shape) {
69  assert(!isa<VectorType>(type) && "must be scalar type");
70  return !shape.empty()
71  ? VectorType::get(shape.sizes, type, shape.scalableFlags)
72  : type;
73 }
74 
75 // Broadcasts scalar value into vector (iff shape is non-scalar).
76 static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
77  VectorShape shape) {
78  assert(!isa<VectorType>(value.getType()) && "must be scalar value");
79  auto type = broadcast(value.getType(), shape);
80  return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
81 }
82 
83 //----------------------------------------------------------------------------//
84 // Helper function to handle n-D vectors with 1-D operations.
85 //----------------------------------------------------------------------------//
86 
87 // Expands and unrolls n-D vector operands into multiple fixed size 1-D vectors
88 // and calls the compute function with 1-D vector operands. Stitches back all
89 // results into the original n-D vector result.
90 //
91 // Examples: vectorWidth = 8
92 // - vector<4x8xf32> unrolled 4 times
93 // - vector<16xf32> expanded to vector<2x8xf32> and unrolled 2 times
94 // - vector<4x16xf32> expanded to vector<4x2x8xf32> and unrolled 4*2 times
95 //
96 // Some math approximations rely on ISA-specific operations that only accept
97 // fixed size 1-D vectors (e.g. AVX expects vectors of width 8).
98 //
99 // It is the caller's responsibility to verify that the inner dimension is
100 // divisible by the vectorWidth, and that all operands have the same vector
101 // shape.
102 static Value
104  ValueRange operands, int64_t vectorWidth,
105  llvm::function_ref<Value(ValueRange)> compute) {
106  assert(!operands.empty() && "operands must be not empty");
107  assert(vectorWidth > 0 && "vector width must be larger than 0");
108 
109  VectorType inputType = cast<VectorType>(operands[0].getType());
110  ArrayRef<int64_t> inputShape = inputType.getShape();
111 
112  // If input shape matches target vector width, we can just call the
113  // user-provided compute function with the operands.
114  if (inputShape == llvm::ArrayRef(vectorWidth))
115  return compute(operands);
116 
117  // Check if the inner dimension has to be expanded, or we can directly iterate
118  // over the outer dimensions of the vector.
119  int64_t innerDim = inputShape.back();
120  int64_t expansionDim = innerDim / vectorWidth;
121  assert((innerDim % vectorWidth == 0) && "invalid inner dimension size");
122 
123  // Maybe expand operands to the higher rank vector shape that we'll use to
124  // iterate over and extract one dimensional vectors.
125  SmallVector<int64_t> expandedShape(inputShape.begin(), inputShape.end());
126  SmallVector<Value> expandedOperands(operands);
127 
128  if (expansionDim > 1) {
129  // Expand shape from [..., innerDim] to [..., expansionDim, vectorWidth].
130  expandedShape.insert(expandedShape.end() - 1, expansionDim);
131  expandedShape.back() = vectorWidth;
132 
133  for (unsigned i = 0; i < operands.size(); ++i) {
134  auto operand = operands[i];
135  auto eltType = cast<VectorType>(operand.getType()).getElementType();
136  auto expandedType = VectorType::get(expandedShape, eltType);
137  expandedOperands[i] =
138  builder.create<vector::ShapeCastOp>(expandedType, operand);
139  }
140  }
141 
142  // Iterate over all outer dimensions of the compute shape vector type.
143  auto iterationDims = ArrayRef<int64_t>(expandedShape).drop_back();
144  int64_t maxIndex = computeMaxLinearIndex(iterationDims);
145  auto strides = computeStrides(iterationDims);
146 
147  // Compute results for each one dimensional vector.
148  SmallVector<Value> results(maxIndex);
149 
150  for (int64_t i = 0; i < maxIndex; ++i) {
151  auto offsets = delinearize(i, strides);
152 
153  SmallVector<Value> extracted(expandedOperands.size());
154  for (const auto &tuple : llvm::enumerate(expandedOperands))
155  extracted[tuple.index()] =
156  builder.create<vector::ExtractOp>(tuple.value(), offsets);
157 
158  results[i] = compute(extracted);
159  }
160 
161  // Stitch results together into one large vector.
162  Type resultEltType = cast<VectorType>(results[0].getType()).getElementType();
163  Type resultExpandedType = VectorType::get(expandedShape, resultEltType);
164  Value result = builder.create<arith::ConstantOp>(
165  resultExpandedType, builder.getZeroAttr(resultExpandedType));
166 
167  for (int64_t i = 0; i < maxIndex; ++i)
168  result = builder.create<vector::InsertOp>(results[i], result,
169  delinearize(i, strides));
170 
171  // Reshape back to the original vector shape.
172  return builder.create<vector::ShapeCastOp>(
173  VectorType::get(inputShape, resultEltType), result);
174 }
175 
176 //----------------------------------------------------------------------------//
177 // Helper functions to create constants.
178 //----------------------------------------------------------------------------//
179 
180 static Value floatCst(ImplicitLocOpBuilder &builder, float value,
181  Type elementType) {
182  assert((elementType.isF16() || elementType.isF32()) &&
183  "x must be f16 or f32 type.");
184  return builder.create<arith::ConstantOp>(
185  builder.getFloatAttr(elementType, value));
186 }
187 
188 static Value f32Cst(ImplicitLocOpBuilder &builder, double value) {
189  return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value));
190 }
191 
192 static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) {
193  return builder.create<arith::ConstantOp>(builder.getI32IntegerAttr(value));
194 }
195 
196 static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) {
197  Value i32Value = i32Cst(builder, static_cast<int32_t>(bits));
198  return builder.create<arith::BitcastOp>(builder.getF32Type(), i32Value);
199 }
200 
201 //----------------------------------------------------------------------------//
202 // Helper functions to build math functions approximations.
203 //----------------------------------------------------------------------------//
204 
205 // Return the minimum of the two values or NaN if value is NaN
206 static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound) {
207  return builder.create<arith::SelectOp>(
208  builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, value, bound),
209  value, bound);
210 }
211 
212 // Return the maximum of the two values or NaN if value is NaN
213 static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound) {
214  return builder.create<arith::SelectOp>(
215  builder.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, value, bound),
216  value, bound);
217 }
218 
219 // Return the clamped value or NaN if value is NaN
220 static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
221  Value upperBound) {
222  return max(builder, min(builder, value, upperBound), lowerBound);
223 }
224 
225 // Decomposes given floating point value `arg` into a normalized fraction and
226 // an integral power of two (see std::frexp). Returned values have float type.
227 static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
228  bool isPositive = false) {
229  assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
230  VectorShape shape = vectorShape(arg);
231 
232  auto bcast = [&](Value value) -> Value {
233  return broadcast(builder, value, shape);
234  };
235 
236  auto i32 = builder.getIntegerType(32);
237  auto i32Vec = broadcast(i32, shape);
238  auto f32Vec = broadcast(builder.getF32Type(), shape);
239 
240  Value cst126f = f32Cst(builder, 126.0f);
241  Value cstHalf = f32Cst(builder, 0.5f);
242  Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u);
243 
244  // Bitcast to i32 for bitwise operations.
245  Value i32Half = builder.create<arith::BitcastOp>(i32, cstHalf);
246  Value i32InvMantMask = builder.create<arith::BitcastOp>(i32, cstInvMantMask);
247  Value i32Arg = builder.create<arith::BitcastOp>(i32Vec, arg);
248 
249  // Compute normalized fraction.
250  Value tmp0 = builder.create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask));
251  Value tmp1 = builder.create<arith::OrIOp>(tmp0, bcast(i32Half));
252  Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1);
253 
254  // Compute exponent.
255  Value arg0 = isPositive ? arg : builder.create<math::AbsFOp>(arg);
256  Value biasedExponentBits = builder.create<arith::ShRUIOp>(
257  builder.create<arith::BitcastOp>(i32Vec, arg0),
258  bcast(i32Cst(builder, 23)));
259  Value biasedExponent =
260  builder.create<arith::SIToFPOp>(f32Vec, biasedExponentBits);
261  Value exponent =
262  builder.create<arith::SubFOp>(biasedExponent, bcast(cst126f));
263 
264  return {normalizedFraction, exponent};
265 }
266 
267 // Computes exp2 for an i32 argument.
268 static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
269  assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type");
270  VectorShape shape = vectorShape(arg);
271 
272  auto bcast = [&](Value value) -> Value {
273  return broadcast(builder, value, shape);
274  };
275 
276  auto f32Vec = broadcast(builder.getF32Type(), shape);
277  // The exponent of f32 located at 23-bit.
278  auto exponetBitLocation = bcast(i32Cst(builder, 23));
279  // Set the exponent bias to zero.
280  auto bias = bcast(i32Cst(builder, 127));
281 
282  Value biasedArg = builder.create<arith::AddIOp>(arg, bias);
283  Value exp2ValueInt =
284  builder.create<arith::ShLIOp>(biasedArg, exponetBitLocation);
285  Value exp2ValueF32 = builder.create<arith::BitcastOp>(f32Vec, exp2ValueInt);
286 
287  return exp2ValueF32;
288 }
289 
290 namespace {
291 Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
292  llvm::ArrayRef<Value> coeffs, Value x) {
293  Type elementType = getElementTypeOrSelf(x);
294  assert((elementType.isF32() || elementType.isF16()) &&
295  "x must be f32 or f16 type");
296  VectorShape shape = vectorShape(x);
297 
298  if (coeffs.empty())
299  return broadcast(builder, floatCst(builder, 0.0f, elementType), shape);
300 
301  if (coeffs.size() == 1)
302  return coeffs[0];
303 
304  Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1],
305  coeffs[coeffs.size() - 2]);
306  for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) {
307  res = builder.create<math::FmaOp>(x, res, coeffs[i]);
308  }
309  return res;
310 }
311 } // namespace
312 
313 //----------------------------------------------------------------------------//
314 // Helper function/pattern to insert casts for reusing F32 bit expansion.
315 //----------------------------------------------------------------------------//
316 
317 template <typename T>
318 LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) {
319  // Conservatively only allow where the operand and result types are exactly 1.
320  Type origType = op->getResultTypes().front();
321  for (Type t : llvm::drop_begin(op->getResultTypes()))
322  if (origType != t)
323  return rewriter.notifyMatchFailure(op, "required all types to match");
324  for (Type t : op->getOperandTypes())
325  if (origType != t)
326  return rewriter.notifyMatchFailure(op, "required all types to match");
327 
328  // Skip if already F32 or larger than 32 bits.
329  if (getElementTypeOrSelf(origType).isF32() ||
330  getElementTypeOrSelf(origType).getIntOrFloatBitWidth() > 32)
331  return failure();
332 
333  // Create F32 equivalent type.
334  Type newType;
335  if (auto shaped = dyn_cast<ShapedType>(origType)) {
336  newType = shaped.clone(rewriter.getF32Type());
337  } else if (isa<FloatType>(origType)) {
338  newType = rewriter.getF32Type();
339  } else {
340  return rewriter.notifyMatchFailure(op,
341  "unable to find F32 equivalent type");
342  }
343 
344  Location loc = op->getLoc();
345  SmallVector<Value> operands;
346  for (auto operand : op->getOperands())
347  operands.push_back(rewriter.create<arith::ExtFOp>(loc, newType, operand));
348  auto result =
349  rewriter.create<T>(loc, TypeRange{newType}, operands, op->getAttrs());
350  rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, origType, result);
351  return success();
352 }
353 
354 namespace {
355 // Pattern to cast to F32 to reuse F32 expansion as fallback for single-result
356 // op.
357 // TODO: Consider revising to avoid adding multiple casts for a subgraph that is
358 // all in lower precision. Currently this is only fallback support and performs
359 // simplistic casting.
360 template <typename T>
361 struct ReuseF32Expansion : public OpRewritePattern<T> {
362 public:
364  LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const final {
365  static_assert(
366  T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(),
367  "requires same operands and result types");
368  return insertCasts<T>(op, rewriter);
369  }
370 };
371 } // namespace
372 
373 //----------------------------------------------------------------------------//
374 // AtanOp approximation.
375 //----------------------------------------------------------------------------//
376 
377 namespace {
378 struct AtanApproximation : public OpRewritePattern<math::AtanOp> {
379 public:
381 
382  LogicalResult matchAndRewrite(math::AtanOp op,
383  PatternRewriter &rewriter) const final;
384 };
385 } // namespace
386 
387 LogicalResult
388 AtanApproximation::matchAndRewrite(math::AtanOp op,
389  PatternRewriter &rewriter) const {
390  auto operand = op.getOperand();
391  if (!getElementTypeOrSelf(operand).isF32())
392  return rewriter.notifyMatchFailure(op, "unsupported operand type");
393 
394  VectorShape shape = vectorShape(op.getOperand());
395 
396  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
397  Value abs = builder.create<math::AbsFOp>(operand);
398 
399  auto one = broadcast(builder, f32Cst(builder, 1.0), shape);
400 
401  // When 0.66 < x <= 2.41 we do (x-1) / (x+1):
402  auto twoThirds = broadcast(builder, f32Cst(builder, 0.66), shape);
403  Value cmp2 =
404  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, twoThirds);
405  Value addone = builder.create<arith::AddFOp>(abs, one);
406  Value subone = builder.create<arith::SubFOp>(abs, one);
407  Value xnum = builder.create<arith::SelectOp>(cmp2, subone, abs);
408  Value xden = builder.create<arith::SelectOp>(cmp2, addone, one);
409 
410  auto bcast = [&](Value value) -> Value {
411  return broadcast(builder, value, shape);
412  };
413 
414  // Break into the <= 0.66 or > 2.41 we do x or 1/x:
415  auto tan3pio8 = bcast(f32Cst(builder, 2.41421356237309504880));
416  Value cmp1 =
417  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, tan3pio8);
418  xnum = builder.create<arith::SelectOp>(cmp1, one, xnum);
419  xden = builder.create<arith::SelectOp>(cmp1, abs, xden);
420 
421  Value x = builder.create<arith::DivFOp>(xnum, xden);
422  Value xx = builder.create<arith::MulFOp>(x, x);
423 
424  // Perform the Taylor series approximation for atan over the range
425  // [0.0, 0.66].
426  auto p0 = bcast(f32Cst(builder, -8.750608600031904122785e-01));
427  auto p1 = bcast(f32Cst(builder, -1.615753718733365076637e+01));
428  auto p2 = bcast(f32Cst(builder, -7.500855792314704667340e+01));
429  auto p3 = bcast(f32Cst(builder, -1.228866684490136173410e+02));
430  auto p4 = bcast(f32Cst(builder, -6.485021904942025371773e+01));
431  auto q0 = bcast(f32Cst(builder, +2.485846490142306297962e+01));
432  auto q1 = bcast(f32Cst(builder, +1.650270098316988542046e+02));
433  auto q2 = bcast(f32Cst(builder, +4.328810604912902668951e+02));
434  auto q3 = bcast(f32Cst(builder, +4.853903996359136964868e+02));
435  auto q4 = bcast(f32Cst(builder, +1.945506571482613964425e+02));
436 
437  // Apply the polynomial approximation for the numerator:
438  Value n = p0;
439  n = builder.create<math::FmaOp>(xx, n, p1);
440  n = builder.create<math::FmaOp>(xx, n, p2);
441  n = builder.create<math::FmaOp>(xx, n, p3);
442  n = builder.create<math::FmaOp>(xx, n, p4);
443  n = builder.create<arith::MulFOp>(n, xx);
444 
445  // Apply the polynomial approximation for the denominator:
446  Value d = q0;
447  d = builder.create<math::FmaOp>(xx, d, q1);
448  d = builder.create<math::FmaOp>(xx, d, q2);
449  d = builder.create<math::FmaOp>(xx, d, q3);
450  d = builder.create<math::FmaOp>(xx, d, q4);
451 
452  // Compute approximation of theta:
453  Value ans0 = builder.create<arith::DivFOp>(n, d);
454  ans0 = builder.create<math::FmaOp>(ans0, x, x);
455 
456  // Correct for the input mapping's angles:
457  Value mpi4 = bcast(f32Cst(builder, llvm::numbers::pi / 4));
458  Value ans2 = builder.create<arith::AddFOp>(mpi4, ans0);
459  Value ans = builder.create<arith::SelectOp>(cmp2, ans2, ans0);
460 
461  Value mpi2 = bcast(f32Cst(builder, llvm::numbers::pi / 2));
462  Value ans1 = builder.create<arith::SubFOp>(mpi2, ans0);
463  ans = builder.create<arith::SelectOp>(cmp1, ans1, ans);
464 
465  // Correct for signing of the input.
466  rewriter.replaceOpWithNewOp<math::CopySignOp>(op, ans, operand);
467  return success();
468 }
469 
470 //----------------------------------------------------------------------------//
471 // AtanOp approximation.
472 //----------------------------------------------------------------------------//
473 
474 namespace {
475 struct Atan2Approximation : public OpRewritePattern<math::Atan2Op> {
476 public:
478 
479  LogicalResult matchAndRewrite(math::Atan2Op op,
480  PatternRewriter &rewriter) const final;
481 };
482 } // namespace
483 
484 LogicalResult
485 Atan2Approximation::matchAndRewrite(math::Atan2Op op,
486  PatternRewriter &rewriter) const {
487  auto y = op.getOperand(0);
488  auto x = op.getOperand(1);
489  if (!getElementTypeOrSelf(x).isF32())
490  return rewriter.notifyMatchFailure(op, "unsupported operand type");
491 
492  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
493  VectorShape shape = vectorShape(op.getResult());
494 
495  // Compute atan in the valid range.
496  auto div = builder.create<arith::DivFOp>(y, x);
497  auto atan = builder.create<math::AtanOp>(div);
498 
499  // Determine what the atan would be for a 180 degree rotation.
500  auto zero = broadcast(builder, f32Cst(builder, 0.0f), shape);
501  auto pi = broadcast(builder, f32Cst(builder, 3.14159265359f), shape);
502  auto addPi = builder.create<arith::AddFOp>(atan, pi);
503  auto subPi = builder.create<arith::SubFOp>(atan, pi);
504  auto atanGt =
505  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, atan, zero);
506  auto flippedAtan = builder.create<arith::SelectOp>(atanGt, subPi, addPi);
507 
508  // Determine whether to directly use atan or use the 180 degree flip
509  auto xGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zero);
510  Value result = builder.create<arith::SelectOp>(xGt, atan, flippedAtan);
511 
512  // Handle x = 0, y > 0
513  Value xZero =
514  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, x, zero);
515  Value yGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, y, zero);
516  Value isHalfPi = builder.create<arith::AndIOp>(xZero, yGt);
517  auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape);
518  result = builder.create<arith::SelectOp>(isHalfPi, halfPi, result);
519 
520  // Handle x = 0, y < 0
521  Value yLt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, y, zero);
522  Value isNegativeHalfPiPi = builder.create<arith::AndIOp>(xZero, yLt);
523  auto negativeHalfPiPi =
524  broadcast(builder, f32Cst(builder, -1.57079632679f), shape);
525  result = builder.create<arith::SelectOp>(isNegativeHalfPiPi, negativeHalfPiPi,
526  result);
527 
528  // Handle x = 0, y = 0;
529  Value yZero =
530  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, y, zero);
531  Value isNan = builder.create<arith::AndIOp>(xZero, yZero);
532  Value cstNan = broadcast(builder, f32FromBits(builder, 0x7fc00000), shape);
533  result = builder.create<arith::SelectOp>(isNan, cstNan, result);
534 
535  rewriter.replaceOp(op, result);
536  return success();
537 }
538 
539 //----------------------------------------------------------------------------//
540 // TanhOp approximation.
541 //----------------------------------------------------------------------------//
542 
543 namespace {
544 struct TanhApproximation : public OpRewritePattern<math::TanhOp> {
545 public:
547 
548  LogicalResult matchAndRewrite(math::TanhOp op,
549  PatternRewriter &rewriter) const final;
550 };
551 } // namespace
552 
553 LogicalResult
554 TanhApproximation::matchAndRewrite(math::TanhOp op,
555  PatternRewriter &rewriter) const {
557  return rewriter.notifyMatchFailure(op, "unsupported operand type");
558 
559  VectorShape shape = vectorShape(op.getOperand());
560 
561  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
562  auto bcast = [&](Value value) -> Value {
563  return broadcast(builder, value, shape);
564  };
565 
566  // Clamp operand into [plusClamp, minusClamp] range.
567  Value minusClamp = bcast(f32Cst(builder, -7.99881172180175781f));
568  Value plusClamp = bcast(f32Cst(builder, 7.99881172180175781f));
569  Value x = clamp(builder, op.getOperand(), minusClamp, plusClamp);
570 
571  // Mask for tiny values that are approximated with `operand`.
572  Value tiny = bcast(f32Cst(builder, 0.0004f));
573  Value tinyMask = builder.create<arith::CmpFOp>(
574  arith::CmpFPredicate::OLT, builder.create<math::AbsFOp>(op.getOperand()),
575  tiny);
576 
577  // The monomial coefficients of the numerator polynomial (odd).
578  Value alpha1 = bcast(f32Cst(builder, 4.89352455891786e-03f));
579  Value alpha3 = bcast(f32Cst(builder, 6.37261928875436e-04f));
580  Value alpha5 = bcast(f32Cst(builder, 1.48572235717979e-05f));
581  Value alpha7 = bcast(f32Cst(builder, 5.12229709037114e-08f));
582  Value alpha9 = bcast(f32Cst(builder, -8.60467152213735e-11f));
583  Value alpha11 = bcast(f32Cst(builder, 2.00018790482477e-13f));
584  Value alpha13 = bcast(f32Cst(builder, -2.76076847742355e-16f));
585 
586  // The monomial coefficients of the denominator polynomial (even).
587  Value beta0 = bcast(f32Cst(builder, 4.89352518554385e-03f));
588  Value beta2 = bcast(f32Cst(builder, 2.26843463243900e-03f));
589  Value beta4 = bcast(f32Cst(builder, 1.18534705686654e-04f));
590  Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f));
591 
592  // Since the polynomials are odd/even, we need x^2.
593  Value x2 = builder.create<arith::MulFOp>(x, x);
594 
595  // Evaluate the numerator polynomial p.
596  Value p = builder.create<math::FmaOp>(x2, alpha13, alpha11);
597  p = builder.create<math::FmaOp>(x2, p, alpha9);
598  p = builder.create<math::FmaOp>(x2, p, alpha7);
599  p = builder.create<math::FmaOp>(x2, p, alpha5);
600  p = builder.create<math::FmaOp>(x2, p, alpha3);
601  p = builder.create<math::FmaOp>(x2, p, alpha1);
602  p = builder.create<arith::MulFOp>(x, p);
603 
604  // Evaluate the denominator polynomial q.
605  Value q = builder.create<math::FmaOp>(x2, beta6, beta4);
606  q = builder.create<math::FmaOp>(x2, q, beta2);
607  q = builder.create<math::FmaOp>(x2, q, beta0);
608 
609  // Divide the numerator by the denominator.
610  Value res = builder.create<arith::SelectOp>(
611  tinyMask, x, builder.create<arith::DivFOp>(p, q));
612 
613  rewriter.replaceOp(op, res);
614 
615  return success();
616 }
617 
618 #define LN2_VALUE \
619  0.693147180559945309417232121458176568075500134360255254120680009493393621L
620 #define LOG2E_VALUE \
621  1.442695040888963407359924681001892137426645954152985934135449406931109219L
622 
623 //----------------------------------------------------------------------------//
624 // LogOp and Log2Op approximation.
625 //----------------------------------------------------------------------------//
626 
627 namespace {
628 template <typename Op>
629 struct LogApproximationBase : public OpRewritePattern<Op> {
631 
632  /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise.
633  LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter,
634  bool base2) const;
635 };
636 } // namespace
637 
638 // This approximation comes from Julien Pommier's SSE math library.
639 // Link: http://gruntthepeon.free.fr/ssemath
640 template <typename Op>
641 LogicalResult
642 LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
643  bool base2) const {
645  return rewriter.notifyMatchFailure(op, "unsupported operand type");
646 
647  VectorShape shape = vectorShape(op.getOperand());
648 
649  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
650  auto bcast = [&](Value value) -> Value {
651  return broadcast(builder, value, shape);
652  };
653 
654  Value cstZero = bcast(f32Cst(builder, 0.0f));
655  Value cstOne = bcast(f32Cst(builder, 1.0f));
656  Value cstNegHalf = bcast(f32Cst(builder, -0.5f));
657 
658  // The smallest non denormalized float number.
659  Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u));
660  Value cstMinusInf = bcast(f32FromBits(builder, 0xff800000u));
661  Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));
662  Value cstNan = bcast(f32FromBits(builder, 0x7fc00000));
663 
664  // Polynomial coefficients.
665  Value cstCephesSQRTHF = bcast(f32Cst(builder, 0.707106781186547524f));
666  Value cstCephesLogP0 = bcast(f32Cst(builder, 7.0376836292E-2f));
667  Value cstCephesLogP1 = bcast(f32Cst(builder, -1.1514610310E-1f));
668  Value cstCephesLogP2 = bcast(f32Cst(builder, 1.1676998740E-1f));
669  Value cstCephesLogP3 = bcast(f32Cst(builder, -1.2420140846E-1f));
670  Value cstCephesLogP4 = bcast(f32Cst(builder, +1.4249322787E-1f));
671  Value cstCephesLogP5 = bcast(f32Cst(builder, -1.6668057665E-1f));
672  Value cstCephesLogP6 = bcast(f32Cst(builder, +2.0000714765E-1f));
673  Value cstCephesLogP7 = bcast(f32Cst(builder, -2.4999993993E-1f));
674  Value cstCephesLogP8 = bcast(f32Cst(builder, +3.3333331174E-1f));
675 
676  Value x = op.getOperand();
677 
678  // Truncate input values to the minimum positive normal.
679  x = max(builder, x, cstMinNormPos);
680 
681  // Extract significant in the range [0.5,1) and exponent.
682  std::pair<Value, Value> pair = frexp(builder, x, /*isPositive=*/true);
683  x = pair.first;
684  Value e = pair.second;
685 
686  // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift
687  // by -1.0. The values are then centered around 0, which improves the
688  // stability of the polynomial evaluation:
689  //
690  // if( x < SQRTHF ) {
691  // e -= 1;
692  // x = x + x - 1.0;
693  // } else { x = x - 1.0; }
694  Value mask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x,
695  cstCephesSQRTHF);
696  Value tmp = builder.create<arith::SelectOp>(mask, x, cstZero);
697 
698  x = builder.create<arith::SubFOp>(x, cstOne);
699  e = builder.create<arith::SubFOp>(
700  e, builder.create<arith::SelectOp>(mask, cstOne, cstZero));
701  x = builder.create<arith::AddFOp>(x, tmp);
702 
703  Value x2 = builder.create<arith::MulFOp>(x, x);
704  Value x3 = builder.create<arith::MulFOp>(x2, x);
705 
706  // Evaluate the polynomial approximant of degree 8 in three parts.
707  Value y0, y1, y2;
708  y0 = builder.create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1);
709  y1 = builder.create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4);
710  y2 = builder.create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7);
711  y0 = builder.create<math::FmaOp>(y0, x, cstCephesLogP2);
712  y1 = builder.create<math::FmaOp>(y1, x, cstCephesLogP5);
713  y2 = builder.create<math::FmaOp>(y2, x, cstCephesLogP8);
714  y0 = builder.create<math::FmaOp>(y0, x3, y1);
715  y0 = builder.create<math::FmaOp>(y0, x3, y2);
716  y0 = builder.create<arith::MulFOp>(y0, x3);
717 
718  y0 = builder.create<math::FmaOp>(cstNegHalf, x2, y0);
719  x = builder.create<arith::AddFOp>(x, y0);
720 
721  if (base2) {
722  Value cstLog2e = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE)));
723  x = builder.create<math::FmaOp>(x, cstLog2e, e);
724  } else {
725  Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE)));
726  x = builder.create<math::FmaOp>(e, cstLn2, x);
727  }
728 
729  Value invalidMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT,
730  op.getOperand(), cstZero);
731  Value zeroMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
732  op.getOperand(), cstZero);
733  Value posInfMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
734  op.getOperand(), cstPosInf);
735 
736  // Filter out invalid values:
737  // • x == 0 -> -INF
738  // • x < 0 -> NAN
739  // • x == +INF -> +INF
740  Value aproximation = builder.create<arith::SelectOp>(
741  zeroMask, cstMinusInf,
742  builder.create<arith::SelectOp>(
743  invalidMask, cstNan,
744  builder.create<arith::SelectOp>(posInfMask, cstPosInf, x)));
745 
746  rewriter.replaceOp(op, aproximation);
747 
748  return success();
749 }
750 
751 namespace {
752 struct LogApproximation : public LogApproximationBase<math::LogOp> {
753  using LogApproximationBase::LogApproximationBase;
754 
755  LogicalResult matchAndRewrite(math::LogOp op,
756  PatternRewriter &rewriter) const final {
757  return logMatchAndRewrite(op, rewriter, /*base2=*/false);
758  }
759 };
760 } // namespace
761 
762 namespace {
763 struct Log2Approximation : public LogApproximationBase<math::Log2Op> {
764  using LogApproximationBase::LogApproximationBase;
765 
766  LogicalResult matchAndRewrite(math::Log2Op op,
767  PatternRewriter &rewriter) const final {
768  return logMatchAndRewrite(op, rewriter, /*base2=*/true);
769  }
770 };
771 } // namespace
772 
773 //----------------------------------------------------------------------------//
774 // Log1p approximation.
775 //----------------------------------------------------------------------------//
776 
777 namespace {
778 struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> {
779 public:
781 
782  LogicalResult matchAndRewrite(math::Log1pOp op,
783  PatternRewriter &rewriter) const final;
784 };
785 } // namespace
786 
787 // Approximate log(1+x).
788 LogicalResult
789 Log1pApproximation::matchAndRewrite(math::Log1pOp op,
790  PatternRewriter &rewriter) const {
792  return rewriter.notifyMatchFailure(op, "unsupported operand type");
793 
794  VectorShape shape = vectorShape(op.getOperand());
795 
796  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
797  auto bcast = [&](Value value) -> Value {
798  return broadcast(builder, value, shape);
799  };
800 
801  // Approximate log(1+x) using the following, due to W. Kahan:
802  // u = x + 1.0;
803  // if (u == 1.0 || u == inf) return x;
804  // return x * log(u) / (u - 1.0);
805  // ^^^^^^^^^^^^^^^^^^^^^^
806  // "logLarge" below.
807  Value cstOne = bcast(f32Cst(builder, 1.0f));
808  Value x = op.getOperand();
809  Value u = builder.create<arith::AddFOp>(x, cstOne);
810  Value uSmall =
811  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne);
812  Value logU = builder.create<math::LogOp>(u);
813  Value uInf =
814  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU);
815  Value logLarge = builder.create<arith::MulFOp>(
816  x, builder.create<arith::DivFOp>(
817  logU, builder.create<arith::SubFOp>(u, cstOne)));
818  Value approximation = builder.create<arith::SelectOp>(
819  builder.create<arith::OrIOp>(uSmall, uInf), x, logLarge);
820  rewriter.replaceOp(op, approximation);
821  return success();
822 }
823 
824 //----------------------------------------------------------------------------//
825 // Asin approximation.
826 //----------------------------------------------------------------------------//
827 
828 // Approximates asin(x).
829 // This approximation is based on the following stackoverflow post:
830 // https://stackoverflow.com/a/42683455
831 namespace {
832 struct AsinPolynomialApproximation : public OpRewritePattern<math::AsinOp> {
833 public:
835 
836  LogicalResult matchAndRewrite(math::AsinOp op,
837  PatternRewriter &rewriter) const final;
838 };
839 } // namespace
840 LogicalResult
841 AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
842  PatternRewriter &rewriter) const {
843  Value operand = op.getOperand();
844  Type elementType = getElementTypeOrSelf(operand);
845 
846  if (!(elementType.isF32() || elementType.isF16()))
847  return rewriter.notifyMatchFailure(op,
848  "only f32 and f16 type is supported.");
849  VectorShape shape = vectorShape(operand);
850 
851  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
852  auto bcast = [&](Value value) -> Value {
853  return broadcast(builder, value, shape);
854  };
855 
856  auto fma = [&](Value a, Value b, Value c) -> Value {
857  return builder.create<math::FmaOp>(a, b, c);
858  };
859 
860  auto mul = [&](Value a, Value b) -> Value {
861  return builder.create<arith::MulFOp>(a, b);
862  };
863 
864  Value s = mul(operand, operand);
865  Value q = mul(s, s);
866  Value r = bcast(floatCst(builder, 5.5579749017470502e-2, elementType));
867  Value t = bcast(floatCst(builder, -6.2027913464120114e-2, elementType));
868 
869  r = fma(r, q, bcast(floatCst(builder, 5.4224464349245036e-2, elementType)));
870  t = fma(t, q, bcast(floatCst(builder, -1.1326992890324464e-2, elementType)));
871  r = fma(r, q, bcast(floatCst(builder, 1.5268872539397656e-2, elementType)));
872  t = fma(t, q, bcast(floatCst(builder, 1.0493798473372081e-2, elementType)));
873  r = fma(r, q, bcast(floatCst(builder, 1.4106045900607047e-2, elementType)));
874  t = fma(t, q, bcast(floatCst(builder, 1.7339776384962050e-2, elementType)));
875  r = fma(r, q, bcast(floatCst(builder, 2.2372961589651054e-2, elementType)));
876  t = fma(t, q, bcast(floatCst(builder, 3.0381912707941005e-2, elementType)));
877  r = fma(r, q, bcast(floatCst(builder, 4.4642857881094775e-2, elementType)));
878  t = fma(t, q, bcast(floatCst(builder, 7.4999999991367292e-2, elementType)));
879  r = fma(r, s, t);
880  r = fma(r, s, bcast(floatCst(builder, 1.6666666666670193e-1, elementType)));
881  t = mul(operand, s);
882  r = fma(r, t, operand);
883 
884  rewriter.replaceOp(op, r);
885  return success();
886 }
887 
888 //----------------------------------------------------------------------------//
889 // Acos approximation.
890 //----------------------------------------------------------------------------//
891 
892 // Approximates acos(x).
893 // This approximation is based on the following stackoverflow post:
894 // https://stackoverflow.com/a/42683455
895 namespace {
896 struct AcosPolynomialApproximation : public OpRewritePattern<math::AcosOp> {
897 public:
899 
900  LogicalResult matchAndRewrite(math::AcosOp op,
901  PatternRewriter &rewriter) const final;
902 };
903 } // namespace
904 LogicalResult
905 AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
906  PatternRewriter &rewriter) const {
907  Value operand = op.getOperand();
908  Type elementType = getElementTypeOrSelf(operand);
909 
910  if (!(elementType.isF32() || elementType.isF16()))
911  return rewriter.notifyMatchFailure(op,
912  "only f32 and f16 type is supported.");
913  VectorShape shape = vectorShape(operand);
914 
915  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
916  auto bcast = [&](Value value) -> Value {
917  return broadcast(builder, value, shape);
918  };
919 
920  auto fma = [&](Value a, Value b, Value c) -> Value {
921  return builder.create<math::FmaOp>(a, b, c);
922  };
923 
924  auto mul = [&](Value a, Value b) -> Value {
925  return builder.create<arith::MulFOp>(a, b);
926  };
927 
928  Value negOperand = builder.create<arith::NegFOp>(operand);
929  Value zero = bcast(floatCst(builder, 0.0, elementType));
930  Value half = bcast(floatCst(builder, 0.5, elementType));
931  Value negOne = bcast(floatCst(builder, -1.0, elementType));
932  Value selR =
933  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand, zero);
934  Value r = builder.create<arith::SelectOp>(selR, negOperand, operand);
935  Value chkConst = bcast(floatCst(builder, -0.5625, elementType));
936  Value firstPred =
937  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, r, chkConst);
938 
939  Value trueVal =
940  fma(bcast(floatCst(builder, 9.3282184640716537e-1, elementType)),
941  bcast(floatCst(builder, 1.6839188885261840e+0, elementType)),
942  builder.create<math::AsinOp>(r));
943 
944  Value falseVal = builder.create<math::SqrtOp>(fma(half, r, half));
945  falseVal = builder.create<math::AsinOp>(falseVal);
946  falseVal = mul(bcast(floatCst(builder, 2.0, elementType)), falseVal);
947 
948  r = builder.create<arith::SelectOp>(firstPred, trueVal, falseVal);
949 
950  // Check whether the operand lies in between [-1.0, 0.0).
951  Value greaterThanNegOne =
952  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, negOne);
953 
954  Value lessThanZero =
955  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
956 
957  Value betweenNegOneZero =
958  builder.create<arith::AndIOp>(greaterThanNegOne, lessThanZero);
959 
960  trueVal = fma(bcast(floatCst(builder, 1.8656436928143307e+0, elementType)),
961  bcast(floatCst(builder, 1.6839188885261840e+0, elementType)),
962  builder.create<arith::NegFOp>(r));
963 
964  Value finalVal =
965  builder.create<arith::SelectOp>(betweenNegOneZero, trueVal, r);
966 
967  rewriter.replaceOp(op, finalVal);
968  return success();
969 }
970 
971 //----------------------------------------------------------------------------//
972 // Erf approximation.
973 //----------------------------------------------------------------------------//
974 
975 // Approximates erf(x) with
976 // a - P(x)/Q(x)
977 // where P and Q are polynomials of degree 4.
978 // Different coefficients are chosen based on the value of x.
979 // The approximation error is ~2.5e-07.
980 // Boost's minimax tool that utilizes the Remez method was used to find the
981 // coefficients.
982 LogicalResult
984  PatternRewriter &rewriter) const {
985  Value operand = op.getOperand();
986  Type elementType = getElementTypeOrSelf(operand);
987 
988  if (!(elementType.isF32() || elementType.isF16()))
989  return rewriter.notifyMatchFailure(op,
990  "only f32 and f16 type is supported.");
991  VectorShape shape = vectorShape(operand);
992 
993  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
994  auto bcast = [&](Value value) -> Value {
995  return broadcast(builder, value, shape);
996  };
997 
998  const int intervalsCount = 3;
999  const int polyDegree = 4;
1000 
1001  Value zero = bcast(floatCst(builder, 0, elementType));
1002  Value one = bcast(floatCst(builder, 1, elementType));
1003  Value pp[intervalsCount][polyDegree + 1];
1004  pp[0][0] = bcast(floatCst(builder, +0.00000000000000000e+00f, elementType));
1005  pp[0][1] = bcast(floatCst(builder, +1.12837916222975858e+00f, elementType));
1006  pp[0][2] = bcast(floatCst(builder, -5.23018562988006470e-01f, elementType));
1007  pp[0][3] = bcast(floatCst(builder, +2.09741709609267072e-01f, elementType));
1008  pp[0][4] = bcast(floatCst(builder, +2.58146801602987875e-02f, elementType));
1009  pp[1][0] = bcast(floatCst(builder, +0.00000000000000000e+00f, elementType));
1010  pp[1][1] = bcast(floatCst(builder, +1.12750687816789140e+00f, elementType));
1011  pp[1][2] = bcast(floatCst(builder, -3.64721408487825775e-01f, elementType));
1012  pp[1][3] = bcast(floatCst(builder, +1.18407396425136952e-01f, elementType));
1013  pp[1][4] = bcast(floatCst(builder, +3.70645533056476558e-02f, elementType));
1014  pp[2][0] = bcast(floatCst(builder, -3.30093071049483172e-03f, elementType));
1015  pp[2][1] = bcast(floatCst(builder, +3.51961938357697011e-03f, elementType));
1016  pp[2][2] = bcast(floatCst(builder, -1.41373622814988039e-03f, elementType));
1017  pp[2][3] = bcast(floatCst(builder, +2.53447094961941348e-04f, elementType));
1018  pp[2][4] = bcast(floatCst(builder, -1.71048029455037401e-05f, elementType));
1019 
1020  Value qq[intervalsCount][polyDegree + 1];
1021  qq[0][0] = bcast(floatCst(builder, +1.000000000000000000e+00f, elementType));
1022  qq[0][1] = bcast(floatCst(builder, -4.635138185962547255e-01f, elementType));
1023  qq[0][2] = bcast(floatCst(builder, +5.192301327279782447e-01f, elementType));
1024  qq[0][3] = bcast(floatCst(builder, -1.318089722204810087e-01f, elementType));
1025  qq[0][4] = bcast(floatCst(builder, +7.397964654672315005e-02f, elementType));
1026  qq[1][0] = bcast(floatCst(builder, +1.00000000000000000e+00f, elementType));
1027  qq[1][1] = bcast(floatCst(builder, -3.27607011824493086e-01f, elementType));
1028  qq[1][2] = bcast(floatCst(builder, +4.48369090658821977e-01f, elementType));
1029  qq[1][3] = bcast(floatCst(builder, -8.83462621207857930e-02f, elementType));
1030  qq[1][4] = bcast(floatCst(builder, +5.72442770283176093e-02f, elementType));
1031  qq[2][0] = bcast(floatCst(builder, +1.00000000000000000e+00f, elementType));
1032  qq[2][1] = bcast(floatCst(builder, -2.06069165953913769e+00f, elementType));
1033  qq[2][2] = bcast(floatCst(builder, +1.62705939945477759e+00f, elementType));
1034  qq[2][3] = bcast(floatCst(builder, -5.83389859211130017e-01f, elementType));
1035  qq[2][4] = bcast(floatCst(builder, +8.21908939856640930e-02f, elementType));
1036 
1037  Value offsets[intervalsCount];
1038  offsets[0] = bcast(floatCst(builder, 0.0f, elementType));
1039  offsets[1] = bcast(floatCst(builder, 0.0f, elementType));
1040  offsets[2] = bcast(floatCst(builder, 1.0f, elementType));
1041 
1042  Value bounds[intervalsCount];
1043  bounds[0] = bcast(floatCst(builder, 0.8f, elementType));
1044  bounds[1] = bcast(floatCst(builder, 2.0f, elementType));
1045  bounds[2] = bcast(floatCst(builder, 3.75f, elementType));
1046 
1047  Value isNegativeArg =
1048  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
1049  Value negArg = builder.create<arith::NegFOp>(operand);
1050  Value x = builder.create<arith::SelectOp>(isNegativeArg, negArg, operand);
1051 
1052  Value offset = offsets[0];
1053  Value p[polyDegree + 1];
1054  Value q[polyDegree + 1];
1055  for (int i = 0; i <= polyDegree; ++i) {
1056  p[i] = pp[0][i];
1057  q[i] = qq[0][i];
1058  }
1059 
1060  // TODO: maybe use vector stacking to reduce the number of selects.
1061  Value isLessThanBound[intervalsCount];
1062  for (int j = 0; j < intervalsCount - 1; ++j) {
1063  isLessThanBound[j] =
1064  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[j]);
1065  for (int i = 0; i <= polyDegree; ++i) {
1066  p[i] = builder.create<arith::SelectOp>(isLessThanBound[j], p[i],
1067  pp[j + 1][i]);
1068  q[i] = builder.create<arith::SelectOp>(isLessThanBound[j], q[i],
1069  qq[j + 1][i]);
1070  }
1071  offset = builder.create<arith::SelectOp>(isLessThanBound[j], offset,
1072  offsets[j + 1]);
1073  }
1074  isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>(
1075  arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
1076 
1077  Value pPoly = makePolynomialCalculation(builder, p, x);
1078  Value qPoly = makePolynomialCalculation(builder, q, x);
1079  Value rationalPoly = builder.create<arith::DivFOp>(pPoly, qPoly);
1080  Value formula = builder.create<arith::AddFOp>(offset, rationalPoly);
1081  formula = builder.create<arith::SelectOp>(isLessThanBound[intervalsCount - 1],
1082  formula, one);
1083 
1084  // erf is odd function: erf(x) = -erf(-x).
1085  Value negFormula = builder.create<arith::NegFOp>(formula);
1086  Value res =
1087  builder.create<arith::SelectOp>(isNegativeArg, negFormula, formula);
1088 
1089  rewriter.replaceOp(op, res);
1090 
1091  return success();
1092 }
1093 
1094 //----------------------------------------------------------------------------//
1095 // Exp approximation.
1096 //----------------------------------------------------------------------------//
1097 
1098 namespace {
1099 
1100 Value clampWithNormals(ImplicitLocOpBuilder &builder, const VectorShape shape,
1101  Value value, float lowerBound, float upperBound) {
1102  assert(!std::isnan(lowerBound));
1103  assert(!std::isnan(upperBound));
1104 
1105  auto bcast = [&](Value value) -> Value {
1106  return broadcast(builder, value, shape);
1107  };
1108 
1109  auto selectCmp = [&builder](auto pred, Value value, Value bound) {
1110  return builder.create<arith::SelectOp>(
1111  builder.create<arith::CmpFOp>(pred, value, bound), value, bound);
1112  };
1113 
1114  // Note: prefer UGE/ULE vs. UGT/ULT, since they generate vmaxps/vminps vs.
1115  // vcmpleps+vmovaps on x86_64. The latter outcome is also obtained with
1116  // arith::{Max,Min}FOp.
1117  value = selectCmp(arith::CmpFPredicate::UGE, value,
1118  bcast(f32Cst(builder, lowerBound)));
1119  value = selectCmp(arith::CmpFPredicate::ULE, value,
1120  bcast(f32Cst(builder, upperBound)));
1121  return value;
1122 }
1123 
1124 struct ExpApproximation : public OpRewritePattern<math::ExpOp> {
1125 public:
1127 
1128  LogicalResult matchAndRewrite(math::ExpOp op,
1129  PatternRewriter &rewriter) const final;
1130 };
1131 
1132 LogicalResult
1133 ExpApproximation::matchAndRewrite(math::ExpOp op,
1134  PatternRewriter &rewriter) const {
1135  auto shape = vectorShape(op.getOperand().getType());
1136  auto elementTy = getElementTypeOrSelf(op.getType());
1137  if (!elementTy.isF32())
1138  return rewriter.notifyMatchFailure(op, "unsupported operand type");
1139 
1140  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1141 
1142  auto add = [&](Value a, Value b) -> Value {
1143  return builder.create<arith::AddFOp>(a, b);
1144  };
1145  auto bcast = [&](Value value) -> Value {
1146  return broadcast(builder, value, shape);
1147  };
1148  auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
1149  auto fmla = [&](Value a, Value b, Value c) {
1150  return builder.create<math::FmaOp>(a, b, c);
1151  };
1152  auto mul = [&](Value a, Value b) -> Value {
1153  return builder.create<arith::MulFOp>(a, b);
1154  };
1155 
1156  // Polynomial approximation from Cephes.
1157  //
1158  // To compute e^x, we re-express it as
1159  //
1160  // e^x = e^(a + b)
1161  // = e^(a + n log(2))
1162  // = e^a * 2^n.
1163  //
1164  // We choose n = round(x / log(2)), restricting the value of `a` to
1165  // (-log(2)/2, log(2)/2). We then use a polynomial to compute e^a. The
1166  // relative error between our approximation and the true value of e^a is less
1167  // than 2^-22.5 for all values of `a` within this range.
1168 
1169  // Restrict input to a small range, including some values that evaluate to
1170  // +/- inf. Note that for our lower bound, we choose log(2^-126) instead of
1171  // log(F32_EPSILON). We do so because this routine always flushes denormal
1172  // floating points to 0. Therefore, we only need to worry about exponentiating
1173  // up to the smallest representable non-denormal floating point, which is
1174  // 2^-126.
1175 
1176  // Constants.
1177  Value cstHalf = bcast(f32Cst(builder, 0.5f));
1178  Value cstOne = bcast(f32Cst(builder, 1.0f));
1179 
1180  // 1/log(2)
1181  Value cstLog2ef = bcast(f32Cst(builder, 1.44269504088896341f));
1182 
1183  Value cstExpC1 = bcast(f32Cst(builder, -0.693359375f));
1184  Value cstExpC2 = bcast(f32Cst(builder, 2.12194440e-4f));
1185  Value cstExpP0 = bcast(f32Cst(builder, 1.9875691500E-4f));
1186  Value cstExpP1 = bcast(f32Cst(builder, 1.3981999507E-3f));
1187  Value cstExpP2 = bcast(f32Cst(builder, 8.3334519073E-3f));
1188  Value cstExpP3 = bcast(f32Cst(builder, 4.1665795894E-2f));
1189  Value cstExpP4 = bcast(f32Cst(builder, 1.6666665459E-1f));
1190  Value cstExpP5 = bcast(f32Cst(builder, 5.0000001201E-1f));
1191 
1192  // Our computations below aren't particularly sensitive to the exact choices
1193  // here, so we choose values a bit larger/smaller than
1194  //
1195  // log(F32_MAX) = 88.723...
1196  // log(2^-126) = -87.337...
1197  Value x = op.getOperand();
1198  x = clampWithNormals(builder, shape, x, -87.8f, 88.8f);
1199  Value n = floor(fmla(x, cstLog2ef, cstHalf));
1200 
1201  // When we eventually do the multiplication in e^a * 2^n, we need to handle
1202  // the case when n > 127, the max fp32 exponent (so 2^n == inf) but e^a < 1
1203  // (so e^a * 2^n != inf). There's a similar problem for n < -126, the
1204  // smallest fp32 exponent.
1205  //
1206  // A straightforward solution would be to detect n out of range and split it
1207  // up, doing
1208  //
1209  // e^a * 2^n = e^a * 2^(n1 + n2)
1210  // = (2^n1 * e^a) * 2^n2.
1211  //
1212  // But it turns out this approach is quite slow, probably because it
1213  // manipulates subnormal values.
1214  //
1215  // The approach we use instead is to clamp n to [-127, 127]. Let n' be the
1216  // value of n clamped to [-127, 127]. In the case where n' = 127, `a` can grow
1217  // up to as large as 88.8 - 127 * log(2) which is about 0.7703. Even though
1218  // this value of `a` is outside our previously specified range, e^a will still
1219  // only have a relative error of approximately 2^-16 at worse. In practice
1220  // this seems to work well enough; it passes our exhaustive tests, breaking
1221  // only one result, and by one ulp (we return exp(88.7228394) = max-float but
1222  // we should return inf).
1223  //
1224  // In the case where n' = -127, the original input value of x is so small that
1225  // e^x, our final answer, is less than 2^-126. Since 2^-126 is the smallest
1226  // normal floating point, and since we flush denormals, we simply return 0. We
1227  // do this in a branchless way by observing that our code for constructing 2^n
1228  // produces 0 if n = -127.
1229  //
1230  // The proof that n' = -127 implies e^x < 2^-126 is as follows:
1231  //
1232  // n' = -127 implies n <= -127
1233  // implies round(x / log(2)) <= -127
1234  // implies x/log(2) < -126.5
1235  // implies x < -126.5 * log(2)
1236  // implies e^x < e^(-126.5 * log(2))
1237  // implies e^x < 2^-126.5 < 2^-126
1238  //
1239  // This proves that n' = -127 implies e^x < 2^-126.
1240  n = clampWithNormals(builder, shape, n, -127.0f, 127.0f);
1241 
1242  // Computes x = x - n' * log(2), the value for `a`
1243  x = fmla(cstExpC1, n, x);
1244  x = fmla(cstExpC2, n, x);
1245 
1246  // Polynomial to compute z = e^a, accurate for a in (-0.5, 0.5).
1247  Value z = fmla(x, cstExpP0, cstExpP1);
1248  z = fmla(z, x, cstExpP2);
1249  z = fmla(z, x, cstExpP3);
1250  z = fmla(z, x, cstExpP4);
1251  z = fmla(z, x, cstExpP5);
1252  z = fmla(z, mul(x, x), x);
1253  z = add(cstOne, z);
1254 
1255  // Convert n' to an i32. This is safe because we clamped it above.
1256  auto i32Vec = broadcast(builder.getI32Type(), shape);
1257  Value nI32 = builder.create<arith::FPToSIOp>(i32Vec, n);
1258 
1259  // Creates the value 2^n' if -126 <= n' <= 127 and 0 if n' = -127.
1260  Value pow2 = exp2I32(builder, nI32);
1261 
1262  // Return z * 2^n' if -126 <= n' <= 127 and 0 if n = -127.
1263  Value ret = mul(z, pow2);
1264 
1265  rewriter.replaceOp(op, ret);
1266  return mlir::success();
1267 }
1268 
1269 } // namespace
1270 
1271 //----------------------------------------------------------------------------//
1272 // ExpM1 approximation.
1273 //----------------------------------------------------------------------------//
1274 
1275 namespace {
1276 
1277 struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> {
1278 public:
1280 
1281  LogicalResult matchAndRewrite(math::ExpM1Op op,
1282  PatternRewriter &rewriter) const final;
1283 };
1284 } // namespace
1285 
1286 LogicalResult
1287 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
1288  PatternRewriter &rewriter) const {
1289  if (!getElementTypeOrSelf(op.getOperand()).isF32())
1290  return rewriter.notifyMatchFailure(op, "unsupported operand type");
1291 
1292  VectorShape shape = vectorShape(op.getOperand());
1293 
1294  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1295  auto bcast = [&](Value value) -> Value {
1296  return broadcast(builder, value, shape);
1297  };
1298 
1299  // expm1(x) = exp(x) - 1 = u - 1.
1300  // We have to handle it carefully when x is near 0, i.e. u ~= 1,
1301  // and when the input is ~= -inf, i.e. u - 1 ~= -1.
1302  Value cstOne = bcast(f32Cst(builder, 1.0f));
1303  Value cstNegOne = bcast(f32Cst(builder, -1.0f));
1304  Value x = op.getOperand();
1305  Value u = builder.create<math::ExpOp>(x);
1306  Value uEqOneOrNaN =
1307  builder.create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
1308  Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne);
1309  Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>(
1310  arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
1311  // logU = log(u) ~= x
1312  Value logU = builder.create<math::LogOp>(u);
1313 
1314  // Detect exp(x) = +inf; written this way to avoid having to form +inf.
1315  Value isInf =
1316  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u);
1317 
1318  // (u - 1) * (x / ~x)
1319  Value expm1 = builder.create<arith::MulFOp>(
1320  uMinusOne, builder.create<arith::DivFOp>(x, logU));
1321  expm1 = builder.create<arith::SelectOp>(isInf, u, expm1);
1322  Value approximation = builder.create<arith::SelectOp>(
1323  uEqOneOrNaN, x,
1324  builder.create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
1325  rewriter.replaceOp(op, approximation);
1326  return success();
1327 }
1328 
1329 //----------------------------------------------------------------------------//
1330 // Sin and Cos approximation.
1331 //----------------------------------------------------------------------------//
1332 
1333 namespace {
1334 
1335 template <bool isSine, typename OpTy>
1336 struct SinAndCosApproximation : public OpRewritePattern<OpTy> {
1337 public:
1339 
1340  LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
1341 };
1342 } // namespace
1343 
1344 #define TWO_OVER_PI \
1345  0.6366197723675813430755350534900574481378385829618257949906693762L
1346 #define PI_OVER_2 \
1347  1.5707963267948966192313216916397514420985846996875529104874722961L
1348 
1349 // Approximates sin(x) or cos(x) by finding the best approximation polynomial in
1350 // the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the
1351 // reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y).
1352 template <bool isSine, typename OpTy>
1353 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
1354  OpTy op, PatternRewriter &rewriter) const {
1355  static_assert(
1356  llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
1357  "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
1358 
1359  if (!getElementTypeOrSelf(op.getOperand()).isF32())
1360  return rewriter.notifyMatchFailure(op, "unsupported operand type");
1361 
1362  VectorShape shape = vectorShape(op.getOperand());
1363 
1364  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1365  auto bcast = [&](Value value) -> Value {
1366  return broadcast(builder, value, shape);
1367  };
1368  auto mul = [&](Value a, Value b) -> Value {
1369  return builder.create<arith::MulFOp>(a, b);
1370  };
1371  auto sub = [&](Value a, Value b) -> Value {
1372  return builder.create<arith::SubFOp>(a, b);
1373  };
1374  auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
1375 
1376  auto i32Vec = broadcast(builder.getI32Type(), shape);
1377  auto fPToSingedInteger = [&](Value a) -> Value {
1378  return builder.create<arith::FPToSIOp>(i32Vec, a);
1379  };
1380 
1381  auto modulo4 = [&](Value a) -> Value {
1382  return builder.create<arith::AndIOp>(a, bcast(i32Cst(builder, 3)));
1383  };
1384 
1385  auto isEqualTo = [&](Value a, Value b) -> Value {
1386  return builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b);
1387  };
1388 
1389  auto isGreaterThan = [&](Value a, Value b) -> Value {
1390  return builder.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b);
1391  };
1392 
1393  auto select = [&](Value cond, Value t, Value f) -> Value {
1394  return builder.create<arith::SelectOp>(cond, t, f);
1395  };
1396 
1397  auto fmla = [&](Value a, Value b, Value c) {
1398  return builder.create<math::FmaOp>(a, b, c);
1399  };
1400 
1401  auto bitwiseOr = [&](Value a, Value b) {
1402  return builder.create<arith::OrIOp>(a, b);
1403  };
1404 
1405  Value twoOverPi = bcast(f32Cst(builder, (float)TWO_OVER_PI));
1406  Value piOverTwo = bcast(f32Cst(builder, (float)PI_OVER_2));
1407 
1408  Value x = op.getOperand();
1409 
1410  Value k = floor(mul(x, twoOverPi));
1411 
1412  Value y = sub(x, mul(k, piOverTwo));
1413 
1414  Value cstOne = bcast(f32Cst(builder, 1.0));
1415  Value cstNegativeOne = bcast(f32Cst(builder, -1.0));
1416 
1417  Value cstSC2 = bcast(f32Cst(builder, -0.16666667163372039794921875f));
1418  Value cstSC4 = bcast(f32Cst(builder, 8.333347737789154052734375e-3f));
1419  Value cstSC6 = bcast(f32Cst(builder, -1.9842604524455964565277099609375e-4f));
1420  Value cstSC8 =
1421  bcast(f32Cst(builder, 2.760012648650445044040679931640625e-6f));
1422  Value cstSC10 =
1423  bcast(f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f));
1424 
1425  Value cstCC2 = bcast(f32Cst(builder, -0.5f));
1426  Value cstCC4 = bcast(f32Cst(builder, 4.166664183139801025390625e-2f));
1427  Value cstCC6 = bcast(f32Cst(builder, -1.388833043165504932403564453125e-3f));
1428  Value cstCC8 = bcast(f32Cst(builder, 2.47562347794882953166961669921875e-5f));
1429  Value cstCC10 =
1430  bcast(f32Cst(builder, -2.59630184018533327616751194000244140625e-7f));
1431 
1432  Value kMod4 = modulo4(fPToSingedInteger(k));
1433 
1434  Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, 0)));
1435  Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, 1)));
1436  Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, 2)));
1437  Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, 3)));
1438 
1439  Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
1440  Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, 1)))
1441  : bitwiseOr(kR1, kR2);
1442 
1443  Value y2 = mul(y, y);
1444 
1445  Value base = select(sinuseCos, cstOne, y);
1446  Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
1447  Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
1448  Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
1449  Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
1450  Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
1451 
1452  Value v1 = fmla(y2, cstC10, cstC8);
1453  Value v2 = fmla(y2, v1, cstC6);
1454  Value v3 = fmla(y2, v2, cstC4);
1455  Value v4 = fmla(y2, v3, cstC2);
1456  Value v5 = fmla(y2, v4, cstOne);
1457  Value v6 = mul(base, v5);
1458 
1459  Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6);
1460 
1461  rewriter.replaceOp(op, approximation);
1462 
1463  return success();
1464 }
1465 
1466 //----------------------------------------------------------------------------//
1467 // Cbrt approximation.
1468 //----------------------------------------------------------------------------//
1469 
1470 namespace {
1471 struct CbrtApproximation : public OpRewritePattern<math::CbrtOp> {
1473 
1474  LogicalResult matchAndRewrite(math::CbrtOp op,
1475  PatternRewriter &rewriter) const final;
1476 };
1477 } // namespace
1478 
1479 // Estimation of cube-root using an algorithm defined in
1480 // Hacker's Delight 2nd Edition.
1481 LogicalResult
1482 CbrtApproximation::matchAndRewrite(math::CbrtOp op,
1483  PatternRewriter &rewriter) const {
1484  auto operand = op.getOperand();
1485  if (!getElementTypeOrSelf(operand).isF32())
1486  return rewriter.notifyMatchFailure(op, "unsupported operand type");
1487 
1488  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1489  VectorShape shape = vectorShape(operand);
1490 
1491  Type floatTy = getElementTypeOrSelf(operand.getType());
1492  Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth());
1493 
1494  // Convert to vector types if necessary.
1495  floatTy = broadcast(floatTy, shape);
1496  intTy = broadcast(intTy, shape);
1497 
1498  auto bconst = [&](TypedAttr attr) -> Value {
1499  Value value = b.create<arith::ConstantOp>(attr);
1500  return broadcast(b, value, shape);
1501  };
1502 
1503  // Declare the initial values:
1504  Value intTwo = bconst(b.getI32IntegerAttr(2));
1505  Value intFour = bconst(b.getI32IntegerAttr(4));
1506  Value intEight = bconst(b.getI32IntegerAttr(8));
1507  Value intMagic = bconst(b.getI32IntegerAttr(0x2a5137a0));
1508  Value fpThird = bconst(b.getF32FloatAttr(0.33333333f));
1509  Value fpTwo = bconst(b.getF32FloatAttr(2.0f));
1510  Value fpZero = bconst(b.getF32FloatAttr(0.0f));
1511 
1512  // Compute an approximation of one third:
1513  // union {int ix; float x;};
1514  // x = x0;
1515  // ix = ix/4 + ix/16;
1516  Value absValue = b.create<math::AbsFOp>(operand);
1517  Value intValue = b.create<arith::BitcastOp>(intTy, absValue);
1518  Value divideBy4 = b.create<arith::ShRSIOp>(intValue, intTwo);
1519  Value divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
1520  intValue = b.create<arith::AddIOp>(divideBy4, divideBy16);
1521 
1522  // ix = ix + ix/16;
1523  divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
1524  intValue = b.create<arith::AddIOp>(intValue, divideBy16);
1525 
1526  // ix = ix + ix/256;
1527  Value divideBy256 = b.create<arith::ShRSIOp>(intValue, intEight);
1528  intValue = b.create<arith::AddIOp>(intValue, divideBy256);
1529 
1530  // ix = 0x2a5137a0 + ix;
1531  intValue = b.create<arith::AddIOp>(intValue, intMagic);
1532 
1533  // Perform one newtons step:
1534  // x = 0.33333333f*(2.0f*x + x0/(x*x));
1535  Value floatValue = b.create<arith::BitcastOp>(floatTy, intValue);
1536  Value squared = b.create<arith::MulFOp>(floatValue, floatValue);
1537  Value mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
1538  Value divSquared = b.create<arith::DivFOp>(absValue, squared);
1539  floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
1540  floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
1541 
1542  // x = 0.33333333f*(2.0f*x + x0/(x*x));
1543  squared = b.create<arith::MulFOp>(floatValue, floatValue);
1544  mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
1545  divSquared = b.create<arith::DivFOp>(absValue, squared);
1546  floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
1547  floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
1548 
1549  // Check for zero and restore sign.
1550  Value isZero =
1551  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absValue, fpZero);
1552  floatValue = b.create<arith::SelectOp>(isZero, fpZero, floatValue);
1553  floatValue = b.create<math::CopySignOp>(floatValue, operand);
1554 
1555  rewriter.replaceOp(op, floatValue);
1556  return success();
1557 }
1558 
1559 //----------------------------------------------------------------------------//
1560 // Rsqrt approximation.
1561 //----------------------------------------------------------------------------//
1562 
1563 namespace {
1564 struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> {
1566 
1567  LogicalResult matchAndRewrite(math::RsqrtOp op,
1568  PatternRewriter &rewriter) const final;
1569 };
1570 } // namespace
1571 
1572 LogicalResult
1573 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
1574  PatternRewriter &rewriter) const {
1575  if (!getElementTypeOrSelf(op.getOperand()).isF32())
1576  return rewriter.notifyMatchFailure(op, "unsupported operand type");
1577 
1578  VectorShape shape = vectorShape(op.getOperand());
1579 
1580  // Only support already-vectorized rsqrt's.
1581  if (shape.empty() || shape.sizes.back() % 8 != 0)
1582  return rewriter.notifyMatchFailure(op, "unsupported operand type");
1583 
1584  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1585  auto bcast = [&](Value value) -> Value {
1586  return broadcast(builder, value, shape);
1587  };
1588 
1589  Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));
1590  Value cstOnePointFive = bcast(f32Cst(builder, 1.5f));
1591  Value cstNegHalf = bcast(f32Cst(builder, -0.5f));
1592  Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u));
1593 
1594  Value negHalf = builder.create<arith::MulFOp>(op.getOperand(), cstNegHalf);
1595 
1596  // Select only the inverse sqrt of positive normals (denormals are
1597  // flushed to zero).
1598  Value ltMinMask = builder.create<arith::CmpFOp>(
1599  arith::CmpFPredicate::OLT, op.getOperand(), cstMinNormPos);
1600  Value infMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
1601  op.getOperand(), cstPosInf);
1602  Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask);
1603 
1604  // Compute an approximate result.
1606  builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value {
1607  return builder.create<x86vector::RsqrtOp>(operands);
1608  });
1609 
1610  // Do a single step of Newton-Raphson iteration to improve the approximation.
1611  // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
1612  // It is essential to evaluate the inner term like this because forming
1613  // y_n^2 may over- or underflow.
1614  Value inner = builder.create<arith::MulFOp>(negHalf, yApprox);
1615  Value fma = builder.create<math::FmaOp>(yApprox, inner, cstOnePointFive);
1616  Value yNewton = builder.create<arith::MulFOp>(yApprox, fma);
1617 
1618  // Select the result of the Newton-Raphson step for positive normal arguments.
1619  // For other arguments, choose the output of the intrinsic. This will
1620  // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if
1621  // x is zero or a positive denormalized float (equivalent to flushing positive
1622  // denormalized inputs to zero).
1623  Value res =
1624  builder.create<arith::SelectOp>(notNormalFiniteMask, yApprox, yNewton);
1625  rewriter.replaceOp(op, res);
1626 
1627  return success();
1628 }
1629 
1630 //----------------------------------------------------------------------------//
1631 
1633  RewritePatternSet &patterns) {
1634  patterns.add<TanhApproximation>(patterns.getContext());
1635 }
1636 
1638  RewritePatternSet &patterns) {
1639  patterns.add<ErfPolynomialApproximation>(patterns.getContext());
1640 }
1641 
1643  RewritePatternSet &patterns,
1645  // Patterns for leveraging existing f32 lowerings on other data types.
1646  patterns
1647  .add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
1648  ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
1649  ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
1650  ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
1651  ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
1652  ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
1653  patterns.getContext());
1654 
1655  patterns
1656  .add<AtanApproximation, Atan2Approximation, TanhApproximation,
1657  LogApproximation, Log2Approximation, Log1pApproximation,
1658  ErfPolynomialApproximation, AsinPolynomialApproximation,
1659  AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1660  CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
1661  SinAndCosApproximation<false, math::CosOp>>(patterns.getContext());
1662  if (options.enableAvx2) {
1663  patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
1664  patterns.getContext());
1665  }
1666 }
static llvm::ManagedStatic< PassManagerOptions > options
static std::pair< Value, Value > frexp(ImplicitLocOpBuilder &builder, Value arg, bool isPositive=false)
#define LN2_VALUE
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg)
#define PI_OVER_2
#define TWO_OVER_PI
static Type broadcast(Type type, VectorShape shape)
static Value floatCst(ImplicitLocOpBuilder &builder, float value, Type elementType)
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter)
static VectorShape vectorShape(Type type)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value)
#define LOG2E_VALUE
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits)
static Value f32Cst(ImplicitLocOpBuilder &builder, double value)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:220
FloatType getF32Type()
Definition: Builders.cpp:67
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:265
IntegerType getI32Type()
Definition: Builders.cpp:87
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:91
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:335
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:257
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:52
bool isF16() const
Definition: Types.cpp:50
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:126
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
DynamicAPInt floor(const Fraction &f)
Definition: Fraction.h:76
Fraction abs(const Fraction &f)
Definition: Fraction.h:106
Include the generated interface declarations.
void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns)
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
int64_t computeMaxLinearIndex(ArrayRef< int64_t > basis)
Return the number of elements of basis (i.e.
Definition: IndexingUtils.h:69
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns, const MathPolynomialApproximationOptions &options={})
ArrayRef< int64_t > sizes
ArrayRef< bool > scalableFlags
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
LogicalResult matchAndRewrite(math::ErfOp op, PatternRewriter &rewriter) const final
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.