MLIR  19.0.0git
PolynomialApproximation.cpp
Go to the documentation of this file.
1 //===- PolynomialApproximation.cpp - Approximate math operations ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements expansion of math operations to fast approximations
10 // that do not rely on any of the library functions.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <climits>
15 #include <cmath>
16 #include <cstddef>
17 
26 #include "mlir/IR/Builders.h"
27 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/OpDefinition.h"
30 #include "mlir/IR/PatternMatch.h"
31 #include "mlir/IR/TypeUtilities.h"
34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/Support/MathExtras.h"
37 
38 using namespace mlir;
39 using namespace mlir::math;
40 using namespace mlir::vector;
41 
42 // Returns vector shape if the type is a vector. Returns an empty shape if it is
43 // not a vector.
45  auto vectorType = dyn_cast<VectorType>(type);
46  return vectorType ? vectorType.getShape() : ArrayRef<int64_t>();
47 }
48 
50  return vectorShape(value.getType());
51 }
52 
53 //----------------------------------------------------------------------------//
54 // Broadcast scalar types and values into vector types and values.
55 //----------------------------------------------------------------------------//
56 
57 // Broadcasts scalar type into vector type (iff shape is non-scalar).
58 static Type broadcast(Type type, ArrayRef<int64_t> shape) {
59  assert(!isa<VectorType>(type) && "must be scalar type");
60  return !shape.empty() ? VectorType::get(shape, type) : type;
61 }
62 
63 // Broadcasts scalar value into vector (iff shape is non-scalar).
64 static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
65  ArrayRef<int64_t> shape) {
66  assert(!isa<VectorType>(value.getType()) && "must be scalar value");
67  auto type = broadcast(value.getType(), shape);
68  return !shape.empty() ? builder.create<BroadcastOp>(type, value) : value;
69 }
70 
71 //----------------------------------------------------------------------------//
72 // Helper function to handle n-D vectors with 1-D operations.
73 //----------------------------------------------------------------------------//
74 
75 // Expands and unrolls n-D vector operands into multiple fixed size 1-D vectors
76 // and calls the compute function with 1-D vector operands. Stitches back all
77 // results into the original n-D vector result.
78 //
79 // Examples: vectorWidth = 8
80 // - vector<4x8xf32> unrolled 4 times
81 // - vector<16xf32> expanded to vector<2x8xf32> and unrolled 2 times
82 // - vector<4x16xf32> expanded to vector<4x2x8xf32> and unrolled 4*2 times
83 //
84 // Some math approximations rely on ISA-specific operations that only accept
85 // fixed size 1-D vectors (e.g. AVX expects vectors of width 8).
86 //
87 // It is the caller's responsibility to verify that the inner dimension is
88 // divisible by the vectorWidth, and that all operands have the same vector
89 // shape.
90 static Value
92  ValueRange operands, int64_t vectorWidth,
94  assert(!operands.empty() && "operands must be not empty");
95  assert(vectorWidth > 0 && "vector width must be larger than 0");
96 
97  VectorType inputType = cast<VectorType>(operands[0].getType());
98  ArrayRef<int64_t> inputShape = inputType.getShape();
99 
100  // If input shape matches target vector width, we can just call the
101  // user-provided compute function with the operands.
102  if (inputShape == llvm::ArrayRef(vectorWidth))
103  return compute(operands);
104 
105  // Check if the inner dimension has to be expanded, or we can directly iterate
106  // over the outer dimensions of the vector.
107  int64_t innerDim = inputShape.back();
108  int64_t expansionDim = innerDim / vectorWidth;
109  assert((innerDim % vectorWidth == 0) && "invalid inner dimension size");
110 
111  // Maybe expand operands to the higher rank vector shape that we'll use to
112  // iterate over and extract one dimensional vectors.
113  SmallVector<int64_t> expandedShape(inputShape.begin(), inputShape.end());
114  SmallVector<Value> expandedOperands(operands);
115 
116  if (expansionDim > 1) {
117  // Expand shape from [..., innerDim] to [..., expansionDim, vectorWidth].
118  expandedShape.insert(expandedShape.end() - 1, expansionDim);
119  expandedShape.back() = vectorWidth;
120 
121  for (unsigned i = 0; i < operands.size(); ++i) {
122  auto operand = operands[i];
123  auto eltType = cast<VectorType>(operand.getType()).getElementType();
124  auto expandedType = VectorType::get(expandedShape, eltType);
125  expandedOperands[i] =
126  builder.create<vector::ShapeCastOp>(expandedType, operand);
127  }
128  }
129 
130  // Iterate over all outer dimensions of the compute shape vector type.
131  auto iterationDims = ArrayRef<int64_t>(expandedShape).drop_back();
132  int64_t maxIndex = computeMaxLinearIndex(iterationDims);
133  auto strides = computeStrides(iterationDims);
134 
135  // Compute results for each one dimensional vector.
136  SmallVector<Value> results(maxIndex);
137 
138  for (int64_t i = 0; i < maxIndex; ++i) {
139  auto offsets = delinearize(i, strides);
140 
141  SmallVector<Value> extracted(expandedOperands.size());
142  for (const auto &tuple : llvm::enumerate(expandedOperands))
143  extracted[tuple.index()] =
144  builder.create<vector::ExtractOp>(tuple.value(), offsets);
145 
146  results[i] = compute(extracted);
147  }
148 
149  // Stitch results together into one large vector.
150  Type resultEltType = cast<VectorType>(results[0].getType()).getElementType();
151  Type resultExpandedType = VectorType::get(expandedShape, resultEltType);
152  Value result = builder.create<arith::ConstantOp>(
153  resultExpandedType, builder.getZeroAttr(resultExpandedType));
154 
155  for (int64_t i = 0; i < maxIndex; ++i)
156  result = builder.create<vector::InsertOp>(results[i], result,
157  delinearize(i, strides));
158 
159  // Reshape back to the original vector shape.
160  return builder.create<vector::ShapeCastOp>(
161  VectorType::get(inputShape, resultEltType), result);
162 }
163 
164 //----------------------------------------------------------------------------//
165 // Helper functions to create constants.
166 //----------------------------------------------------------------------------//
167 
168 static Value floatCst(ImplicitLocOpBuilder &builder, float value,
169  Type elementType) {
170  assert((elementType.isF16() || elementType.isF32()) &&
171  "x must be f16 or f32 type.");
172  return builder.create<arith::ConstantOp>(
173  builder.getFloatAttr(elementType, value));
174 }
175 
176 static Value f32Cst(ImplicitLocOpBuilder &builder, double value) {
177  return builder.create<arith::ConstantOp>(builder.getF32FloatAttr(value));
178 }
179 
180 static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) {
181  return builder.create<arith::ConstantOp>(builder.getI32IntegerAttr(value));
182 }
183 
184 static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) {
185  Value i32Value = i32Cst(builder, static_cast<int32_t>(bits));
186  return builder.create<arith::BitcastOp>(builder.getF32Type(), i32Value);
187 }
188 
189 //----------------------------------------------------------------------------//
190 // Helper functions to build math functions approximations.
191 //----------------------------------------------------------------------------//
192 
193 // Return the minimum of the two values or NaN if value is NaN
194 static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound) {
195  return builder.create<arith::SelectOp>(
196  builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, value, bound),
197  value, bound);
198 }
199 
200 // Return the maximum of the two values or NaN if value is NaN
201 static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound) {
202  return builder.create<arith::SelectOp>(
203  builder.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, value, bound),
204  value, bound);
205 }
206 
207 // Return the clamped value or NaN if value is NaN
208 static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
209  Value upperBound) {
210  return max(builder, min(builder, value, upperBound), lowerBound);
211 }
212 
213 // Decomposes given floating point value `arg` into a normalized fraction and
214 // an integral power of two (see std::frexp). Returned values have float type.
215 static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
216  bool isPositive = false) {
217  assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
218  ArrayRef<int64_t> shape = vectorShape(arg);
219 
220  auto bcast = [&](Value value) -> Value {
221  return broadcast(builder, value, shape);
222  };
223 
224  auto i32 = builder.getIntegerType(32);
225  auto i32Vec = broadcast(i32, shape);
226  auto f32Vec = broadcast(builder.getF32Type(), shape);
227 
228  Value cst126f = f32Cst(builder, 126.0f);
229  Value cstHalf = f32Cst(builder, 0.5f);
230  Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u);
231 
232  // Bitcast to i32 for bitwise operations.
233  Value i32Half = builder.create<arith::BitcastOp>(i32, cstHalf);
234  Value i32InvMantMask = builder.create<arith::BitcastOp>(i32, cstInvMantMask);
235  Value i32Arg = builder.create<arith::BitcastOp>(i32Vec, arg);
236 
237  // Compute normalized fraction.
238  Value tmp0 = builder.create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask));
239  Value tmp1 = builder.create<arith::OrIOp>(tmp0, bcast(i32Half));
240  Value normalizedFraction = builder.create<arith::BitcastOp>(f32Vec, tmp1);
241 
242  // Compute exponent.
243  Value arg0 = isPositive ? arg : builder.create<math::AbsFOp>(arg);
244  Value biasedExponentBits = builder.create<arith::ShRUIOp>(
245  builder.create<arith::BitcastOp>(i32Vec, arg0),
246  bcast(i32Cst(builder, 23)));
247  Value biasedExponent =
248  builder.create<arith::SIToFPOp>(f32Vec, biasedExponentBits);
249  Value exponent =
250  builder.create<arith::SubFOp>(biasedExponent, bcast(cst126f));
251 
252  return {normalizedFraction, exponent};
253 }
254 
255 // Computes exp2 for an i32 argument.
256 static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
257  assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type");
258  ArrayRef<int64_t> shape = vectorShape(arg);
259 
260  auto bcast = [&](Value value) -> Value {
261  return broadcast(builder, value, shape);
262  };
263 
264  auto f32Vec = broadcast(builder.getF32Type(), shape);
265  // The exponent of f32 located at 23-bit.
266  auto exponetBitLocation = bcast(i32Cst(builder, 23));
267  // Set the exponent bias to zero.
268  auto bias = bcast(i32Cst(builder, 127));
269 
270  Value biasedArg = builder.create<arith::AddIOp>(arg, bias);
271  Value exp2ValueInt =
272  builder.create<arith::ShLIOp>(biasedArg, exponetBitLocation);
273  Value exp2ValueF32 = builder.create<arith::BitcastOp>(f32Vec, exp2ValueInt);
274 
275  return exp2ValueF32;
276 }
277 
278 namespace {
279 Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
280  llvm::ArrayRef<Value> coeffs, Value x) {
281  Type elementType = getElementTypeOrSelf(x);
282  assert((elementType.isF32() || elementType.isF16()) &&
283  "x must be f32 or f16 type");
284  ArrayRef<int64_t> shape = vectorShape(x);
285 
286  if (coeffs.empty())
287  return broadcast(builder, floatCst(builder, 0.0f, elementType), shape);
288 
289  if (coeffs.size() == 1)
290  return coeffs[0];
291 
292  Value res = builder.create<math::FmaOp>(x, coeffs[coeffs.size() - 1],
293  coeffs[coeffs.size() - 2]);
294  for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) {
295  res = builder.create<math::FmaOp>(x, res, coeffs[i]);
296  }
297  return res;
298 }
299 } // namespace
300 
301 //----------------------------------------------------------------------------//
302 // Helper function/pattern to insert casts for reusing F32 bit expansion.
303 //----------------------------------------------------------------------------//
304 
305 template <typename T>
307  // Conservatively only allow where the operand and result types are exactly 1.
308  Type origType = op->getResultTypes().front();
309  for (Type t : llvm::drop_begin(op->getResultTypes()))
310  if (origType != t)
311  return rewriter.notifyMatchFailure(op, "required all types to match");
312  for (Type t : op->getOperandTypes())
313  if (origType != t)
314  return rewriter.notifyMatchFailure(op, "required all types to match");
315 
316  // Skip if already F32 or larger than 32 bits.
317  if (getElementTypeOrSelf(origType).isF32() ||
318  getElementTypeOrSelf(origType).getIntOrFloatBitWidth() > 32)
319  return failure();
320 
321  // Create F32 equivalent type.
322  Type newType;
323  if (auto shaped = dyn_cast<ShapedType>(origType)) {
324  newType = shaped.clone(rewriter.getF32Type());
325  } else if (isa<FloatType>(origType)) {
326  newType = rewriter.getF32Type();
327  } else {
328  return rewriter.notifyMatchFailure(op,
329  "unable to find F32 equivalent type");
330  }
331 
332  Location loc = op->getLoc();
333  SmallVector<Value> operands;
334  for (auto operand : op->getOperands())
335  operands.push_back(rewriter.create<arith::ExtFOp>(loc, newType, operand));
336  auto result =
337  rewriter.create<T>(loc, TypeRange{newType}, operands, op->getAttrs());
338  rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, origType, result);
339  return success();
340 }
341 
342 namespace {
343 // Pattern to cast to F32 to reuse F32 expansion as fallback for single-result
344 // op.
345 // TODO: Consider revising to avoid adding multiple casts for a subgraph that is
346 // all in lower precision. Currently this is only fallback support and performs
347 // simplistic casting.
348 template <typename T>
349 struct ReuseF32Expansion : public OpRewritePattern<T> {
350 public:
352  LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const final {
353  static_assert(
354  T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(),
355  "requires same operands and result types");
356  return insertCasts<T>(op, rewriter);
357  }
358 };
359 } // namespace
360 
361 //----------------------------------------------------------------------------//
362 // AtanOp approximation.
363 //----------------------------------------------------------------------------//
364 
365 namespace {
366 struct AtanApproximation : public OpRewritePattern<math::AtanOp> {
367 public:
369 
370  LogicalResult matchAndRewrite(math::AtanOp op,
371  PatternRewriter &rewriter) const final;
372 };
373 } // namespace
374 
376 AtanApproximation::matchAndRewrite(math::AtanOp op,
377  PatternRewriter &rewriter) const {
378  auto operand = op.getOperand();
379  if (!getElementTypeOrSelf(operand).isF32())
380  return rewriter.notifyMatchFailure(op, "unsupported operand type");
381 
383 
384  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
385  Value abs = builder.create<math::AbsFOp>(operand);
386 
387  auto one = broadcast(builder, f32Cst(builder, 1.0), shape);
388 
389  // When 0.66 < x <= 2.41 we do (x-1) / (x+1):
390  auto twoThirds = broadcast(builder, f32Cst(builder, 0.66), shape);
391  Value cmp2 =
392  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, twoThirds);
393  Value addone = builder.create<arith::AddFOp>(abs, one);
394  Value subone = builder.create<arith::SubFOp>(abs, one);
395  Value xnum = builder.create<arith::SelectOp>(cmp2, subone, abs);
396  Value xden = builder.create<arith::SelectOp>(cmp2, addone, one);
397 
398  auto bcast = [&](Value value) -> Value {
399  return broadcast(builder, value, shape);
400  };
401 
402  // Break into the <= 0.66 or > 2.41 we do x or 1/x:
403  auto tan3pio8 = bcast(f32Cst(builder, 2.41421356237309504880));
404  Value cmp1 =
405  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, abs, tan3pio8);
406  xnum = builder.create<arith::SelectOp>(cmp1, one, xnum);
407  xden = builder.create<arith::SelectOp>(cmp1, abs, xden);
408 
409  Value x = builder.create<arith::DivFOp>(xnum, xden);
410  Value xx = builder.create<arith::MulFOp>(x, x);
411 
412  // Perform the Taylor series approximation for atan over the range
413  // [0.0, 0.66].
414  auto p0 = bcast(f32Cst(builder, -8.750608600031904122785e-01));
415  auto p1 = bcast(f32Cst(builder, -1.615753718733365076637e+01));
416  auto p2 = bcast(f32Cst(builder, -7.500855792314704667340e+01));
417  auto p3 = bcast(f32Cst(builder, -1.228866684490136173410e+02));
418  auto p4 = bcast(f32Cst(builder, -6.485021904942025371773e+01));
419  auto q0 = bcast(f32Cst(builder, +2.485846490142306297962e+01));
420  auto q1 = bcast(f32Cst(builder, +1.650270098316988542046e+02));
421  auto q2 = bcast(f32Cst(builder, +4.328810604912902668951e+02));
422  auto q3 = bcast(f32Cst(builder, +4.853903996359136964868e+02));
423  auto q4 = bcast(f32Cst(builder, +1.945506571482613964425e+02));
424 
425  // Apply the polynomial approximation for the numerator:
426  Value n = p0;
427  n = builder.create<math::FmaOp>(xx, n, p1);
428  n = builder.create<math::FmaOp>(xx, n, p2);
429  n = builder.create<math::FmaOp>(xx, n, p3);
430  n = builder.create<math::FmaOp>(xx, n, p4);
431  n = builder.create<arith::MulFOp>(n, xx);
432 
433  // Apply the polynomial approximation for the denominator:
434  Value d = q0;
435  d = builder.create<math::FmaOp>(xx, d, q1);
436  d = builder.create<math::FmaOp>(xx, d, q2);
437  d = builder.create<math::FmaOp>(xx, d, q3);
438  d = builder.create<math::FmaOp>(xx, d, q4);
439 
440  // Compute approximation of theta:
441  Value ans0 = builder.create<arith::DivFOp>(n, d);
442  ans0 = builder.create<math::FmaOp>(ans0, x, x);
443 
444  // Correct for the input mapping's angles:
445  Value mpi4 = bcast(f32Cst(builder, llvm::numbers::pi / 4));
446  Value ans2 = builder.create<arith::AddFOp>(mpi4, ans0);
447  Value ans = builder.create<arith::SelectOp>(cmp2, ans2, ans0);
448 
449  Value mpi2 = bcast(f32Cst(builder, llvm::numbers::pi / 2));
450  Value ans1 = builder.create<arith::SubFOp>(mpi2, ans0);
451  ans = builder.create<arith::SelectOp>(cmp1, ans1, ans);
452 
453  // Correct for signing of the input.
454  rewriter.replaceOpWithNewOp<math::CopySignOp>(op, ans, operand);
455  return success();
456 }
457 
458 //----------------------------------------------------------------------------//
459 // AtanOp approximation.
460 //----------------------------------------------------------------------------//
461 
462 namespace {
463 struct Atan2Approximation : public OpRewritePattern<math::Atan2Op> {
464 public:
466 
467  LogicalResult matchAndRewrite(math::Atan2Op op,
468  PatternRewriter &rewriter) const final;
469 };
470 } // namespace
471 
473 Atan2Approximation::matchAndRewrite(math::Atan2Op op,
474  PatternRewriter &rewriter) const {
475  auto y = op.getOperand(0);
476  auto x = op.getOperand(1);
477  if (!getElementTypeOrSelf(x).isF32())
478  return rewriter.notifyMatchFailure(op, "unsupported operand type");
479 
480  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
482 
483  // Compute atan in the valid range.
484  auto div = builder.create<arith::DivFOp>(y, x);
485  auto atan = builder.create<math::AtanOp>(div);
486 
487  // Determine what the atan would be for a 180 degree rotation.
488  auto zero = broadcast(builder, f32Cst(builder, 0.0f), shape);
489  auto pi = broadcast(builder, f32Cst(builder, 3.14159265359f), shape);
490  auto addPi = builder.create<arith::AddFOp>(atan, pi);
491  auto subPi = builder.create<arith::SubFOp>(atan, pi);
492  auto atanGt =
493  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, atan, zero);
494  auto flippedAtan = builder.create<arith::SelectOp>(atanGt, subPi, addPi);
495 
496  // Determine whether to directly use atan or use the 180 degree flip
497  auto xGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zero);
498  Value result = builder.create<arith::SelectOp>(xGt, atan, flippedAtan);
499 
500  // Handle x = 0, y > 0
501  Value xZero =
502  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, x, zero);
503  Value yGt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, y, zero);
504  Value isHalfPi = builder.create<arith::AndIOp>(xZero, yGt);
505  auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape);
506  result = builder.create<arith::SelectOp>(isHalfPi, halfPi, result);
507 
508  // Handle x = 0, y < 0
509  Value yLt = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, y, zero);
510  Value isNegativeHalfPiPi = builder.create<arith::AndIOp>(xZero, yLt);
511  auto negativeHalfPiPi =
512  broadcast(builder, f32Cst(builder, -1.57079632679f), shape);
513  result = builder.create<arith::SelectOp>(isNegativeHalfPiPi, negativeHalfPiPi,
514  result);
515 
516  // Handle x = 0, y = 0;
517  Value yZero =
518  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, y, zero);
519  Value isNan = builder.create<arith::AndIOp>(xZero, yZero);
520  Value cstNan = broadcast(builder, f32FromBits(builder, 0x7fc00000), shape);
521  result = builder.create<arith::SelectOp>(isNan, cstNan, result);
522 
523  rewriter.replaceOp(op, result);
524  return success();
525 }
526 
527 //----------------------------------------------------------------------------//
528 // TanhOp approximation.
529 //----------------------------------------------------------------------------//
530 
531 namespace {
532 struct TanhApproximation : public OpRewritePattern<math::TanhOp> {
533 public:
535 
536  LogicalResult matchAndRewrite(math::TanhOp op,
537  PatternRewriter &rewriter) const final;
538 };
539 } // namespace
540 
542 TanhApproximation::matchAndRewrite(math::TanhOp op,
543  PatternRewriter &rewriter) const {
545  return rewriter.notifyMatchFailure(op, "unsupported operand type");
546 
548 
549  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
550  auto bcast = [&](Value value) -> Value {
551  return broadcast(builder, value, shape);
552  };
553 
554  // Clamp operand into [plusClamp, minusClamp] range.
555  Value minusClamp = bcast(f32Cst(builder, -7.99881172180175781f));
556  Value plusClamp = bcast(f32Cst(builder, 7.99881172180175781f));
557  Value x = clamp(builder, op.getOperand(), minusClamp, plusClamp);
558 
559  // Mask for tiny values that are approximated with `operand`.
560  Value tiny = bcast(f32Cst(builder, 0.0004f));
561  Value tinyMask = builder.create<arith::CmpFOp>(
562  arith::CmpFPredicate::OLT, builder.create<math::AbsFOp>(op.getOperand()),
563  tiny);
564 
565  // The monomial coefficients of the numerator polynomial (odd).
566  Value alpha1 = bcast(f32Cst(builder, 4.89352455891786e-03f));
567  Value alpha3 = bcast(f32Cst(builder, 6.37261928875436e-04f));
568  Value alpha5 = bcast(f32Cst(builder, 1.48572235717979e-05f));
569  Value alpha7 = bcast(f32Cst(builder, 5.12229709037114e-08f));
570  Value alpha9 = bcast(f32Cst(builder, -8.60467152213735e-11f));
571  Value alpha11 = bcast(f32Cst(builder, 2.00018790482477e-13f));
572  Value alpha13 = bcast(f32Cst(builder, -2.76076847742355e-16f));
573 
574  // The monomial coefficients of the denominator polynomial (even).
575  Value beta0 = bcast(f32Cst(builder, 4.89352518554385e-03f));
576  Value beta2 = bcast(f32Cst(builder, 2.26843463243900e-03f));
577  Value beta4 = bcast(f32Cst(builder, 1.18534705686654e-04f));
578  Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f));
579 
580  // Since the polynomials are odd/even, we need x^2.
581  Value x2 = builder.create<arith::MulFOp>(x, x);
582 
583  // Evaluate the numerator polynomial p.
584  Value p = builder.create<math::FmaOp>(x2, alpha13, alpha11);
585  p = builder.create<math::FmaOp>(x2, p, alpha9);
586  p = builder.create<math::FmaOp>(x2, p, alpha7);
587  p = builder.create<math::FmaOp>(x2, p, alpha5);
588  p = builder.create<math::FmaOp>(x2, p, alpha3);
589  p = builder.create<math::FmaOp>(x2, p, alpha1);
590  p = builder.create<arith::MulFOp>(x, p);
591 
592  // Evaluate the denominator polynomial q.
593  Value q = builder.create<math::FmaOp>(x2, beta6, beta4);
594  q = builder.create<math::FmaOp>(x2, q, beta2);
595  q = builder.create<math::FmaOp>(x2, q, beta0);
596 
597  // Divide the numerator by the denominator.
598  Value res = builder.create<arith::SelectOp>(
599  tinyMask, x, builder.create<arith::DivFOp>(p, q));
600 
601  rewriter.replaceOp(op, res);
602 
603  return success();
604 }
605 
606 #define LN2_VALUE \
607  0.693147180559945309417232121458176568075500134360255254120680009493393621L
608 #define LOG2E_VALUE \
609  1.442695040888963407359924681001892137426645954152985934135449406931109219L
610 
611 //----------------------------------------------------------------------------//
612 // LogOp and Log2Op approximation.
613 //----------------------------------------------------------------------------//
614 
615 namespace {
616 template <typename Op>
617 struct LogApproximationBase : public OpRewritePattern<Op> {
619 
620  /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise.
621  LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter,
622  bool base2) const;
623 };
624 } // namespace
625 
626 // This approximation comes from Julien Pommier's SSE math library.
627 // Link: http://gruntthepeon.free.fr/ssemath
628 template <typename Op>
630 LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
631  bool base2) const {
633  return rewriter.notifyMatchFailure(op, "unsupported operand type");
634 
636 
637  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
638  auto bcast = [&](Value value) -> Value {
639  return broadcast(builder, value, shape);
640  };
641 
642  Value cstZero = bcast(f32Cst(builder, 0.0f));
643  Value cstOne = bcast(f32Cst(builder, 1.0f));
644  Value cstNegHalf = bcast(f32Cst(builder, -0.5f));
645 
646  // The smallest non denormalized float number.
647  Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u));
648  Value cstMinusInf = bcast(f32FromBits(builder, 0xff800000u));
649  Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));
650  Value cstNan = bcast(f32FromBits(builder, 0x7fc00000));
651 
652  // Polynomial coefficients.
653  Value cstCephesSQRTHF = bcast(f32Cst(builder, 0.707106781186547524f));
654  Value cstCephesLogP0 = bcast(f32Cst(builder, 7.0376836292E-2f));
655  Value cstCephesLogP1 = bcast(f32Cst(builder, -1.1514610310E-1f));
656  Value cstCephesLogP2 = bcast(f32Cst(builder, 1.1676998740E-1f));
657  Value cstCephesLogP3 = bcast(f32Cst(builder, -1.2420140846E-1f));
658  Value cstCephesLogP4 = bcast(f32Cst(builder, +1.4249322787E-1f));
659  Value cstCephesLogP5 = bcast(f32Cst(builder, -1.6668057665E-1f));
660  Value cstCephesLogP6 = bcast(f32Cst(builder, +2.0000714765E-1f));
661  Value cstCephesLogP7 = bcast(f32Cst(builder, -2.4999993993E-1f));
662  Value cstCephesLogP8 = bcast(f32Cst(builder, +3.3333331174E-1f));
663 
664  Value x = op.getOperand();
665 
666  // Truncate input values to the minimum positive normal.
667  x = max(builder, x, cstMinNormPos);
668 
669  // Extract significant in the range [0.5,1) and exponent.
670  std::pair<Value, Value> pair = frexp(builder, x, /*isPositive=*/true);
671  x = pair.first;
672  Value e = pair.second;
673 
674  // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift
675  // by -1.0. The values are then centered around 0, which improves the
676  // stability of the polynomial evaluation:
677  //
678  // if( x < SQRTHF ) {
679  // e -= 1;
680  // x = x + x - 1.0;
681  // } else { x = x - 1.0; }
682  Value mask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x,
683  cstCephesSQRTHF);
684  Value tmp = builder.create<arith::SelectOp>(mask, x, cstZero);
685 
686  x = builder.create<arith::SubFOp>(x, cstOne);
687  e = builder.create<arith::SubFOp>(
688  e, builder.create<arith::SelectOp>(mask, cstOne, cstZero));
689  x = builder.create<arith::AddFOp>(x, tmp);
690 
691  Value x2 = builder.create<arith::MulFOp>(x, x);
692  Value x3 = builder.create<arith::MulFOp>(x2, x);
693 
694  // Evaluate the polynomial approximant of degree 8 in three parts.
695  Value y0, y1, y2;
696  y0 = builder.create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1);
697  y1 = builder.create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4);
698  y2 = builder.create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7);
699  y0 = builder.create<math::FmaOp>(y0, x, cstCephesLogP2);
700  y1 = builder.create<math::FmaOp>(y1, x, cstCephesLogP5);
701  y2 = builder.create<math::FmaOp>(y2, x, cstCephesLogP8);
702  y0 = builder.create<math::FmaOp>(y0, x3, y1);
703  y0 = builder.create<math::FmaOp>(y0, x3, y2);
704  y0 = builder.create<arith::MulFOp>(y0, x3);
705 
706  y0 = builder.create<math::FmaOp>(cstNegHalf, x2, y0);
707  x = builder.create<arith::AddFOp>(x, y0);
708 
709  if (base2) {
710  Value cstLog2e = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE)));
711  x = builder.create<math::FmaOp>(x, cstLog2e, e);
712  } else {
713  Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE)));
714  x = builder.create<math::FmaOp>(e, cstLn2, x);
715  }
716 
717  Value invalidMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::ULT,
718  op.getOperand(), cstZero);
719  Value zeroMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
720  op.getOperand(), cstZero);
721  Value posInfMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
722  op.getOperand(), cstPosInf);
723 
724  // Filter out invalid values:
725  // • x == 0 -> -INF
726  // • x < 0 -> NAN
727  // • x == +INF -> +INF
728  Value aproximation = builder.create<arith::SelectOp>(
729  zeroMask, cstMinusInf,
730  builder.create<arith::SelectOp>(
731  invalidMask, cstNan,
732  builder.create<arith::SelectOp>(posInfMask, cstPosInf, x)));
733 
734  rewriter.replaceOp(op, aproximation);
735 
736  return success();
737 }
738 
739 namespace {
740 struct LogApproximation : public LogApproximationBase<math::LogOp> {
741  using LogApproximationBase::LogApproximationBase;
742 
743  LogicalResult matchAndRewrite(math::LogOp op,
744  PatternRewriter &rewriter) const final {
745  return logMatchAndRewrite(op, rewriter, /*base2=*/false);
746  }
747 };
748 } // namespace
749 
750 namespace {
751 struct Log2Approximation : public LogApproximationBase<math::Log2Op> {
752  using LogApproximationBase::LogApproximationBase;
753 
754  LogicalResult matchAndRewrite(math::Log2Op op,
755  PatternRewriter &rewriter) const final {
756  return logMatchAndRewrite(op, rewriter, /*base2=*/true);
757  }
758 };
759 } // namespace
760 
761 //----------------------------------------------------------------------------//
762 // Log1p approximation.
763 //----------------------------------------------------------------------------//
764 
765 namespace {
766 struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> {
767 public:
769 
770  LogicalResult matchAndRewrite(math::Log1pOp op,
771  PatternRewriter &rewriter) const final;
772 };
773 } // namespace
774 
775 // Approximate log(1+x).
777 Log1pApproximation::matchAndRewrite(math::Log1pOp op,
778  PatternRewriter &rewriter) const {
780  return rewriter.notifyMatchFailure(op, "unsupported operand type");
781 
783 
784  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
785  auto bcast = [&](Value value) -> Value {
786  return broadcast(builder, value, shape);
787  };
788 
789  // Approximate log(1+x) using the following, due to W. Kahan:
790  // u = x + 1.0;
791  // if (u == 1.0 || u == inf) return x;
792  // return x * log(u) / (u - 1.0);
793  // ^^^^^^^^^^^^^^^^^^^^^^
794  // "logLarge" below.
795  Value cstOne = bcast(f32Cst(builder, 1.0f));
796  Value x = op.getOperand();
797  Value u = builder.create<arith::AddFOp>(x, cstOne);
798  Value uSmall =
799  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne);
800  Value logU = builder.create<math::LogOp>(u);
801  Value uInf =
802  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU);
803  Value logLarge = builder.create<arith::MulFOp>(
804  x, builder.create<arith::DivFOp>(
805  logU, builder.create<arith::SubFOp>(u, cstOne)));
806  Value approximation = builder.create<arith::SelectOp>(
807  builder.create<arith::OrIOp>(uSmall, uInf), x, logLarge);
808  rewriter.replaceOp(op, approximation);
809  return success();
810 }
811 
812 //----------------------------------------------------------------------------//
813 // Erf approximation.
814 //----------------------------------------------------------------------------//
815 
816 // Approximates erf(x) with
817 // a - P(x)/Q(x)
818 // where P and Q are polynomials of degree 4.
819 // Different coefficients are chosen based on the value of x.
820 // The approximation error is ~2.5e-07.
821 // Boost's minimax tool that utilizes the Remez method was used to find the
822 // coefficients.
825  PatternRewriter &rewriter) const {
826  Value operand = op.getOperand();
827  Type elementType = getElementTypeOrSelf(operand);
828 
829  if (!(elementType.isF32() || elementType.isF16()))
830  return rewriter.notifyMatchFailure(op,
831  "only f32 and f16 type is supported.");
832  ArrayRef<int64_t> shape = vectorShape(operand);
833 
834  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
835  auto bcast = [&](Value value) -> Value {
836  return broadcast(builder, value, shape);
837  };
838 
839  const int intervalsCount = 3;
840  const int polyDegree = 4;
841 
842  Value zero = bcast(floatCst(builder, 0, elementType));
843  Value one = bcast(floatCst(builder, 1, elementType));
844  Value pp[intervalsCount][polyDegree + 1];
845  pp[0][0] = bcast(floatCst(builder, +0.00000000000000000e+00f, elementType));
846  pp[0][1] = bcast(floatCst(builder, +1.12837916222975858e+00f, elementType));
847  pp[0][2] = bcast(floatCst(builder, -5.23018562988006470e-01f, elementType));
848  pp[0][3] = bcast(floatCst(builder, +2.09741709609267072e-01f, elementType));
849  pp[0][4] = bcast(floatCst(builder, +2.58146801602987875e-02f, elementType));
850  pp[1][0] = bcast(floatCst(builder, +0.00000000000000000e+00f, elementType));
851  pp[1][1] = bcast(floatCst(builder, +1.12750687816789140e+00f, elementType));
852  pp[1][2] = bcast(floatCst(builder, -3.64721408487825775e-01f, elementType));
853  pp[1][3] = bcast(floatCst(builder, +1.18407396425136952e-01f, elementType));
854  pp[1][4] = bcast(floatCst(builder, +3.70645533056476558e-02f, elementType));
855  pp[2][0] = bcast(floatCst(builder, -3.30093071049483172e-03f, elementType));
856  pp[2][1] = bcast(floatCst(builder, +3.51961938357697011e-03f, elementType));
857  pp[2][2] = bcast(floatCst(builder, -1.41373622814988039e-03f, elementType));
858  pp[2][3] = bcast(floatCst(builder, +2.53447094961941348e-04f, elementType));
859  pp[2][4] = bcast(floatCst(builder, -1.71048029455037401e-05f, elementType));
860 
861  Value qq[intervalsCount][polyDegree + 1];
862  qq[0][0] = bcast(floatCst(builder, +1.000000000000000000e+00f, elementType));
863  qq[0][1] = bcast(floatCst(builder, -4.635138185962547255e-01f, elementType));
864  qq[0][2] = bcast(floatCst(builder, +5.192301327279782447e-01f, elementType));
865  qq[0][3] = bcast(floatCst(builder, -1.318089722204810087e-01f, elementType));
866  qq[0][4] = bcast(floatCst(builder, +7.397964654672315005e-02f, elementType));
867  qq[1][0] = bcast(floatCst(builder, +1.00000000000000000e+00f, elementType));
868  qq[1][1] = bcast(floatCst(builder, -3.27607011824493086e-01f, elementType));
869  qq[1][2] = bcast(floatCst(builder, +4.48369090658821977e-01f, elementType));
870  qq[1][3] = bcast(floatCst(builder, -8.83462621207857930e-02f, elementType));
871  qq[1][4] = bcast(floatCst(builder, +5.72442770283176093e-02f, elementType));
872  qq[2][0] = bcast(floatCst(builder, +1.00000000000000000e+00f, elementType));
873  qq[2][1] = bcast(floatCst(builder, -2.06069165953913769e+00f, elementType));
874  qq[2][2] = bcast(floatCst(builder, +1.62705939945477759e+00f, elementType));
875  qq[2][3] = bcast(floatCst(builder, -5.83389859211130017e-01f, elementType));
876  qq[2][4] = bcast(floatCst(builder, +8.21908939856640930e-02f, elementType));
877 
878  Value offsets[intervalsCount];
879  offsets[0] = bcast(floatCst(builder, 0.0f, elementType));
880  offsets[1] = bcast(floatCst(builder, 0.0f, elementType));
881  offsets[2] = bcast(floatCst(builder, 1.0f, elementType));
882 
883  Value bounds[intervalsCount];
884  bounds[0] = bcast(floatCst(builder, 0.8f, elementType));
885  bounds[1] = bcast(floatCst(builder, 2.0f, elementType));
886  bounds[2] = bcast(floatCst(builder, 3.75f, elementType));
887 
888  Value isNegativeArg =
889  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
890  Value negArg = builder.create<arith::NegFOp>(operand);
891  Value x = builder.create<arith::SelectOp>(isNegativeArg, negArg, operand);
892 
893  Value offset = offsets[0];
894  Value p[polyDegree + 1];
895  Value q[polyDegree + 1];
896  for (int i = 0; i <= polyDegree; ++i) {
897  p[i] = pp[0][i];
898  q[i] = qq[0][i];
899  }
900 
901  // TODO: maybe use vector stacking to reduce the number of selects.
902  Value isLessThanBound[intervalsCount];
903  for (int j = 0; j < intervalsCount - 1; ++j) {
904  isLessThanBound[j] =
905  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[j]);
906  for (int i = 0; i <= polyDegree; ++i) {
907  p[i] = builder.create<arith::SelectOp>(isLessThanBound[j], p[i],
908  pp[j + 1][i]);
909  q[i] = builder.create<arith::SelectOp>(isLessThanBound[j], q[i],
910  qq[j + 1][i]);
911  }
912  offset = builder.create<arith::SelectOp>(isLessThanBound[j], offset,
913  offsets[j + 1]);
914  }
915  isLessThanBound[intervalsCount - 1] = builder.create<arith::CmpFOp>(
916  arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
917 
918  Value pPoly = makePolynomialCalculation(builder, p, x);
919  Value qPoly = makePolynomialCalculation(builder, q, x);
920  Value rationalPoly = builder.create<arith::DivFOp>(pPoly, qPoly);
921  Value formula = builder.create<arith::AddFOp>(offset, rationalPoly);
922  formula = builder.create<arith::SelectOp>(isLessThanBound[intervalsCount - 1],
923  formula, one);
924 
925  // erf is odd function: erf(x) = -erf(-x).
926  Value negFormula = builder.create<arith::NegFOp>(formula);
927  Value res =
928  builder.create<arith::SelectOp>(isNegativeArg, negFormula, formula);
929 
930  rewriter.replaceOp(op, res);
931 
932  return success();
933 }
934 
935 //----------------------------------------------------------------------------//
936 // Exp approximation.
937 //----------------------------------------------------------------------------//
938 
939 namespace {
940 
941 Value clampWithNormals(ImplicitLocOpBuilder &builder,
942  const llvm::ArrayRef<int64_t> shape, Value value,
943  float lowerBound, float upperBound) {
944  assert(!std::isnan(lowerBound));
945  assert(!std::isnan(upperBound));
946 
947  auto bcast = [&](Value value) -> Value {
948  return broadcast(builder, value, shape);
949  };
950 
951  auto selectCmp = [&builder](auto pred, Value value, Value bound) {
952  return builder.create<arith::SelectOp>(
953  builder.create<arith::CmpFOp>(pred, value, bound), value, bound);
954  };
955 
956  // Note: prefer UGE/ULE vs. UGT/ULT, since they generate vmaxps/vminps vs.
957  // vcmpleps+vmovaps on x86_64. The latter outcome is also obtained with
958  // arith::{Max,Min}FOp.
959  value = selectCmp(arith::CmpFPredicate::UGE, value,
960  bcast(f32Cst(builder, lowerBound)));
961  value = selectCmp(arith::CmpFPredicate::ULE, value,
962  bcast(f32Cst(builder, upperBound)));
963  return value;
964 }
965 
966 struct ExpApproximation : public OpRewritePattern<math::ExpOp> {
967 public:
969 
970  LogicalResult matchAndRewrite(math::ExpOp op,
971  PatternRewriter &rewriter) const final;
972 };
973 
975 ExpApproximation::matchAndRewrite(math::ExpOp op,
976  PatternRewriter &rewriter) const {
977  auto shape = vectorShape(op.getOperand().getType());
978  auto elementTy = getElementTypeOrSelf(op.getType());
979  if (!elementTy.isF32())
980  return rewriter.notifyMatchFailure(op, "unsupported operand type");
981 
982  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
983 
984  auto add = [&](Value a, Value b) -> Value {
985  return builder.create<arith::AddFOp>(a, b);
986  };
987  auto bcast = [&](Value value) -> Value {
988  return broadcast(builder, value, shape);
989  };
990  auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
991  auto fmla = [&](Value a, Value b, Value c) {
992  return builder.create<math::FmaOp>(a, b, c);
993  };
994  auto mul = [&](Value a, Value b) -> Value {
995  return builder.create<arith::MulFOp>(a, b);
996  };
997 
998  // Polynomial approximation from Cephes.
999  //
1000  // To compute e^x, we re-express it as
1001  //
1002  // e^x = e^(a + b)
1003  // = e^(a + n log(2))
1004  // = e^a * 2^n.
1005  //
1006  // We choose n = round(x / log(2)), restricting the value of `a` to
1007  // (-log(2)/2, log(2)/2). We then use a polynomial to compute e^a. The
1008  // relative error between our approximation and the true value of e^a is less
1009  // than 2^-22.5 for all values of `a` within this range.
1010 
1011  // Restrict input to a small range, including some values that evaluate to
1012  // +/- inf. Note that for our lower bound, we choose log(2^-126) instead of
1013  // log(F32_EPSILON). We do so because this routine always flushes denormal
1014  // floating points to 0. Therefore, we only need to worry about exponentiating
1015  // up to the smallest representable non-denormal floating point, which is
1016  // 2^-126.
1017 
1018  // Constants.
1019  Value cstHalf = bcast(f32Cst(builder, 0.5f));
1020  Value cstOne = bcast(f32Cst(builder, 1.0f));
1021 
1022  // 1/log(2)
1023  Value cstLog2ef = bcast(f32Cst(builder, 1.44269504088896341f));
1024 
1025  Value cstExpC1 = bcast(f32Cst(builder, -0.693359375f));
1026  Value cstExpC2 = bcast(f32Cst(builder, 2.12194440e-4f));
1027  Value cstExpP0 = bcast(f32Cst(builder, 1.9875691500E-4f));
1028  Value cstExpP1 = bcast(f32Cst(builder, 1.3981999507E-3f));
1029  Value cstExpP2 = bcast(f32Cst(builder, 8.3334519073E-3f));
1030  Value cstExpP3 = bcast(f32Cst(builder, 4.1665795894E-2f));
1031  Value cstExpP4 = bcast(f32Cst(builder, 1.6666665459E-1f));
1032  Value cstExpP5 = bcast(f32Cst(builder, 5.0000001201E-1f));
1033 
1034  // Our computations below aren't particularly sensitive to the exact choices
1035  // here, so we choose values a bit larger/smaller than
1036  //
1037  // log(F32_MAX) = 88.723...
1038  // log(2^-126) = -87.337...
1039  Value x = op.getOperand();
1040  x = clampWithNormals(builder, shape, x, -87.8f, 88.8f);
1041  Value n = floor(fmla(x, cstLog2ef, cstHalf));
1042 
1043  // When we eventually do the multiplication in e^a * 2^n, we need to handle
1044  // the case when n > 127, the max fp32 exponent (so 2^n == inf) but e^a < 1
1045  // (so e^a * 2^n != inf). There's a similar problem for n < -126, the
1046  // smallest fp32 exponent.
1047  //
1048  // A straightforward solution would be to detect n out of range and split it
1049  // up, doing
1050  //
1051  // e^a * 2^n = e^a * 2^(n1 + n2)
1052  // = (2^n1 * e^a) * 2^n2.
1053  //
1054  // But it turns out this approach is quite slow, probably because it
1055  // manipulates subnormal values.
1056  //
1057  // The approach we use instead is to clamp n to [-127, 127]. Let n' be the
1058  // value of n clamped to [-127, 127]. In the case where n' = 127, `a` can grow
1059  // up to as large as 88.8 - 127 * log(2) which is about 0.7703. Even though
1060  // this value of `a` is outside our previously specified range, e^a will still
1061  // only have a relative error of approximately 2^-16 at worse. In practice
1062  // this seems to work well enough; it passes our exhaustive tests, breaking
1063  // only one result, and by one ulp (we return exp(88.7228394) = max-float but
1064  // we should return inf).
1065  //
1066  // In the case where n' = -127, the original input value of x is so small that
1067  // e^x, our final answer, is less than 2^-126. Since 2^-126 is the smallest
1068  // normal floating point, and since we flush denormals, we simply return 0. We
1069  // do this in a branchless way by observing that our code for constructing 2^n
1070  // produces 0 if n = -127.
1071  //
1072  // The proof that n' = -127 implies e^x < 2^-126 is as follows:
1073  //
1074  // n' = -127 implies n <= -127
1075  // implies round(x / log(2)) <= -127
1076  // implies x/log(2) < -126.5
1077  // implies x < -126.5 * log(2)
1078  // implies e^x < e^(-126.5 * log(2))
1079  // implies e^x < 2^-126.5 < 2^-126
1080  //
1081  // This proves that n' = -127 implies e^x < 2^-126.
1082  n = clampWithNormals(builder, shape, n, -127.0f, 127.0f);
1083 
1084  // Computes x = x - n' * log(2), the value for `a`
1085  x = fmla(cstExpC1, n, x);
1086  x = fmla(cstExpC2, n, x);
1087 
1088  // Polynomial to compute z = e^a, accurate for a in (-0.5, 0.5).
1089  Value z = fmla(x, cstExpP0, cstExpP1);
1090  z = fmla(z, x, cstExpP2);
1091  z = fmla(z, x, cstExpP3);
1092  z = fmla(z, x, cstExpP4);
1093  z = fmla(z, x, cstExpP5);
1094  z = fmla(z, mul(x, x), x);
1095  z = add(cstOne, z);
1096 
1097  // Convert n' to an i32. This is safe because we clamped it above.
1098  auto i32Vec = broadcast(builder.getI32Type(), shape);
1099  Value nI32 = builder.create<arith::FPToSIOp>(i32Vec, n);
1100 
1101  // Creates the value 2^n' if -126 <= n' <= 127 and 0 if n' = -127.
1102  Value pow2 = exp2I32(builder, nI32);
1103 
1104  // Return z * 2^n' if -126 <= n' <= 127 and 0 if n = -127.
1105  Value ret = mul(z, pow2);
1106 
1107  rewriter.replaceOp(op, ret);
1108  return mlir::success();
1109 }
1110 
1111 } // namespace
1112 
1113 //----------------------------------------------------------------------------//
1114 // ExpM1 approximation.
1115 //----------------------------------------------------------------------------//
1116 
1117 namespace {
1118 
1119 struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> {
1120 public:
1122 
1123  LogicalResult matchAndRewrite(math::ExpM1Op op,
1124  PatternRewriter &rewriter) const final;
1125 };
1126 } // namespace
1127 
1129 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
1130  PatternRewriter &rewriter) const {
1131  if (!getElementTypeOrSelf(op.getOperand()).isF32())
1132  return rewriter.notifyMatchFailure(op, "unsupported operand type");
1133 
1135 
1136  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1137  auto bcast = [&](Value value) -> Value {
1138  return broadcast(builder, value, shape);
1139  };
1140 
1141  // expm1(x) = exp(x) - 1 = u - 1.
1142  // We have to handle it carefully when x is near 0, i.e. u ~= 1,
1143  // and when the input is ~= -inf, i.e. u - 1 ~= -1.
1144  Value cstOne = bcast(f32Cst(builder, 1.0f));
1145  Value cstNegOne = bcast(f32Cst(builder, -1.0f));
1146  Value x = op.getOperand();
1147  Value u = builder.create<math::ExpOp>(x);
1148  Value uEqOneOrNaN =
1149  builder.create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
1150  Value uMinusOne = builder.create<arith::SubFOp>(u, cstOne);
1151  Value uMinusOneEqNegOne = builder.create<arith::CmpFOp>(
1152  arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
1153  // logU = log(u) ~= x
1154  Value logU = builder.create<math::LogOp>(u);
1155 
1156  // Detect exp(x) = +inf; written this way to avoid having to form +inf.
1157  Value isInf =
1158  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u);
1159 
1160  // (u - 1) * (x / ~x)
1161  Value expm1 = builder.create<arith::MulFOp>(
1162  uMinusOne, builder.create<arith::DivFOp>(x, logU));
1163  expm1 = builder.create<arith::SelectOp>(isInf, u, expm1);
1164  Value approximation = builder.create<arith::SelectOp>(
1165  uEqOneOrNaN, x,
1166  builder.create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
1167  rewriter.replaceOp(op, approximation);
1168  return success();
1169 }
1170 
1171 //----------------------------------------------------------------------------//
1172 // Sin and Cos approximation.
1173 //----------------------------------------------------------------------------//
1174 
1175 namespace {
1176 
1177 template <bool isSine, typename OpTy>
1178 struct SinAndCosApproximation : public OpRewritePattern<OpTy> {
1179 public:
1181 
1182  LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
1183 };
1184 } // namespace
1185 
1186 #define TWO_OVER_PI \
1187  0.6366197723675813430755350534900574481378385829618257949906693762L
1188 #define PI_OVER_2 \
1189  1.5707963267948966192313216916397514420985846996875529104874722961L
1190 
1191 // Approximates sin(x) or cos(x) by finding the best approximation polynomial in
1192 // the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the
1193 // reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y).
1194 template <bool isSine, typename OpTy>
1195 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
1196  OpTy op, PatternRewriter &rewriter) const {
1197  static_assert(
1198  llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
1199  "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
1200 
1201  if (!getElementTypeOrSelf(op.getOperand()).isF32())
1202  return rewriter.notifyMatchFailure(op, "unsupported operand type");
1203 
1205 
1206  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1207  auto bcast = [&](Value value) -> Value {
1208  return broadcast(builder, value, shape);
1209  };
1210  auto mul = [&](Value a, Value b) -> Value {
1211  return builder.create<arith::MulFOp>(a, b);
1212  };
1213  auto sub = [&](Value a, Value b) -> Value {
1214  return builder.create<arith::SubFOp>(a, b);
1215  };
1216  auto floor = [&](Value a) { return builder.create<math::FloorOp>(a); };
1217 
1218  auto i32Vec = broadcast(builder.getI32Type(), shape);
1219  auto fPToSingedInteger = [&](Value a) -> Value {
1220  return builder.create<arith::FPToSIOp>(i32Vec, a);
1221  };
1222 
1223  auto modulo4 = [&](Value a) -> Value {
1224  return builder.create<arith::AndIOp>(a, bcast(i32Cst(builder, 3)));
1225  };
1226 
1227  auto isEqualTo = [&](Value a, Value b) -> Value {
1228  return builder.create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b);
1229  };
1230 
1231  auto isGreaterThan = [&](Value a, Value b) -> Value {
1232  return builder.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b);
1233  };
1234 
1235  auto select = [&](Value cond, Value t, Value f) -> Value {
1236  return builder.create<arith::SelectOp>(cond, t, f);
1237  };
1238 
1239  auto fmla = [&](Value a, Value b, Value c) {
1240  return builder.create<math::FmaOp>(a, b, c);
1241  };
1242 
1243  auto bitwiseOr = [&](Value a, Value b) {
1244  return builder.create<arith::OrIOp>(a, b);
1245  };
1246 
1247  Value twoOverPi = bcast(f32Cst(builder, (float)TWO_OVER_PI));
1248  Value piOverTwo = bcast(f32Cst(builder, (float)PI_OVER_2));
1249 
1250  Value x = op.getOperand();
1251 
1252  Value k = floor(mul(x, twoOverPi));
1253 
1254  Value y = sub(x, mul(k, piOverTwo));
1255 
1256  Value cstOne = bcast(f32Cst(builder, 1.0));
1257  Value cstNegativeOne = bcast(f32Cst(builder, -1.0));
1258 
1259  Value cstSC2 = bcast(f32Cst(builder, -0.16666667163372039794921875f));
1260  Value cstSC4 = bcast(f32Cst(builder, 8.333347737789154052734375e-3f));
1261  Value cstSC6 = bcast(f32Cst(builder, -1.9842604524455964565277099609375e-4f));
1262  Value cstSC8 =
1263  bcast(f32Cst(builder, 2.760012648650445044040679931640625e-6f));
1264  Value cstSC10 =
1265  bcast(f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f));
1266 
1267  Value cstCC2 = bcast(f32Cst(builder, -0.5f));
1268  Value cstCC4 = bcast(f32Cst(builder, 4.166664183139801025390625e-2f));
1269  Value cstCC6 = bcast(f32Cst(builder, -1.388833043165504932403564453125e-3f));
1270  Value cstCC8 = bcast(f32Cst(builder, 2.47562347794882953166961669921875e-5f));
1271  Value cstCC10 =
1272  bcast(f32Cst(builder, -2.59630184018533327616751194000244140625e-7f));
1273 
1274  Value kMod4 = modulo4(fPToSingedInteger(k));
1275 
1276  Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, 0)));
1277  Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, 1)));
1278  Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, 2)));
1279  Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, 3)));
1280 
1281  Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
1282  Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, 1)))
1283  : bitwiseOr(kR1, kR2);
1284 
1285  Value y2 = mul(y, y);
1286 
1287  Value base = select(sinuseCos, cstOne, y);
1288  Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
1289  Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
1290  Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
1291  Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
1292  Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
1293 
1294  Value v1 = fmla(y2, cstC10, cstC8);
1295  Value v2 = fmla(y2, v1, cstC6);
1296  Value v3 = fmla(y2, v2, cstC4);
1297  Value v4 = fmla(y2, v3, cstC2);
1298  Value v5 = fmla(y2, v4, cstOne);
1299  Value v6 = mul(base, v5);
1300 
1301  Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6);
1302 
1303  rewriter.replaceOp(op, approximation);
1304 
1305  return success();
1306 }
1307 
1308 //----------------------------------------------------------------------------//
1309 // Cbrt approximation.
1310 //----------------------------------------------------------------------------//
1311 
1312 namespace {
1313 struct CbrtApproximation : public OpRewritePattern<math::CbrtOp> {
1315 
1316  LogicalResult matchAndRewrite(math::CbrtOp op,
1317  PatternRewriter &rewriter) const final;
1318 };
1319 } // namespace
1320 
1321 // Estimation of cube-root using an algorithm defined in
1322 // Hacker's Delight 2nd Edition.
1324 CbrtApproximation::matchAndRewrite(math::CbrtOp op,
1325  PatternRewriter &rewriter) const {
1326  auto operand = op.getOperand();
1327  if (!getElementTypeOrSelf(operand).isF32())
1328  return rewriter.notifyMatchFailure(op, "unsupported operand type");
1329 
1330  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1331  ArrayRef<int64_t> shape = vectorShape(operand);
1332 
1333  Type floatTy = getElementTypeOrSelf(operand.getType());
1334  Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth());
1335 
1336  // Convert to vector types if necessary.
1337  floatTy = broadcast(floatTy, shape);
1338  intTy = broadcast(intTy, shape);
1339 
1340  auto bconst = [&](TypedAttr attr) -> Value {
1341  Value value = b.create<arith::ConstantOp>(attr);
1342  return broadcast(b, value, shape);
1343  };
1344 
1345  // Declare the initial values:
1346  Value intTwo = bconst(b.getI32IntegerAttr(2));
1347  Value intFour = bconst(b.getI32IntegerAttr(4));
1348  Value intEight = bconst(b.getI32IntegerAttr(8));
1349  Value intMagic = bconst(b.getI32IntegerAttr(0x2a5137a0));
1350  Value fpThird = bconst(b.getF32FloatAttr(0.33333333f));
1351  Value fpTwo = bconst(b.getF32FloatAttr(2.0f));
1352  Value fpZero = bconst(b.getF32FloatAttr(0.0f));
1353 
1354  // Compute an approximation of one third:
1355  // union {int ix; float x;};
1356  // x = x0;
1357  // ix = ix/4 + ix/16;
1358  Value absValue = b.create<math::AbsFOp>(operand);
1359  Value intValue = b.create<arith::BitcastOp>(intTy, absValue);
1360  Value divideBy4 = b.create<arith::ShRSIOp>(intValue, intTwo);
1361  Value divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
1362  intValue = b.create<arith::AddIOp>(divideBy4, divideBy16);
1363 
1364  // ix = ix + ix/16;
1365  divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
1366  intValue = b.create<arith::AddIOp>(intValue, divideBy16);
1367 
1368  // ix = ix + ix/256;
1369  Value divideBy256 = b.create<arith::ShRSIOp>(intValue, intEight);
1370  intValue = b.create<arith::AddIOp>(intValue, divideBy256);
1371 
1372  // ix = 0x2a5137a0 + ix;
1373  intValue = b.create<arith::AddIOp>(intValue, intMagic);
1374 
1375  // Perform one newtons step:
1376  // x = 0.33333333f*(2.0f*x + x0/(x*x));
1377  Value floatValue = b.create<arith::BitcastOp>(floatTy, intValue);
1378  Value squared = b.create<arith::MulFOp>(floatValue, floatValue);
1379  Value mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
1380  Value divSquared = b.create<arith::DivFOp>(absValue, squared);
1381  floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
1382  floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
1383 
1384  // x = 0.33333333f*(2.0f*x + x0/(x*x));
1385  squared = b.create<arith::MulFOp>(floatValue, floatValue);
1386  mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
1387  divSquared = b.create<arith::DivFOp>(absValue, squared);
1388  floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
1389  floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
1390 
1391  // Check for zero and restore sign.
1392  Value isZero =
1393  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absValue, fpZero);
1394  floatValue = b.create<arith::SelectOp>(isZero, fpZero, floatValue);
1395  floatValue = b.create<math::CopySignOp>(floatValue, operand);
1396 
1397  rewriter.replaceOp(op, floatValue);
1398  return success();
1399 }
1400 
1401 //----------------------------------------------------------------------------//
1402 // Rsqrt approximation.
1403 //----------------------------------------------------------------------------//
1404 
1405 namespace {
1406 struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> {
1408 
1409  LogicalResult matchAndRewrite(math::RsqrtOp op,
1410  PatternRewriter &rewriter) const final;
1411 };
1412 } // namespace
1413 
1415 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
1416  PatternRewriter &rewriter) const {
1417  if (!getElementTypeOrSelf(op.getOperand()).isF32())
1418  return rewriter.notifyMatchFailure(op, "unsupported operand type");
1419 
1421 
1422  // Only support already-vectorized rsqrt's.
1423  if (shape.empty() || shape.back() % 8 != 0)
1424  return rewriter.notifyMatchFailure(op, "unsupported operand type");
1425 
1426  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1427  auto bcast = [&](Value value) -> Value {
1428  return broadcast(builder, value, shape);
1429  };
1430 
1431  Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));
1432  Value cstOnePointFive = bcast(f32Cst(builder, 1.5f));
1433  Value cstNegHalf = bcast(f32Cst(builder, -0.5f));
1434  Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u));
1435 
1436  Value negHalf = builder.create<arith::MulFOp>(op.getOperand(), cstNegHalf);
1437 
1438  // Select only the inverse sqrt of positive normals (denormals are
1439  // flushed to zero).
1440  Value ltMinMask = builder.create<arith::CmpFOp>(
1441  arith::CmpFPredicate::OLT, op.getOperand(), cstMinNormPos);
1442  Value infMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
1443  op.getOperand(), cstPosInf);
1444  Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask);
1445 
1446  // Compute an approximate result.
1448  builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value {
1449  return builder.create<x86vector::RsqrtOp>(operands);
1450  });
1451 
1452  // Do a single step of Newton-Raphson iteration to improve the approximation.
1453  // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
1454  // It is essential to evaluate the inner term like this because forming
1455  // y_n^2 may over- or underflow.
1456  Value inner = builder.create<arith::MulFOp>(negHalf, yApprox);
1457  Value fma = builder.create<math::FmaOp>(yApprox, inner, cstOnePointFive);
1458  Value yNewton = builder.create<arith::MulFOp>(yApprox, fma);
1459 
1460  // Select the result of the Newton-Raphson step for positive normal arguments.
1461  // For other arguments, choose the output of the intrinsic. This will
1462  // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if
1463  // x is zero or a positive denormalized float (equivalent to flushing positive
1464  // denormalized inputs to zero).
1465  Value res =
1466  builder.create<arith::SelectOp>(notNormalFiniteMask, yApprox, yNewton);
1467  rewriter.replaceOp(op, res);
1468 
1469  return success();
1470 }
1471 
1472 //----------------------------------------------------------------------------//
1473 
1475  RewritePatternSet &patterns) {
1476  patterns.add<TanhApproximation>(patterns.getContext());
1477 }
1478 
1480  RewritePatternSet &patterns) {
1481  patterns.add<ErfPolynomialApproximation>(patterns.getContext());
1482 }
1483 
1485  RewritePatternSet &patterns,
1487  // Patterns for leveraging existing f32 lowerings on other data types.
1488  patterns
1489  .add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
1490  ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
1491  ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
1492  ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
1493  ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
1494  ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
1495  patterns.getContext());
1496 
1497  patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
1498  LogApproximation, Log2Approximation, Log1pApproximation,
1499  ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1500  CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
1501  SinAndCosApproximation<false, math::CosOp>>(
1502  patterns.getContext());
1503  if (options.enableAvx2) {
1504  patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
1505  patterns.getContext());
1506  }
1507 }
static llvm::ManagedStatic< PassManagerOptions > options
static std::pair< Value, Value > frexp(ImplicitLocOpBuilder &builder, Value arg, bool isPositive=false)
#define LN2_VALUE
static Type broadcast(Type type, ArrayRef< int64_t > shape)
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg)
#define PI_OVER_2
#define TWO_OVER_PI
static ArrayRef< int64_t > vectorShape(Type type)
static Value floatCst(ImplicitLocOpBuilder &builder, float value, Type elementType)
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value)
#define LOG2E_VALUE
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits)
static Value f32Cst(ImplicitLocOpBuilder &builder, double value)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
FloatType getF32Type()
Definition: Builders.cpp:63
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
IntegerType getI32Type()
Definition: Builders.cpp:83
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:253
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
MLIRContext * getContext() const
Definition: PatternMatch.h:785
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:809
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:685
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:537
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:51
bool isF16() const
Definition: Types.cpp:49
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
Type front()
Return first type in the range.
Definition: TypeRange.h:148
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Fraction abs(const Fraction &f)
Definition: Fraction.h:104
MPInt floor(const Fraction &f)
Definition: Fraction.h:74
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns)
void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns)
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
int64_t computeMaxLinearIndex(ArrayRef< int64_t > basis)
Return the number of elements of basis (i.e.
Definition: IndexingUtils.h:69
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns, const MathPolynomialApproximationOptions &options={})
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361
LogicalResult matchAndRewrite(math::ErfOp op, PatternRewriter &rewriter) const final
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.