MLIR  22.0.0git
PolynomialApproximation.cpp
Go to the documentation of this file.
1 //===- PolynomialApproximation.cpp - Approximate math operations ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements expansion of math operations to fast approximations
10 // that do not rely on any of the library functions.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <climits>
15 #include <cmath>
16 #include <cstddef>
17 
26 #include "mlir/IR/Builders.h"
27 #include "mlir/IR/BuiltinTypes.h"
28 #include "mlir/IR/OpDefinition.h"
29 #include "mlir/IR/PatternMatch.h"
30 #include "mlir/IR/TypeUtilities.h"
32 #include "llvm/ADT/ArrayRef.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/Support/MathExtras.h"
35 
36 using namespace mlir;
37 using namespace mlir::math;
38 using namespace mlir::vector;
39 
40 // Helper to encapsulate a vector's shape (including scalable dims).
41 struct VectorShape {
44 };
45 
46 // Returns vector shape if the type is a vector, otherwise return nullopt.
47 static std::optional<VectorShape> vectorShape(Type type) {
48  if (auto vectorType = dyn_cast<VectorType>(type)) {
49  return VectorShape{vectorType.getShape(), vectorType.getScalableDims()};
50  }
51  return std::nullopt;
52 }
53 
54 static std::optional<VectorShape> vectorShape(Value value) {
55  return vectorShape(value.getType());
56 }
57 
58 //----------------------------------------------------------------------------//
59 // Broadcast scalar types and values into vector types and values.
60 //----------------------------------------------------------------------------//
61 
62 // Broadcasts scalar type into vector type (iff shape is non-scalar).
63 static Type broadcast(Type type, std::optional<VectorShape> shape) {
64  assert(!isa<VectorType>(type) && "must be scalar type");
65  return shape ? VectorType::get(shape->sizes, type, shape->scalableFlags)
66  : type;
67 }
68 
69 // Broadcasts scalar value into vector (iff shape is non-scalar).
70 static Value broadcast(ImplicitLocOpBuilder &builder, Value value,
71  std::optional<VectorShape> shape) {
72  assert(!isa<VectorType>(value.getType()) && "must be scalar value");
73  auto type = broadcast(value.getType(), shape);
74  return shape ? BroadcastOp::create(builder, type, value) : value;
75 }
76 
77 //----------------------------------------------------------------------------//
78 // Helper function to handle n-D vectors with 1-D operations.
79 //----------------------------------------------------------------------------//
80 
81 // Expands and unrolls n-D vector operands into multiple fixed size 1-D vectors
82 // and calls the compute function with 1-D vector operands. Stitches back all
83 // results into the original n-D vector result.
84 //
85 // Examples: vectorWidth = 8
86 // - vector<4x8xf32> unrolled 4 times
87 // - vector<16xf32> expanded to vector<2x8xf32> and unrolled 2 times
88 // - vector<4x16xf32> expanded to vector<4x2x8xf32> and unrolled 4*2 times
89 //
90 // Some math approximations rely on ISA-specific operations that only accept
91 // fixed size 1-D vectors (e.g. AVX expects vectors of width 8).
92 //
93 // It is the caller's responsibility to verify that the inner dimension is
94 // divisible by the vectorWidth, and that all operands have the same vector
95 // shape.
96 static Value
98  ValueRange operands, int64_t vectorWidth,
100  assert(!operands.empty() && "operands must be not empty");
101  assert(vectorWidth > 0 && "vector width must be larger than 0");
102 
103  VectorType inputType = cast<VectorType>(operands[0].getType());
104  ArrayRef<int64_t> inputShape = inputType.getShape();
105 
106  // If input shape matches target vector width, we can just call the
107  // user-provided compute function with the operands.
108  if (inputShape == llvm::ArrayRef(vectorWidth))
109  return compute(operands);
110 
111  // Check if the inner dimension has to be expanded, or we can directly iterate
112  // over the outer dimensions of the vector.
113  int64_t innerDim = inputShape.back();
114  int64_t expansionDim = innerDim / vectorWidth;
115  assert((innerDim % vectorWidth == 0) && "invalid inner dimension size");
116 
117  // Maybe expand operands to the higher rank vector shape that we'll use to
118  // iterate over and extract one dimensional vectors.
119  SmallVector<int64_t> expandedShape(inputShape);
120  SmallVector<Value> expandedOperands(operands);
121 
122  if (expansionDim > 1) {
123  // Expand shape from [..., innerDim] to [..., expansionDim, vectorWidth].
124  expandedShape.insert(expandedShape.end() - 1, expansionDim);
125  expandedShape.back() = vectorWidth;
126 
127  for (unsigned i = 0; i < operands.size(); ++i) {
128  auto operand = operands[i];
129  auto eltType = cast<VectorType>(operand.getType()).getElementType();
130  auto expandedType = VectorType::get(expandedShape, eltType);
131  expandedOperands[i] =
132  vector::ShapeCastOp::create(builder, expandedType, operand);
133  }
134  }
135 
136  // Iterate over all outer dimensions of the compute shape vector type.
137  auto iterationDims = ArrayRef<int64_t>(expandedShape).drop_back();
138  int64_t maxIndex = computeMaxLinearIndex(iterationDims);
139  auto strides = computeStrides(iterationDims);
140 
141  // Compute results for each one dimensional vector.
142  SmallVector<Value> results(maxIndex);
143 
144  for (int64_t i = 0; i < maxIndex; ++i) {
145  auto offsets = delinearize(i, strides);
146 
147  SmallVector<Value> extracted(expandedOperands.size());
148  for (const auto &tuple : llvm::enumerate(expandedOperands))
149  extracted[tuple.index()] =
150  vector::ExtractOp::create(builder, tuple.value(), offsets);
151 
152  results[i] = compute(extracted);
153  }
154 
155  // Stitch results together into one large vector.
156  Type resultEltType = cast<VectorType>(results[0].getType()).getElementType();
157  Type resultExpandedType = VectorType::get(expandedShape, resultEltType);
158  Value result = arith::ConstantOp::create(
159  builder, resultExpandedType, builder.getZeroAttr(resultExpandedType));
160 
161  for (int64_t i = 0; i < maxIndex; ++i)
162  result = vector::InsertOp::create(builder, results[i], result,
163  delinearize(i, strides));
164 
165  // Reshape back to the original vector shape.
166  return vector::ShapeCastOp::create(
167  builder, VectorType::get(inputShape, resultEltType), result);
168 }
169 
170 //----------------------------------------------------------------------------//
171 // Helper functions to create constants.
172 //----------------------------------------------------------------------------//
173 
174 static Value boolCst(ImplicitLocOpBuilder &builder, bool value) {
175  return arith::ConstantOp::create(builder, builder.getBoolAttr(value));
176 }
177 
178 static Value floatCst(ImplicitLocOpBuilder &builder, float value,
179  Type elementType) {
180  assert((elementType.isF16() || elementType.isF32()) &&
181  "x must be f16 or f32 type.");
182  return arith::ConstantOp::create(builder,
183  builder.getFloatAttr(elementType, value));
184 }
185 
186 static Value f32Cst(ImplicitLocOpBuilder &builder, double value) {
187  return arith::ConstantOp::create(builder, builder.getF32FloatAttr(value));
188 }
189 
190 static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) {
191  return arith::ConstantOp::create(builder, builder.getI32IntegerAttr(value));
192 }
193 
194 static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) {
195  Value i32Value = i32Cst(builder, static_cast<int32_t>(bits));
196  return arith::BitcastOp::create(builder, builder.getF32Type(), i32Value);
197 }
198 
199 //----------------------------------------------------------------------------//
200 // Helper functions to build math functions approximations.
201 //----------------------------------------------------------------------------//
202 
203 // Return the minimum of the two values or NaN if value is NaN
204 static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound) {
205  return arith::SelectOp::create(
206  builder,
207  arith::CmpFOp::create(builder, arith::CmpFPredicate::ULT, value, bound),
208  value, bound);
209 }
210 
211 // Return the maximum of the two values or NaN if value is NaN
212 static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound) {
213  return arith::SelectOp::create(
214  builder,
215  arith::CmpFOp::create(builder, arith::CmpFPredicate::UGT, value, bound),
216  value, bound);
217 }
218 
219 // Return the clamped value or NaN if value is NaN
220 static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
221  Value upperBound) {
222  return max(builder, min(builder, value, upperBound), lowerBound);
223 }
224 
225 // Decomposes given floating point value `arg` into a normalized fraction and
226 // an integral power of two (see std::frexp). Returned values have float type.
227 static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
228  bool isPositive = false) {
229  assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
230  std::optional<VectorShape> shape = vectorShape(arg);
231 
232  auto bcast = [&](Value value) -> Value {
233  return broadcast(builder, value, shape);
234  };
235 
236  auto i32 = builder.getIntegerType(32);
237  auto i32Vec = broadcast(i32, shape);
238  auto f32Vec = broadcast(builder.getF32Type(), shape);
239 
240  Value cst126f = f32Cst(builder, 126.0f);
241  Value cstHalf = f32Cst(builder, 0.5f);
242  Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u);
243 
244  // Bitcast to i32 for bitwise operations.
245  Value i32Half = arith::BitcastOp::create(builder, i32, cstHalf);
246  Value i32InvMantMask = arith::BitcastOp::create(builder, i32, cstInvMantMask);
247  Value i32Arg = arith::BitcastOp::create(builder, i32Vec, arg);
248 
249  // Compute normalized fraction.
250  Value tmp0 = arith::AndIOp::create(builder, i32Arg, bcast(i32InvMantMask));
251  Value tmp1 = arith::OrIOp::create(builder, tmp0, bcast(i32Half));
252  Value normalizedFraction = arith::BitcastOp::create(builder, f32Vec, tmp1);
253 
254  // Compute exponent.
255  Value arg0 = isPositive ? arg : math::AbsFOp::create(builder, arg);
256  Value biasedExponentBits = arith::ShRUIOp::create(
257  builder, arith::BitcastOp::create(builder, i32Vec, arg0),
258  bcast(i32Cst(builder, 23)));
259  Value biasedExponent =
260  arith::SIToFPOp::create(builder, f32Vec, biasedExponentBits);
261  Value exponent =
262  arith::SubFOp::create(builder, biasedExponent, bcast(cst126f));
263 
264  return {normalizedFraction, exponent};
265 }
266 
267 // Computes exp2 for an i32 argument.
268 static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg) {
269  assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type");
270  std::optional<VectorShape> shape = vectorShape(arg);
271 
272  auto bcast = [&](Value value) -> Value {
273  return broadcast(builder, value, shape);
274  };
275 
276  auto f32Vec = broadcast(builder.getF32Type(), shape);
277  // The exponent of f32 located at 23-bit.
278  auto exponetBitLocation = bcast(i32Cst(builder, 23));
279  // Set the exponent bias to zero.
280  auto bias = bcast(i32Cst(builder, 127));
281 
282  Value biasedArg = arith::AddIOp::create(builder, arg, bias);
283  Value exp2ValueInt =
284  arith::ShLIOp::create(builder, biasedArg, exponetBitLocation);
285  Value exp2ValueF32 = arith::BitcastOp::create(builder, f32Vec, exp2ValueInt);
286 
287  return exp2ValueF32;
288 }
289 
290 namespace {
291 Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
292  llvm::ArrayRef<Value> coeffs, Value x) {
293  Type elementType = getElementTypeOrSelf(x);
294  assert((elementType.isF32() || elementType.isF16()) &&
295  "x must be f32 or f16 type");
296  std::optional<VectorShape> shape = vectorShape(x);
297 
298  if (coeffs.empty())
299  return broadcast(builder, floatCst(builder, 0.0f, elementType), shape);
300 
301  if (coeffs.size() == 1)
302  return coeffs[0];
303 
304  Value res = math::FmaOp::create(builder, x, coeffs[coeffs.size() - 1],
305  coeffs[coeffs.size() - 2]);
306  for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) {
307  res = math::FmaOp::create(builder, x, res, coeffs[i]);
308  }
309  return res;
310 }
311 } // namespace
312 
313 //----------------------------------------------------------------------------//
314 // Helper function/pattern to insert casts for reusing F32 bit expansion.
315 //----------------------------------------------------------------------------//
316 
317 template <typename T>
318 LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) {
319  // Conservatively only allow where the operand and result types are exactly 1.
320  Type origType = op->getResultTypes().front();
321  for (Type t : llvm::drop_begin(op->getResultTypes()))
322  if (origType != t)
323  return rewriter.notifyMatchFailure(op, "required all types to match");
324  for (Type t : op->getOperandTypes())
325  if (origType != t)
326  return rewriter.notifyMatchFailure(op, "required all types to match");
327 
328  // Skip if already F32 or larger than 32 bits.
329  if (getElementTypeOrSelf(origType).isF32() ||
330  getElementTypeOrSelf(origType).getIntOrFloatBitWidth() > 32)
331  return failure();
332 
333  // Create F32 equivalent type.
334  Type newType;
335  if (auto shaped = dyn_cast<ShapedType>(origType)) {
336  newType = shaped.clone(rewriter.getF32Type());
337  } else if (isa<FloatType>(origType)) {
338  newType = rewriter.getF32Type();
339  } else {
340  return rewriter.notifyMatchFailure(op,
341  "unable to find F32 equivalent type");
342  }
343 
344  Location loc = op->getLoc();
345  SmallVector<Value> operands;
346  for (auto operand : op->getOperands())
347  operands.push_back(arith::ExtFOp::create(rewriter, loc, newType, operand));
348  auto result =
349  T::create(rewriter, loc, TypeRange{newType}, operands, op->getAttrs());
350  rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, origType, result);
351  return success();
352 }
353 
354 namespace {
355 // Pattern to cast to F32 to reuse F32 expansion as fallback for single-result
356 // op.
357 // TODO: Consider revising to avoid adding multiple casts for a subgraph that is
358 // all in lower precision. Currently this is only fallback support and performs
359 // simplistic casting.
360 template <typename T>
361 struct ReuseF32Expansion : public OpRewritePattern<T> {
362 public:
364  LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const final {
365  static_assert(
366  T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(),
367  "requires same operands and result types");
368  return insertCasts<T>(op, rewriter);
369  }
370 };
371 } // namespace
372 
373 //----------------------------------------------------------------------------//
374 // AtanOp approximation.
375 //----------------------------------------------------------------------------//
376 
377 namespace {
378 struct AtanApproximation : public OpRewritePattern<math::AtanOp> {
379 public:
381 
382  LogicalResult matchAndRewrite(math::AtanOp op,
383  PatternRewriter &rewriter) const final;
384 };
385 } // namespace
386 
387 LogicalResult
388 AtanApproximation::matchAndRewrite(math::AtanOp op,
389  PatternRewriter &rewriter) const {
390  auto operand = op.getOperand();
391  if (!getElementTypeOrSelf(operand).isF32())
392  return rewriter.notifyMatchFailure(op, "unsupported operand type");
393 
394  std::optional<VectorShape> shape = vectorShape(op.getOperand());
395 
396  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
397  Value abs = math::AbsFOp::create(builder, operand);
398 
399  auto one = broadcast(builder, f32Cst(builder, 1.0), shape);
400 
401  // When 0.66 < x <= 2.41 we do (x-1) / (x+1):
402  auto twoThirds = broadcast(builder, f32Cst(builder, 0.66), shape);
403  Value cmp2 =
404  arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, abs, twoThirds);
405  Value addone = arith::AddFOp::create(builder, abs, one);
406  Value subone = arith::SubFOp::create(builder, abs, one);
407  Value xnum = arith::SelectOp::create(builder, cmp2, subone, abs);
408  Value xden = arith::SelectOp::create(builder, cmp2, addone, one);
409 
410  auto bcast = [&](Value value) -> Value {
411  return broadcast(builder, value, shape);
412  };
413 
414  // Break into the <= 0.66 or > 2.41 we do x or 1/x:
415  auto tan3pio8 = bcast(f32Cst(builder, 2.41421356237309504880));
416  Value cmp1 =
417  arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, abs, tan3pio8);
418  xnum = arith::SelectOp::create(builder, cmp1, one, xnum);
419  xden = arith::SelectOp::create(builder, cmp1, abs, xden);
420 
421  Value x = arith::DivFOp::create(builder, xnum, xden);
422  Value xx = arith::MulFOp::create(builder, x, x);
423 
424  // Perform the Taylor series approximation for atan over the range
425  // [0.0, 0.66].
426  auto p0 = bcast(f32Cst(builder, -8.750608600031904122785e-01));
427  auto p1 = bcast(f32Cst(builder, -1.615753718733365076637e+01));
428  auto p2 = bcast(f32Cst(builder, -7.500855792314704667340e+01));
429  auto p3 = bcast(f32Cst(builder, -1.228866684490136173410e+02));
430  auto p4 = bcast(f32Cst(builder, -6.485021904942025371773e+01));
431  auto q0 = bcast(f32Cst(builder, +2.485846490142306297962e+01));
432  auto q1 = bcast(f32Cst(builder, +1.650270098316988542046e+02));
433  auto q2 = bcast(f32Cst(builder, +4.328810604912902668951e+02));
434  auto q3 = bcast(f32Cst(builder, +4.853903996359136964868e+02));
435  auto q4 = bcast(f32Cst(builder, +1.945506571482613964425e+02));
436 
437  // Apply the polynomial approximation for the numerator:
438  Value n = p0;
439  n = math::FmaOp::create(builder, xx, n, p1);
440  n = math::FmaOp::create(builder, xx, n, p2);
441  n = math::FmaOp::create(builder, xx, n, p3);
442  n = math::FmaOp::create(builder, xx, n, p4);
443  n = arith::MulFOp::create(builder, n, xx);
444 
445  // Apply the polynomial approximation for the denominator:
446  Value d = q0;
447  d = math::FmaOp::create(builder, xx, d, q1);
448  d = math::FmaOp::create(builder, xx, d, q2);
449  d = math::FmaOp::create(builder, xx, d, q3);
450  d = math::FmaOp::create(builder, xx, d, q4);
451 
452  // Compute approximation of theta:
453  Value ans0 = arith::DivFOp::create(builder, n, d);
454  ans0 = math::FmaOp::create(builder, ans0, x, x);
455 
456  // Correct for the input mapping's angles:
457  Value mpi4 = bcast(f32Cst(builder, llvm::numbers::pi / 4));
458  Value ans2 = arith::AddFOp::create(builder, mpi4, ans0);
459  Value ans = arith::SelectOp::create(builder, cmp2, ans2, ans0);
460 
461  Value mpi2 = bcast(f32Cst(builder, llvm::numbers::pi / 2));
462  Value ans1 = arith::SubFOp::create(builder, mpi2, ans0);
463  ans = arith::SelectOp::create(builder, cmp1, ans1, ans);
464 
465  // Correct for signing of the input.
466  rewriter.replaceOpWithNewOp<math::CopySignOp>(op, ans, operand);
467  return success();
468 }
469 
470 //----------------------------------------------------------------------------//
471 // AtanOp approximation.
472 //----------------------------------------------------------------------------//
473 
474 namespace {
475 struct Atan2Approximation : public OpRewritePattern<math::Atan2Op> {
476 public:
478 
479  LogicalResult matchAndRewrite(math::Atan2Op op,
480  PatternRewriter &rewriter) const final;
481 };
482 } // namespace
483 
484 LogicalResult
485 Atan2Approximation::matchAndRewrite(math::Atan2Op op,
486  PatternRewriter &rewriter) const {
487  auto y = op.getOperand(0);
488  auto x = op.getOperand(1);
489  if (!getElementTypeOrSelf(x).isF32())
490  return rewriter.notifyMatchFailure(op, "unsupported operand type");
491 
492  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
493  std::optional<VectorShape> shape = vectorShape(op.getResult());
494 
495  // Compute atan in the valid range.
496  auto div = arith::DivFOp::create(builder, y, x);
497  auto atan = math::AtanOp::create(builder, div);
498 
499  // Determine what the atan would be for a 180 degree rotation.
500  auto zero = broadcast(builder, f32Cst(builder, 0.0f), shape);
501  auto pi = broadcast(builder, f32Cst(builder, 3.14159265359f), shape);
502  auto addPi = arith::AddFOp::create(builder, atan, pi);
503  auto subPi = arith::SubFOp::create(builder, atan, pi);
504  auto atanGt =
505  arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, atan, zero);
506  auto flippedAtan = arith::SelectOp::create(builder, atanGt, subPi, addPi);
507 
508  // Determine whether to directly use atan or use the 180 degree flip
509  auto xGt = arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, x, zero);
510  Value result = arith::SelectOp::create(builder, xGt, atan, flippedAtan);
511 
512  // Handle x = 0, y > 0
513  Value xZero =
514  arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, x, zero);
515  Value yGt =
516  arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, y, zero);
517  Value isHalfPi = arith::AndIOp::create(builder, xZero, yGt);
518  auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape);
519  result = arith::SelectOp::create(builder, isHalfPi, halfPi, result);
520 
521  // Handle x = 0, y < 0
522  Value yLt =
523  arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, y, zero);
524  Value isNegativeHalfPiPi = arith::AndIOp::create(builder, xZero, yLt);
525  auto negativeHalfPiPi =
526  broadcast(builder, f32Cst(builder, -1.57079632679f), shape);
527  result = arith::SelectOp::create(builder, isNegativeHalfPiPi,
528  negativeHalfPiPi, result);
529 
530  // Handle x = 0, y = 0;
531  Value yZero =
532  arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, y, zero);
533  Value isNan = arith::AndIOp::create(builder, xZero, yZero);
534  Value cstNan = broadcast(builder, f32FromBits(builder, 0x7fc00000), shape);
535  result = arith::SelectOp::create(builder, isNan, cstNan, result);
536 
537  rewriter.replaceOp(op, result);
538  return success();
539 }
540 
541 //----------------------------------------------------------------------------//
542 // TanhOp approximation.
543 //----------------------------------------------------------------------------//
544 
545 namespace {
546 struct TanhApproximation : public OpRewritePattern<math::TanhOp> {
547 public:
549 
550  LogicalResult matchAndRewrite(math::TanhOp op,
551  PatternRewriter &rewriter) const final;
552 };
553 } // namespace
554 
555 LogicalResult
556 TanhApproximation::matchAndRewrite(math::TanhOp op,
557  PatternRewriter &rewriter) const {
558  if (!getElementTypeOrSelf(op.getOperand()).isF32())
559  return rewriter.notifyMatchFailure(op, "unsupported operand type");
560 
561  std::optional<VectorShape> shape = vectorShape(op.getOperand());
562 
563  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
564  auto bcast = [&](Value value) -> Value {
565  return broadcast(builder, value, shape);
566  };
567 
568  // Clamp operand into [plusClamp, minusClamp] range.
569  Value minusClamp = bcast(f32Cst(builder, -7.99881172180175781f));
570  Value plusClamp = bcast(f32Cst(builder, 7.99881172180175781f));
571  Value x = clamp(builder, op.getOperand(), minusClamp, plusClamp);
572 
573  // Mask for tiny values that are approximated with `operand`.
574  Value tiny = bcast(f32Cst(builder, 0.0004f));
575  Value tinyMask = arith::CmpFOp::create(
576  builder, arith::CmpFPredicate::OLT,
577  math::AbsFOp::create(builder, op.getOperand()), tiny);
578 
579  // The monomial coefficients of the numerator polynomial (odd).
580  Value alpha1 = bcast(f32Cst(builder, 4.89352455891786e-03f));
581  Value alpha3 = bcast(f32Cst(builder, 6.37261928875436e-04f));
582  Value alpha5 = bcast(f32Cst(builder, 1.48572235717979e-05f));
583  Value alpha7 = bcast(f32Cst(builder, 5.12229709037114e-08f));
584  Value alpha9 = bcast(f32Cst(builder, -8.60467152213735e-11f));
585  Value alpha11 = bcast(f32Cst(builder, 2.00018790482477e-13f));
586  Value alpha13 = bcast(f32Cst(builder, -2.76076847742355e-16f));
587 
588  // The monomial coefficients of the denominator polynomial (even).
589  Value beta0 = bcast(f32Cst(builder, 4.89352518554385e-03f));
590  Value beta2 = bcast(f32Cst(builder, 2.26843463243900e-03f));
591  Value beta4 = bcast(f32Cst(builder, 1.18534705686654e-04f));
592  Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f));
593 
594  // Since the polynomials are odd/even, we need x^2.
595  Value x2 = arith::MulFOp::create(builder, x, x);
596 
597  // Evaluate the numerator polynomial p.
598  Value p = math::FmaOp::create(builder, x2, alpha13, alpha11);
599  p = math::FmaOp::create(builder, x2, p, alpha9);
600  p = math::FmaOp::create(builder, x2, p, alpha7);
601  p = math::FmaOp::create(builder, x2, p, alpha5);
602  p = math::FmaOp::create(builder, x2, p, alpha3);
603  p = math::FmaOp::create(builder, x2, p, alpha1);
604  p = arith::MulFOp::create(builder, x, p);
605 
606  // Evaluate the denominator polynomial q.
607  Value q = math::FmaOp::create(builder, x2, beta6, beta4);
608  q = math::FmaOp::create(builder, x2, q, beta2);
609  q = math::FmaOp::create(builder, x2, q, beta0);
610 
611  // Divide the numerator by the denominator.
612  Value res = arith::SelectOp::create(builder, tinyMask, x,
613  arith::DivFOp::create(builder, p, q));
614 
615  rewriter.replaceOp(op, res);
616 
617  return success();
618 }
619 
620 #define LN2_VALUE \
621  0.693147180559945309417232121458176568075500134360255254120680009493393621L
622 #define LOG2E_VALUE \
623  1.442695040888963407359924681001892137426645954152985934135449406931109219L
624 
625 //----------------------------------------------------------------------------//
626 // LogOp and Log2Op approximation.
627 //----------------------------------------------------------------------------//
628 
629 namespace {
630 template <typename Op>
631 struct LogApproximationBase : public OpRewritePattern<Op> {
633 
634  /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise.
635  LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter,
636  bool base2) const;
637 };
638 } // namespace
639 
640 // This approximation comes from Julien Pommier's SSE math library.
641 // Link: http://gruntthepeon.free.fr/ssemath
642 template <typename Op>
643 LogicalResult
644 LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
645  bool base2) const {
646  if (!getElementTypeOrSelf(op.getOperand()).isF32())
647  return rewriter.notifyMatchFailure(op, "unsupported operand type");
648 
649  std::optional<VectorShape> shape = vectorShape(op.getOperand());
650 
651  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
652  auto bcast = [&](Value value) -> Value {
653  return broadcast(builder, value, shape);
654  };
655 
656  Value cstZero = bcast(f32Cst(builder, 0.0f));
657  Value cstOne = bcast(f32Cst(builder, 1.0f));
658  Value cstNegHalf = bcast(f32Cst(builder, -0.5f));
659 
660  // The smallest non denormalized float number.
661  Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u));
662  Value cstMinusInf = bcast(f32FromBits(builder, 0xff800000u));
663  Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));
664  Value cstNan = bcast(f32FromBits(builder, 0x7fc00000));
665 
666  // Polynomial coefficients.
667  Value cstCephesSQRTHF = bcast(f32Cst(builder, 0.707106781186547524f));
668  Value cstCephesLogP0 = bcast(f32Cst(builder, 7.0376836292E-2f));
669  Value cstCephesLogP1 = bcast(f32Cst(builder, -1.1514610310E-1f));
670  Value cstCephesLogP2 = bcast(f32Cst(builder, 1.1676998740E-1f));
671  Value cstCephesLogP3 = bcast(f32Cst(builder, -1.2420140846E-1f));
672  Value cstCephesLogP4 = bcast(f32Cst(builder, +1.4249322787E-1f));
673  Value cstCephesLogP5 = bcast(f32Cst(builder, -1.6668057665E-1f));
674  Value cstCephesLogP6 = bcast(f32Cst(builder, +2.0000714765E-1f));
675  Value cstCephesLogP7 = bcast(f32Cst(builder, -2.4999993993E-1f));
676  Value cstCephesLogP8 = bcast(f32Cst(builder, +3.3333331174E-1f));
677 
678  Value x = op.getOperand();
679 
680  // Truncate input values to the minimum positive normal.
681  x = max(builder, x, cstMinNormPos);
682 
683  // Extract significant in the range [0.5,1) and exponent.
684  std::pair<Value, Value> pair = frexp(builder, x, /*isPositive=*/true);
685  x = pair.first;
686  Value e = pair.second;
687 
688  // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift
689  // by -1.0. The values are then centered around 0, which improves the
690  // stability of the polynomial evaluation:
691  //
692  // if( x < SQRTHF ) {
693  // e -= 1;
694  // x = x + x - 1.0;
695  // } else { x = x - 1.0; }
696  Value mask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, x,
697  cstCephesSQRTHF);
698  Value tmp = arith::SelectOp::create(builder, mask, x, cstZero);
699 
700  x = arith::SubFOp::create(builder, x, cstOne);
701  e = arith::SubFOp::create(
702  builder, e, arith::SelectOp::create(builder, mask, cstOne, cstZero));
703  x = arith::AddFOp::create(builder, x, tmp);
704 
705  Value x2 = arith::MulFOp::create(builder, x, x);
706  Value x3 = arith::MulFOp::create(builder, x2, x);
707 
708  // Evaluate the polynomial approximant of degree 8 in three parts.
709  Value y0, y1, y2;
710  y0 = math::FmaOp::create(builder, cstCephesLogP0, x, cstCephesLogP1);
711  y1 = math::FmaOp::create(builder, cstCephesLogP3, x, cstCephesLogP4);
712  y2 = math::FmaOp::create(builder, cstCephesLogP6, x, cstCephesLogP7);
713  y0 = math::FmaOp::create(builder, y0, x, cstCephesLogP2);
714  y1 = math::FmaOp::create(builder, y1, x, cstCephesLogP5);
715  y2 = math::FmaOp::create(builder, y2, x, cstCephesLogP8);
716  y0 = math::FmaOp::create(builder, y0, x3, y1);
717  y0 = math::FmaOp::create(builder, y0, x3, y2);
718  y0 = arith::MulFOp::create(builder, y0, x3);
719 
720  y0 = math::FmaOp::create(builder, cstNegHalf, x2, y0);
721  x = arith::AddFOp::create(builder, x, y0);
722 
723  if (base2) {
724  Value cstLog2e = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE)));
725  x = math::FmaOp::create(builder, x, cstLog2e, e);
726  } else {
727  Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE)));
728  x = math::FmaOp::create(builder, e, cstLn2, x);
729  }
730 
731  Value invalidMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::ULT,
732  op.getOperand(), cstZero);
733  Value zeroMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ,
734  op.getOperand(), cstZero);
735  Value posInfMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ,
736  op.getOperand(), cstPosInf);
737 
738  // Filter out invalid values:
739  // • x == 0 -> -INF
740  // • x < 0 -> NAN
741  // • x == +INF -> +INF
742  Value aproximation = arith::SelectOp::create(
743  builder, zeroMask, cstMinusInf,
744  arith::SelectOp::create(
745  builder, invalidMask, cstNan,
746  arith::SelectOp::create(builder, posInfMask, cstPosInf, x)));
747 
748  rewriter.replaceOp(op, aproximation);
749 
750  return success();
751 }
752 
753 namespace {
754 struct LogApproximation : public LogApproximationBase<math::LogOp> {
755  using LogApproximationBase::LogApproximationBase;
756 
757  LogicalResult matchAndRewrite(math::LogOp op,
758  PatternRewriter &rewriter) const final {
759  return logMatchAndRewrite(op, rewriter, /*base2=*/false);
760  }
761 };
762 } // namespace
763 
764 namespace {
765 struct Log2Approximation : public LogApproximationBase<math::Log2Op> {
766  using LogApproximationBase::LogApproximationBase;
767 
768  LogicalResult matchAndRewrite(math::Log2Op op,
769  PatternRewriter &rewriter) const final {
770  return logMatchAndRewrite(op, rewriter, /*base2=*/true);
771  }
772 };
773 } // namespace
774 
775 //----------------------------------------------------------------------------//
776 // Log1p approximation.
777 //----------------------------------------------------------------------------//
778 
779 namespace {
780 struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> {
781 public:
783 
784  LogicalResult matchAndRewrite(math::Log1pOp op,
785  PatternRewriter &rewriter) const final;
786 };
787 } // namespace
788 
789 // Approximate log(1+x).
790 LogicalResult
791 Log1pApproximation::matchAndRewrite(math::Log1pOp op,
792  PatternRewriter &rewriter) const {
793  if (!getElementTypeOrSelf(op.getOperand()).isF32())
794  return rewriter.notifyMatchFailure(op, "unsupported operand type");
795 
796  std::optional<VectorShape> shape = vectorShape(op.getOperand());
797 
798  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
799  auto bcast = [&](Value value) -> Value {
800  return broadcast(builder, value, shape);
801  };
802 
803  // Approximate log(1+x) using the following, due to W. Kahan:
804  // u = x + 1.0;
805  // if (u == 1.0 || u == inf) return x;
806  // return x * log(u) / (u - 1.0);
807  // ^^^^^^^^^^^^^^^^^^^^^^
808  // "logLarge" below.
809  Value cstOne = bcast(f32Cst(builder, 1.0f));
810  Value x = op.getOperand();
811  Value u = arith::AddFOp::create(builder, x, cstOne);
812  Value uSmall =
813  arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, u, cstOne);
814  Value logU = math::LogOp::create(builder, u);
815  Value uInf =
816  arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, u, logU);
817  Value logLarge = arith::MulFOp::create(
818  builder, x,
819  arith::DivFOp::create(builder, logU,
820  arith::SubFOp::create(builder, u, cstOne)));
821  Value approximation = arith::SelectOp::create(
822  builder, arith::OrIOp::create(builder, uSmall, uInf), x, logLarge);
823  rewriter.replaceOp(op, approximation);
824  return success();
825 }
826 
827 //----------------------------------------------------------------------------//
828 // Asin approximation.
829 //----------------------------------------------------------------------------//
830 
831 // Approximates asin(x).
832 // This approximation is based on the following stackoverflow post:
833 // https://stackoverflow.com/a/42683455
834 namespace {
835 struct AsinPolynomialApproximation : public OpRewritePattern<math::AsinOp> {
836 public:
838 
839  LogicalResult matchAndRewrite(math::AsinOp op,
840  PatternRewriter &rewriter) const final;
841 };
842 } // namespace
843 LogicalResult
844 AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
845  PatternRewriter &rewriter) const {
846  Value operand = op.getOperand();
847  Type elementType = getElementTypeOrSelf(operand);
848 
849  if (!(elementType.isF32() || elementType.isF16()))
850  return rewriter.notifyMatchFailure(op,
851  "only f32 and f16 type is supported.");
852  std::optional<VectorShape> shape = vectorShape(operand);
853 
854  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
855  auto bcast = [&](Value value) -> Value {
856  return broadcast(builder, value, shape);
857  };
858 
859  auto fma = [&](Value a, Value b, Value c) -> Value {
860  return math::FmaOp::create(builder, a, b, c);
861  };
862 
863  auto mul = [&](Value a, Value b) -> Value {
864  return arith::MulFOp::create(builder, a, b);
865  };
866 
867  auto sub = [&](Value a, Value b) -> Value {
868  return arith::SubFOp::create(builder, a, b);
869  };
870 
871  auto abs = [&](Value a) -> Value { return math::AbsFOp::create(builder, a); };
872 
873  auto sqrt = [&](Value a) -> Value {
874  return math::SqrtOp::create(builder, a);
875  };
876 
877  auto scopy = [&](Value a, Value b) -> Value {
878  return math::CopySignOp::create(builder, a, b);
879  };
880 
881  auto sel = [&](Value a, Value b, Value c) -> Value {
882  return arith::SelectOp::create(builder, a, b, c);
883  };
884 
885  Value abso = abs(operand);
886  Value aa = mul(operand, operand);
887  Value opp = sqrt(sub(bcast(floatCst(builder, 1.0, elementType)), aa));
888 
889  Value gt = arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, aa,
890  bcast(floatCst(builder, 0.5, elementType)));
891 
892  Value x = sel(gt, opp, abso);
893 
894  // Asin(x) approximation for x = [-9/16, 9/16]:
895  Value s = mul(x, x);
896  Value q = mul(s, s);
897  Value r = bcast(floatCst(builder, 5.5579749017470502e-2, elementType));
898  Value t = bcast(floatCst(builder, -6.2027913464120114e-2, elementType));
899 
900  r = fma(r, q, bcast(floatCst(builder, 5.4224464349245036e-2, elementType)));
901  t = fma(t, q, bcast(floatCst(builder, -1.1326992890324464e-2, elementType)));
902  r = fma(r, q, bcast(floatCst(builder, 1.5268872539397656e-2, elementType)));
903  t = fma(t, q, bcast(floatCst(builder, 1.0493798473372081e-2, elementType)));
904  r = fma(r, q, bcast(floatCst(builder, 1.4106045900607047e-2, elementType)));
905  t = fma(t, q, bcast(floatCst(builder, 1.7339776384962050e-2, elementType)));
906  r = fma(r, q, bcast(floatCst(builder, 2.2372961589651054e-2, elementType)));
907  t = fma(t, q, bcast(floatCst(builder, 3.0381912707941005e-2, elementType)));
908  r = fma(r, q, bcast(floatCst(builder, 4.4642857881094775e-2, elementType)));
909  t = fma(t, q, bcast(floatCst(builder, 7.4999999991367292e-2, elementType)));
910  r = fma(r, s, t);
911  r = fma(r, s, bcast(floatCst(builder, 1.6666666666670193e-1, elementType)));
912  t = mul(x, s);
913  r = fma(r, t, x);
914 
915  Value rsub = sub(bcast(floatCst(builder, 1.57079632679, elementType)), r);
916  r = sel(gt, rsub, r);
917  r = scopy(r, operand);
918 
919  rewriter.replaceOp(op, r);
920  return success();
921 }
922 
923 //----------------------------------------------------------------------------//
924 // Acos approximation.
925 //----------------------------------------------------------------------------//
926 
927 // Approximates acos(x).
928 // This approximation is based on the following stackoverflow post:
929 // https://stackoverflow.com/a/42683455
930 namespace {
931 struct AcosPolynomialApproximation : public OpRewritePattern<math::AcosOp> {
932 public:
934 
935  LogicalResult matchAndRewrite(math::AcosOp op,
936  PatternRewriter &rewriter) const final;
937 };
938 } // namespace
939 LogicalResult
940 AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
941  PatternRewriter &rewriter) const {
942  Value operand = op.getOperand();
943  Type elementType = getElementTypeOrSelf(operand);
944 
945  if (!(elementType.isF32() || elementType.isF16()))
946  return rewriter.notifyMatchFailure(op,
947  "only f32 and f16 type is supported.");
948  std::optional<VectorShape> shape = vectorShape(operand);
949 
950  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
951  auto bcast = [&](Value value) -> Value {
952  return broadcast(builder, value, shape);
953  };
954 
955  auto fma = [&](Value a, Value b, Value c) -> Value {
956  return math::FmaOp::create(builder, a, b, c);
957  };
958 
959  auto mul = [&](Value a, Value b) -> Value {
960  return arith::MulFOp::create(builder, a, b);
961  };
962 
963  Value negOperand = arith::NegFOp::create(builder, operand);
964  Value zero = bcast(floatCst(builder, 0.0, elementType));
965  Value half = bcast(floatCst(builder, 0.5, elementType));
966  Value negOne = bcast(floatCst(builder, -1.0, elementType));
967  Value selR =
968  arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, operand, zero);
969  Value r = arith::SelectOp::create(builder, selR, negOperand, operand);
970  Value chkConst = bcast(floatCst(builder, -0.5625, elementType));
971  Value firstPred =
972  arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, r, chkConst);
973 
974  Value trueVal =
975  fma(bcast(floatCst(builder, 9.3282184640716537e-1, elementType)),
976  bcast(floatCst(builder, 1.6839188885261840e+0, elementType)),
977  math::AsinOp::create(builder, r));
978 
979  Value falseVal = math::SqrtOp::create(builder, fma(half, r, half));
980  falseVal = math::AsinOp::create(builder, falseVal);
981  falseVal = mul(bcast(floatCst(builder, 2.0, elementType)), falseVal);
982 
983  r = arith::SelectOp::create(builder, firstPred, trueVal, falseVal);
984 
985  // Check whether the operand lies in between [-1.0, 0.0).
986  Value greaterThanNegOne = arith::CmpFOp::create(
987  builder, arith::CmpFPredicate::OGE, operand, negOne);
988 
989  Value lessThanZero =
990  arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, operand, zero);
991 
992  Value betweenNegOneZero =
993  arith::AndIOp::create(builder, greaterThanNegOne, lessThanZero);
994 
995  trueVal = fma(bcast(floatCst(builder, 1.8656436928143307e+0, elementType)),
996  bcast(floatCst(builder, 1.6839188885261840e+0, elementType)),
997  arith::NegFOp::create(builder, r));
998 
999  Value finalVal =
1000  arith::SelectOp::create(builder, betweenNegOneZero, trueVal, r);
1001 
1002  rewriter.replaceOp(op, finalVal);
1003  return success();
1004 }
1005 
1006 //----------------------------------------------------------------------------//
1007 // Erf approximation.
1008 //----------------------------------------------------------------------------//
1009 
1010 // Approximates erf(x) with
1011 // a - P(x)/Q(x)
1012 // where P and Q are polynomials of degree 4.
1013 // Different coefficients are chosen based on the value of x.
1014 // The approximation error is ~2.5e-07.
1015 // Boost's minimax tool that utilizes the Remez method was used to find the
1016 // coefficients.
1017 LogicalResult
1019  PatternRewriter &rewriter) const {
1020  Value operand = op.getOperand();
1021  Type elementType = getElementTypeOrSelf(operand);
1022 
1023  if (!(elementType.isF32() || elementType.isF16()))
1024  return rewriter.notifyMatchFailure(op,
1025  "only f32 and f16 type is supported.");
1026  std::optional<VectorShape> shape = vectorShape(operand);
1027 
1028  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1029  auto bcast = [&](Value value) -> Value {
1030  return broadcast(builder, value, shape);
1031  };
1032 
1033  const int intervalsCount = 3;
1034  const int polyDegree = 4;
1035 
1036  Value zero = bcast(floatCst(builder, 0, elementType));
1037  Value one = bcast(floatCst(builder, 1, elementType));
1038  Value pp[intervalsCount][polyDegree + 1];
1039  pp[0][0] = bcast(floatCst(builder, +0.00000000000000000e+00f, elementType));
1040  pp[0][1] = bcast(floatCst(builder, +1.12837916222975858e+00f, elementType));
1041  pp[0][2] = bcast(floatCst(builder, -5.23018562988006470e-01f, elementType));
1042  pp[0][3] = bcast(floatCst(builder, +2.09741709609267072e-01f, elementType));
1043  pp[0][4] = bcast(floatCst(builder, +2.58146801602987875e-02f, elementType));
1044  pp[1][0] = bcast(floatCst(builder, +0.00000000000000000e+00f, elementType));
1045  pp[1][1] = bcast(floatCst(builder, +1.12750687816789140e+00f, elementType));
1046  pp[1][2] = bcast(floatCst(builder, -3.64721408487825775e-01f, elementType));
1047  pp[1][3] = bcast(floatCst(builder, +1.18407396425136952e-01f, elementType));
1048  pp[1][4] = bcast(floatCst(builder, +3.70645533056476558e-02f, elementType));
1049  pp[2][0] = bcast(floatCst(builder, -3.30093071049483172e-03f, elementType));
1050  pp[2][1] = bcast(floatCst(builder, +3.51961938357697011e-03f, elementType));
1051  pp[2][2] = bcast(floatCst(builder, -1.41373622814988039e-03f, elementType));
1052  pp[2][3] = bcast(floatCst(builder, +2.53447094961941348e-04f, elementType));
1053  pp[2][4] = bcast(floatCst(builder, -1.71048029455037401e-05f, elementType));
1054 
1055  Value qq[intervalsCount][polyDegree + 1];
1056  qq[0][0] = bcast(floatCst(builder, +1.000000000000000000e+00f, elementType));
1057  qq[0][1] = bcast(floatCst(builder, -4.635138185962547255e-01f, elementType));
1058  qq[0][2] = bcast(floatCst(builder, +5.192301327279782447e-01f, elementType));
1059  qq[0][3] = bcast(floatCst(builder, -1.318089722204810087e-01f, elementType));
1060  qq[0][4] = bcast(floatCst(builder, +7.397964654672315005e-02f, elementType));
1061  qq[1][0] = bcast(floatCst(builder, +1.00000000000000000e+00f, elementType));
1062  qq[1][1] = bcast(floatCst(builder, -3.27607011824493086e-01f, elementType));
1063  qq[1][2] = bcast(floatCst(builder, +4.48369090658821977e-01f, elementType));
1064  qq[1][3] = bcast(floatCst(builder, -8.83462621207857930e-02f, elementType));
1065  qq[1][4] = bcast(floatCst(builder, +5.72442770283176093e-02f, elementType));
1066  qq[2][0] = bcast(floatCst(builder, +1.00000000000000000e+00f, elementType));
1067  qq[2][1] = bcast(floatCst(builder, -2.06069165953913769e+00f, elementType));
1068  qq[2][2] = bcast(floatCst(builder, +1.62705939945477759e+00f, elementType));
1069  qq[2][3] = bcast(floatCst(builder, -5.83389859211130017e-01f, elementType));
1070  qq[2][4] = bcast(floatCst(builder, +8.21908939856640930e-02f, elementType));
1071 
1072  Value offsets[intervalsCount];
1073  offsets[0] = bcast(floatCst(builder, 0.0f, elementType));
1074  offsets[1] = bcast(floatCst(builder, 0.0f, elementType));
1075  offsets[2] = bcast(floatCst(builder, 1.0f, elementType));
1076 
1077  Value bounds[intervalsCount];
1078  bounds[0] = bcast(floatCst(builder, 0.8f, elementType));
1079  bounds[1] = bcast(floatCst(builder, 2.0f, elementType));
1080  bounds[2] = bcast(floatCst(builder, 3.75f, elementType));
1081 
1082  Value isNegativeArg =
1083  arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, operand, zero);
1084  Value negArg = arith::NegFOp::create(builder, operand);
1085  Value x = arith::SelectOp::create(builder, isNegativeArg, negArg, operand);
1086 
1087  Value offset = offsets[0];
1088  Value p[polyDegree + 1];
1089  Value q[polyDegree + 1];
1090  for (int i = 0; i <= polyDegree; ++i) {
1091  p[i] = pp[0][i];
1092  q[i] = qq[0][i];
1093  }
1094 
1095  // TODO: maybe use vector stacking to reduce the number of selects.
1096  Value isLessThanBound[intervalsCount];
1097  for (int j = 0; j < intervalsCount - 1; ++j) {
1098  isLessThanBound[j] =
1099  arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, x, bounds[j]);
1100  for (int i = 0; i <= polyDegree; ++i) {
1101  p[i] = arith::SelectOp::create(builder, isLessThanBound[j], p[i],
1102  pp[j + 1][i]);
1103  q[i] = arith::SelectOp::create(builder, isLessThanBound[j], q[i],
1104  qq[j + 1][i]);
1105  }
1106  offset = arith::SelectOp::create(builder, isLessThanBound[j], offset,
1107  offsets[j + 1]);
1108  }
1109  isLessThanBound[intervalsCount - 1] = arith::CmpFOp::create(
1110  builder, arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
1111 
1112  Value pPoly = makePolynomialCalculation(builder, p, x);
1113  Value qPoly = makePolynomialCalculation(builder, q, x);
1114  Value rationalPoly = arith::DivFOp::create(builder, pPoly, qPoly);
1115  Value formula = arith::AddFOp::create(builder, offset, rationalPoly);
1116  formula = arith::SelectOp::create(
1117  builder, isLessThanBound[intervalsCount - 1], formula, one);
1118 
1119  // erf is odd function: erf(x) = -erf(-x).
1120  Value negFormula = arith::NegFOp::create(builder, formula);
1121  Value res =
1122  arith::SelectOp::create(builder, isNegativeArg, negFormula, formula);
1123 
1124  rewriter.replaceOp(op, res);
1125 
1126  return success();
1127 }
1128 
1129 // Approximates erfc(x) with p((x - 2) / (x + 2)), where p is a 9 degree
1130 // polynomial.This approximation is based on the following stackoverflow post:
1131 // https://stackoverflow.com/questions/35966695/vectorizable-implementation-of-complementary-error-function-erfcf
1132 // The stackoverflow post is in turn based on:
1133 // M. M. Shepherd and J. G. Laframboise, "Chebyshev Approximation of
1134 // (1+2x)exp(x^2)erfc x in 0 <= x < INF", Mathematics of Computation, Vol. 36,
1135 // No. 153, January 1981, pp. 249-253.
1136 //
1137 // Maximum error: 2.65 ulps
1138 LogicalResult
1140  PatternRewriter &rewriter) const {
1141  Value x = op.getOperand();
1142  Type et = getElementTypeOrSelf(x);
1143 
1144  if (!et.isF32())
1145  return rewriter.notifyMatchFailure(op, "only f32 type is supported.");
1146  std::optional<VectorShape> shape = vectorShape(x);
1147 
1148  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1149  auto bcast = [&](Value value) -> Value {
1150  return broadcast(builder, value, shape);
1151  };
1152 
1153  Value trueValue = bcast(boolCst(builder, true));
1154  Value zero = bcast(floatCst(builder, 0.0f, et));
1155  Value one = bcast(floatCst(builder, 1.0f, et));
1156  Value onehalf = bcast(floatCst(builder, 0.5f, et));
1157  Value neg4 = bcast(floatCst(builder, -4.0f, et));
1158  Value neg2 = bcast(floatCst(builder, -2.0f, et));
1159  Value pos2 = bcast(floatCst(builder, 2.0f, et));
1160  Value posInf = bcast(floatCst(builder, INFINITY, et));
1161  Value clampVal = bcast(floatCst(builder, 10.0546875f, et));
1162 
1163  Value a = math::AbsFOp::create(builder, x);
1164  Value p = arith::AddFOp::create(builder, a, pos2);
1165  Value r = arith::DivFOp::create(builder, one, p);
1166  Value q = math::FmaOp::create(builder, neg4, r, one);
1167  Value t = math::FmaOp::create(builder, arith::AddFOp::create(builder, q, one),
1168  neg2, a);
1169  Value e =
1170  math::FmaOp::create(builder, arith::NegFOp::create(builder, a), q, t);
1171  q = math::FmaOp::create(builder, r, e, q);
1172 
1173  p = bcast(floatCst(builder, -0x1.a4a000p-12f, et)); // -4.01139259e-4
1174  Value c1 = bcast(floatCst(builder, -0x1.42a260p-10f, et)); // -1.23075210e-3
1175  p = math::FmaOp::create(builder, p, q, c1);
1176  Value c2 = bcast(floatCst(builder, 0x1.585714p-10f, et)); // 1.31355342e-3
1177  p = math::FmaOp::create(builder, p, q, c2);
1178  Value c3 = bcast(floatCst(builder, 0x1.1adcc4p-07f, et)); // 8.63227434e-3
1179  p = math::FmaOp::create(builder, p, q, c3);
1180  Value c4 = bcast(floatCst(builder, -0x1.081b82p-07f, et)); // -8.05991981e-3
1181  p = math::FmaOp::create(builder, p, q, c4);
1182  Value c5 = bcast(floatCst(builder, -0x1.bc0b6ap-05f, et)); // -5.42046614e-2
1183  p = math::FmaOp::create(builder, p, q, c5);
1184  Value c6 = bcast(floatCst(builder, 0x1.4ffc46p-03f, et)); // 1.64055392e-1
1185  p = math::FmaOp::create(builder, p, q, c6);
1186  Value c7 = bcast(floatCst(builder, -0x1.540840p-03f, et)); // -1.66031361e-1
1187  p = math::FmaOp::create(builder, p, q, c7);
1188  Value c8 = bcast(floatCst(builder, -0x1.7bf616p-04f, et)); // -9.27639827e-2
1189  p = math::FmaOp::create(builder, p, q, c8);
1190  Value c9 = bcast(floatCst(builder, 0x1.1ba03ap-02f, et)); // 2.76978403e-1
1191  p = math::FmaOp::create(builder, p, q, c9);
1192 
1193  Value d = math::FmaOp::create(builder, pos2, a, one);
1194  r = arith::DivFOp::create(builder, one, d);
1195  q = math::FmaOp::create(builder, p, r, r);
1196  Value negfa = arith::NegFOp::create(builder, a);
1197  Value fmaqah = math::FmaOp::create(builder, q, negfa, onehalf);
1198  Value psubq = arith::SubFOp::create(builder, p, q);
1199  e = math::FmaOp::create(builder, fmaqah, pos2, psubq);
1200  r = math::FmaOp::create(builder, e, r, q);
1201 
1202  Value s = arith::MulFOp::create(builder, a, a);
1203  e = math::ExpOp::create(builder, arith::NegFOp::create(builder, s));
1204 
1205  t = math::FmaOp::create(builder, arith::NegFOp::create(builder, a), a, s);
1206  r = math::FmaOp::create(
1207  builder, r, e,
1208  arith::MulFOp::create(builder, arith::MulFOp::create(builder, r, e), t));
1209 
1210  Value isNotLessThanInf = arith::XOrIOp::create(
1211  builder,
1212  arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, a, posInf),
1213  trueValue);
1214  r = arith::SelectOp::create(builder, isNotLessThanInf,
1215  arith::AddFOp::create(builder, x, x), r);
1216  Value isGreaterThanClamp =
1217  arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, a, clampVal);
1218  r = arith::SelectOp::create(builder, isGreaterThanClamp, zero, r);
1219 
1220  Value isNegative =
1221  arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, x, zero);
1222  r = arith::SelectOp::create(builder, isNegative,
1223  arith::SubFOp::create(builder, pos2, r), r);
1224 
1225  rewriter.replaceOp(op, r);
1226  return success();
1227 }
1228 //----------------------------------------------------------------------------//
1229 // Exp approximation.
1230 //----------------------------------------------------------------------------//
1231 
1232 namespace {
1233 
1234 Value clampWithNormals(ImplicitLocOpBuilder &builder,
1235  const std::optional<VectorShape> shape, Value value,
1236  float lowerBound, float upperBound) {
1237  assert(!std::isnan(lowerBound));
1238  assert(!std::isnan(upperBound));
1239 
1240  auto bcast = [&](Value value) -> Value {
1241  return broadcast(builder, value, shape);
1242  };
1243 
1244  auto selectCmp = [&builder](auto pred, Value value, Value bound) {
1245  return arith::SelectOp::create(
1246  builder, arith::CmpFOp::create(builder, pred, value, bound), value,
1247  bound);
1248  };
1249 
1250  // Note: prefer UGE/ULE vs. UGT/ULT, since they generate vmaxps/vminps vs.
1251  // vcmpleps+vmovaps on x86_64. The latter outcome is also obtained with
1252  // arith::{Max,Min}FOp.
1253  value = selectCmp(arith::CmpFPredicate::UGE, value,
1254  bcast(f32Cst(builder, lowerBound)));
1255  value = selectCmp(arith::CmpFPredicate::ULE, value,
1256  bcast(f32Cst(builder, upperBound)));
1257  return value;
1258 }
1259 
1260 struct ExpApproximation : public OpRewritePattern<math::ExpOp> {
1261 public:
1263 
1264  LogicalResult matchAndRewrite(math::ExpOp op,
1265  PatternRewriter &rewriter) const final;
1266 };
1267 
1268 LogicalResult
1269 ExpApproximation::matchAndRewrite(math::ExpOp op,
1270  PatternRewriter &rewriter) const {
1271  auto shape = vectorShape(op.getOperand().getType());
1272  auto elementTy = getElementTypeOrSelf(op.getType());
1273  if (!elementTy.isF32())
1274  return rewriter.notifyMatchFailure(op, "unsupported operand type");
1275 
1276  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1277 
1278  auto add = [&](Value a, Value b) -> Value {
1279  return arith::AddFOp::create(builder, a, b);
1280  };
1281  auto bcast = [&](Value value) -> Value {
1282  return broadcast(builder, value, shape);
1283  };
1284  auto floor = [&](Value a) { return math::FloorOp::create(builder, a); };
1285  auto fmla = [&](Value a, Value b, Value c) {
1286  return math::FmaOp::create(builder, a, b, c);
1287  };
1288  auto mul = [&](Value a, Value b) -> Value {
1289  return arith::MulFOp::create(builder, a, b);
1290  };
1291 
1292  // Polynomial approximation from Cephes.
1293  //
1294  // To compute e^x, we re-express it as
1295  //
1296  // e^x = e^(a + b)
1297  // = e^(a + n log(2))
1298  // = e^a * 2^n.
1299  //
1300  // We choose n = round(x / log(2)), restricting the value of `a` to
1301  // (-log(2)/2, log(2)/2). We then use a polynomial to compute e^a. The
1302  // relative error between our approximation and the true value of e^a is less
1303  // than 2^-22.5 for all values of `a` within this range.
1304 
1305  // Restrict input to a small range, including some values that evaluate to
1306  // +/- inf. Note that for our lower bound, we choose log(2^-126) instead of
1307  // log(F32_EPSILON). We do so because this routine always flushes denormal
1308  // floating points to 0. Therefore, we only need to worry about exponentiating
1309  // up to the smallest representable non-denormal floating point, which is
1310  // 2^-126.
1311 
1312  // Constants.
1313  Value cstHalf = bcast(f32Cst(builder, 0.5f));
1314  Value cstOne = bcast(f32Cst(builder, 1.0f));
1315 
1316  // 1/log(2)
1317  Value cstLog2ef = bcast(f32Cst(builder, 1.44269504088896341f));
1318 
1319  Value cstExpC1 = bcast(f32Cst(builder, -0.693359375f));
1320  Value cstExpC2 = bcast(f32Cst(builder, 2.12194440e-4f));
1321  Value cstExpP0 = bcast(f32Cst(builder, 1.9875691500E-4f));
1322  Value cstExpP1 = bcast(f32Cst(builder, 1.3981999507E-3f));
1323  Value cstExpP2 = bcast(f32Cst(builder, 8.3334519073E-3f));
1324  Value cstExpP3 = bcast(f32Cst(builder, 4.1665795894E-2f));
1325  Value cstExpP4 = bcast(f32Cst(builder, 1.6666665459E-1f));
1326  Value cstExpP5 = bcast(f32Cst(builder, 5.0000001201E-1f));
1327 
1328  // Our computations below aren't particularly sensitive to the exact choices
1329  // here, so we choose values a bit larger/smaller than
1330  //
1331  // log(F32_MAX) = 88.723...
1332  // log(2^-126) = -87.337...
1333  Value x = op.getOperand();
1334  x = clampWithNormals(builder, shape, x, -87.8f, 88.8f);
1335  Value n = floor(fmla(x, cstLog2ef, cstHalf));
1336 
1337  // When we eventually do the multiplication in e^a * 2^n, we need to handle
1338  // the case when n > 127, the max fp32 exponent (so 2^n == inf) but e^a < 1
1339  // (so e^a * 2^n != inf). There's a similar problem for n < -126, the
1340  // smallest fp32 exponent.
1341  //
1342  // A straightforward solution would be to detect n out of range and split it
1343  // up, doing
1344  //
1345  // e^a * 2^n = e^a * 2^(n1 + n2)
1346  // = (2^n1 * e^a) * 2^n2.
1347  //
1348  // But it turns out this approach is quite slow, probably because it
1349  // manipulates subnormal values.
1350  //
1351  // The approach we use instead is to clamp n to [-127, 127]. Let n' be the
1352  // value of n clamped to [-127, 127]. In the case where n' = 127, `a` can grow
1353  // up to as large as 88.8 - 127 * log(2) which is about 0.7703. Even though
1354  // this value of `a` is outside our previously specified range, e^a will still
1355  // only have a relative error of approximately 2^-16 at worse. In practice
1356  // this seems to work well enough; it passes our exhaustive tests, breaking
1357  // only one result, and by one ulp (we return exp(88.7228394) = max-float but
1358  // we should return inf).
1359  //
1360  // In the case where n' = -127, the original input value of x is so small that
1361  // e^x, our final answer, is less than 2^-126. Since 2^-126 is the smallest
1362  // normal floating point, and since we flush denormals, we simply return 0. We
1363  // do this in a branchless way by observing that our code for constructing 2^n
1364  // produces 0 if n = -127.
1365  //
1366  // The proof that n' = -127 implies e^x < 2^-126 is as follows:
1367  //
1368  // n' = -127 implies n <= -127
1369  // implies round(x / log(2)) <= -127
1370  // implies x/log(2) < -126.5
1371  // implies x < -126.5 * log(2)
1372  // implies e^x < e^(-126.5 * log(2))
1373  // implies e^x < 2^-126.5 < 2^-126
1374  //
1375  // This proves that n' = -127 implies e^x < 2^-126.
1376  n = clampWithNormals(builder, shape, n, -127.0f, 127.0f);
1377 
1378  // Computes x = x - n' * log(2), the value for `a`
1379  x = fmla(cstExpC1, n, x);
1380  x = fmla(cstExpC2, n, x);
1381 
1382  // Polynomial to compute z = e^a, accurate for a in (-0.5, 0.5).
1383  Value z = fmla(x, cstExpP0, cstExpP1);
1384  z = fmla(z, x, cstExpP2);
1385  z = fmla(z, x, cstExpP3);
1386  z = fmla(z, x, cstExpP4);
1387  z = fmla(z, x, cstExpP5);
1388  z = fmla(z, mul(x, x), x);
1389  z = add(cstOne, z);
1390 
1391  // Convert n' to an i32. This is safe because we clamped it above.
1392  auto i32Vec = broadcast(builder.getI32Type(), shape);
1393  Value nI32 = arith::FPToSIOp::create(builder, i32Vec, n);
1394 
1395  // Creates the value 2^n' if -126 <= n' <= 127 and 0 if n' = -127.
1396  Value pow2 = exp2I32(builder, nI32);
1397 
1398  // Return z * 2^n' if -126 <= n' <= 127 and 0 if n = -127.
1399  Value ret = mul(z, pow2);
1400 
1401  rewriter.replaceOp(op, ret);
1402  return mlir::success();
1403 }
1404 
1405 } // namespace
1406 
1407 //----------------------------------------------------------------------------//
1408 // ExpM1 approximation.
1409 //----------------------------------------------------------------------------//
1410 
1411 namespace {
1412 
1413 struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> {
1414 public:
1416 
1417  LogicalResult matchAndRewrite(math::ExpM1Op op,
1418  PatternRewriter &rewriter) const final;
1419 };
1420 } // namespace
1421 
1422 LogicalResult
1423 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
1424  PatternRewriter &rewriter) const {
1425  if (!getElementTypeOrSelf(op.getOperand()).isF32())
1426  return rewriter.notifyMatchFailure(op, "unsupported operand type");
1427 
1428  std::optional<VectorShape> shape = vectorShape(op.getOperand());
1429 
1430  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1431  auto bcast = [&](Value value) -> Value {
1432  return broadcast(builder, value, shape);
1433  };
1434 
1435  // expm1(x) = exp(x) - 1 = u - 1.
1436  // We have to handle it carefully when x is near 0, i.e. u ~= 1,
1437  // and when the input is ~= -inf, i.e. u - 1 ~= -1.
1438  Value cstOne = bcast(f32Cst(builder, 1.0f));
1439  Value cstNegOne = bcast(f32Cst(builder, -1.0f));
1440  Value x = op.getOperand();
1441  Value u = math::ExpOp::create(builder, x);
1442  Value uEqOneOrNaN =
1443  arith::CmpFOp::create(builder, arith::CmpFPredicate::UEQ, u, cstOne);
1444  Value uMinusOne = arith::SubFOp::create(builder, u, cstOne);
1445  Value uMinusOneEqNegOne = arith::CmpFOp::create(
1446  builder, arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
1447  // logU = log(u) ~= x
1448  Value logU = math::LogOp::create(builder, u);
1449 
1450  // Detect exp(x) = +inf; written this way to avoid having to form +inf.
1451  Value isInf =
1452  arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, logU, u);
1453 
1454  // (u - 1) * (x / ~x)
1455  Value expm1 = arith::MulFOp::create(builder, uMinusOne,
1456  arith::DivFOp::create(builder, x, logU));
1457  expm1 = arith::SelectOp::create(builder, isInf, u, expm1);
1458  Value approximation = arith::SelectOp::create(
1459  builder, uEqOneOrNaN, x,
1460  arith::SelectOp::create(builder, uMinusOneEqNegOne, cstNegOne, expm1));
1461  rewriter.replaceOp(op, approximation);
1462  return success();
1463 }
1464 
1465 //----------------------------------------------------------------------------//
1466 // Sin and Cos approximation.
1467 //----------------------------------------------------------------------------//
1468 
1469 namespace {
1470 
1471 template <bool isSine, typename OpTy>
1472 struct SinAndCosApproximation : public OpRewritePattern<OpTy> {
1473 public:
1475 
1476  LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
1477 };
1478 } // namespace
1479 
1480 #define TWO_OVER_PI \
1481  0.6366197723675813430755350534900574481378385829618257949906693762L
1482 #define PI_OVER_2 \
1483  1.5707963267948966192313216916397514420985846996875529104874722961L
1484 
1485 // Approximates sin(x) or cos(x) by finding the best approximation polynomial in
1486 // the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the
1487 // reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y).
1488 template <bool isSine, typename OpTy>
1489 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
1490  OpTy op, PatternRewriter &rewriter) const {
1491  static_assert(
1492  llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
1493  "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
1494 
1495  if (!getElementTypeOrSelf(op.getOperand()).isF32())
1496  return rewriter.notifyMatchFailure(op, "unsupported operand type");
1497 
1498  std::optional<VectorShape> shape = vectorShape(op.getOperand());
1499 
1500  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1501  auto bcast = [&](Value value) -> Value {
1502  return broadcast(builder, value, shape);
1503  };
1504  auto mul = [&](Value a, Value b) -> Value {
1505  return arith::MulFOp::create(builder, a, b);
1506  };
1507  auto sub = [&](Value a, Value b) -> Value {
1508  return arith::SubFOp::create(builder, a, b);
1509  };
1510  auto floor = [&](Value a) { return math::FloorOp::create(builder, a); };
1511 
1512  auto i32Vec = broadcast(builder.getI32Type(), shape);
1513  auto fPToSingedInteger = [&](Value a) -> Value {
1514  return arith::FPToSIOp::create(builder, i32Vec, a);
1515  };
1516 
1517  auto modulo4 = [&](Value a) -> Value {
1518  return arith::AndIOp::create(builder, a, bcast(i32Cst(builder, 3)));
1519  };
1520 
1521  auto isEqualTo = [&](Value a, Value b) -> Value {
1522  return arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, a, b);
1523  };
1524 
1525  auto isGreaterThan = [&](Value a, Value b) -> Value {
1526  return arith::CmpIOp::create(builder, arith::CmpIPredicate::sgt, a, b);
1527  };
1528 
1529  auto select = [&](Value cond, Value t, Value f) -> Value {
1530  return arith::SelectOp::create(builder, cond, t, f);
1531  };
1532 
1533  auto fmla = [&](Value a, Value b, Value c) {
1534  return math::FmaOp::create(builder, a, b, c);
1535  };
1536 
1537  auto bitwiseOr = [&](Value a, Value b) {
1538  return arith::OrIOp::create(builder, a, b);
1539  };
1540 
1541  Value twoOverPi = bcast(f32Cst(builder, (float)TWO_OVER_PI));
1542  Value piOverTwo = bcast(f32Cst(builder, (float)PI_OVER_2));
1543 
1544  Value x = op.getOperand();
1545 
1546  Value k = floor(mul(x, twoOverPi));
1547 
1548  Value y = sub(x, mul(k, piOverTwo));
1549 
1550  Value cstOne = bcast(f32Cst(builder, 1.0));
1551  Value cstNegativeOne = bcast(f32Cst(builder, -1.0));
1552 
1553  Value cstSC2 = bcast(f32Cst(builder, -0.16666667163372039794921875f));
1554  Value cstSC4 = bcast(f32Cst(builder, 8.333347737789154052734375e-3f));
1555  Value cstSC6 = bcast(f32Cst(builder, -1.9842604524455964565277099609375e-4f));
1556  Value cstSC8 =
1557  bcast(f32Cst(builder, 2.760012648650445044040679931640625e-6f));
1558  Value cstSC10 =
1559  bcast(f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f));
1560 
1561  Value cstCC2 = bcast(f32Cst(builder, -0.5f));
1562  Value cstCC4 = bcast(f32Cst(builder, 4.166664183139801025390625e-2f));
1563  Value cstCC6 = bcast(f32Cst(builder, -1.388833043165504932403564453125e-3f));
1564  Value cstCC8 = bcast(f32Cst(builder, 2.47562347794882953166961669921875e-5f));
1565  Value cstCC10 =
1566  bcast(f32Cst(builder, -2.59630184018533327616751194000244140625e-7f));
1567 
1568  Value kMod4 = modulo4(fPToSingedInteger(k));
1569 
1570  Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, 0)));
1571  Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, 1)));
1572  Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, 2)));
1573  Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, 3)));
1574 
1575  Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
1576  Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, 1)))
1577  : bitwiseOr(kR1, kR2);
1578 
1579  Value y2 = mul(y, y);
1580 
1581  Value base = select(sinuseCos, cstOne, y);
1582  Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
1583  Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
1584  Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
1585  Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
1586  Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
1587 
1588  Value v1 = fmla(y2, cstC10, cstC8);
1589  Value v2 = fmla(y2, v1, cstC6);
1590  Value v3 = fmla(y2, v2, cstC4);
1591  Value v4 = fmla(y2, v3, cstC2);
1592  Value v5 = fmla(y2, v4, cstOne);
1593  Value v6 = mul(base, v5);
1594 
1595  Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6);
1596 
1597  rewriter.replaceOp(op, approximation);
1598 
1599  return success();
1600 }
1601 
1602 //----------------------------------------------------------------------------//
1603 // Cbrt approximation.
1604 //----------------------------------------------------------------------------//
1605 
1606 namespace {
1607 struct CbrtApproximation : public OpRewritePattern<math::CbrtOp> {
1609 
1610  LogicalResult matchAndRewrite(math::CbrtOp op,
1611  PatternRewriter &rewriter) const final;
1612 };
1613 } // namespace
1614 
1615 // Estimation of cube-root using an algorithm defined in
1616 // Hacker's Delight 2nd Edition.
1617 LogicalResult
1618 CbrtApproximation::matchAndRewrite(math::CbrtOp op,
1619  PatternRewriter &rewriter) const {
1620  auto operand = op.getOperand();
1621  if (!getElementTypeOrSelf(operand).isF32())
1622  return rewriter.notifyMatchFailure(op, "unsupported operand type");
1623 
1624  ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1625  std::optional<VectorShape> shape = vectorShape(operand);
1626 
1627  Type floatTy = getElementTypeOrSelf(operand.getType());
1628  Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth());
1629 
1630  // Convert to vector types if necessary.
1631  floatTy = broadcast(floatTy, shape);
1632  intTy = broadcast(intTy, shape);
1633 
1634  auto bconst = [&](TypedAttr attr) -> Value {
1635  Value value = arith::ConstantOp::create(b, attr);
1636  return broadcast(b, value, shape);
1637  };
1638 
1639  // Declare the initial values:
1640  Value intTwo = bconst(b.getI32IntegerAttr(2));
1641  Value intFour = bconst(b.getI32IntegerAttr(4));
1642  Value intEight = bconst(b.getI32IntegerAttr(8));
1643  Value intMagic = bconst(b.getI32IntegerAttr(0x2a5137a0));
1644  Value fpThird = bconst(b.getF32FloatAttr(0.33333333f));
1645  Value fpTwo = bconst(b.getF32FloatAttr(2.0f));
1646  Value fpZero = bconst(b.getF32FloatAttr(0.0f));
1647 
1648  // Compute an approximation of one third:
1649  // union {int ix; float x;};
1650  // x = x0;
1651  // ix = ix/4 + ix/16;
1652  Value absValue = math::AbsFOp::create(b, operand);
1653  Value intValue = arith::BitcastOp::create(b, intTy, absValue);
1654  Value divideBy4 = arith::ShRSIOp::create(b, intValue, intTwo);
1655  Value divideBy16 = arith::ShRSIOp::create(b, intValue, intFour);
1656  intValue = arith::AddIOp::create(b, divideBy4, divideBy16);
1657 
1658  // ix = ix + ix/16;
1659  divideBy16 = arith::ShRSIOp::create(b, intValue, intFour);
1660  intValue = arith::AddIOp::create(b, intValue, divideBy16);
1661 
1662  // ix = ix + ix/256;
1663  Value divideBy256 = arith::ShRSIOp::create(b, intValue, intEight);
1664  intValue = arith::AddIOp::create(b, intValue, divideBy256);
1665 
1666  // ix = 0x2a5137a0 + ix;
1667  intValue = arith::AddIOp::create(b, intValue, intMagic);
1668 
1669  // Perform one newtons step:
1670  // x = 0.33333333f*(2.0f*x + x0/(x*x));
1671  Value floatValue = arith::BitcastOp::create(b, floatTy, intValue);
1672  Value squared = arith::MulFOp::create(b, floatValue, floatValue);
1673  Value mulTwo = arith::MulFOp::create(b, floatValue, fpTwo);
1674  Value divSquared = arith::DivFOp::create(b, absValue, squared);
1675  floatValue = arith::AddFOp::create(b, mulTwo, divSquared);
1676  floatValue = arith::MulFOp::create(b, floatValue, fpThird);
1677 
1678  // x = 0.33333333f*(2.0f*x + x0/(x*x));
1679  squared = arith::MulFOp::create(b, floatValue, floatValue);
1680  mulTwo = arith::MulFOp::create(b, floatValue, fpTwo);
1681  divSquared = arith::DivFOp::create(b, absValue, squared);
1682  floatValue = arith::AddFOp::create(b, mulTwo, divSquared);
1683  floatValue = arith::MulFOp::create(b, floatValue, fpThird);
1684 
1685  // Check for zero and restore sign.
1686  Value isZero =
1687  arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absValue, fpZero);
1688  floatValue = arith::SelectOp::create(b, isZero, fpZero, floatValue);
1689  floatValue = math::CopySignOp::create(b, floatValue, operand);
1690 
1691  rewriter.replaceOp(op, floatValue);
1692  return success();
1693 }
1694 
1695 //----------------------------------------------------------------------------//
1696 // Rsqrt approximation.
1697 //----------------------------------------------------------------------------//
1698 
1699 namespace {
1700 struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> {
1702 
1703  LogicalResult matchAndRewrite(math::RsqrtOp op,
1704  PatternRewriter &rewriter) const final;
1705 };
1706 } // namespace
1707 
1708 LogicalResult
1709 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
1710  PatternRewriter &rewriter) const {
1711  if (!getElementTypeOrSelf(op.getOperand()).isF32())
1712  return rewriter.notifyMatchFailure(op, "unsupported operand type");
1713 
1714  std::optional<VectorShape> shape = vectorShape(op.getOperand());
1715 
1716  // Only support already-vectorized rsqrt's.
1717  if (!shape || shape->sizes.empty() || shape->sizes.back() % 8 != 0)
1718  return rewriter.notifyMatchFailure(op, "unsupported operand type");
1719 
1720  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1721  auto bcast = [&](Value value) -> Value {
1722  return broadcast(builder, value, shape);
1723  };
1724 
1725  Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));
1726  Value cstOnePointFive = bcast(f32Cst(builder, 1.5f));
1727  Value cstNegHalf = bcast(f32Cst(builder, -0.5f));
1728  Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u));
1729 
1730  Value negHalf = arith::MulFOp::create(builder, op.getOperand(), cstNegHalf);
1731 
1732  // Select only the inverse sqrt of positive normals (denormals are
1733  // flushed to zero).
1734  Value ltMinMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT,
1735  op.getOperand(), cstMinNormPos);
1736  Value infMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ,
1737  op.getOperand(), cstPosInf);
1738  Value notNormalFiniteMask = arith::OrIOp::create(builder, ltMinMask, infMask);
1739 
1740  // Compute an approximate result.
1742  builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value {
1743  return x86vector::RsqrtOp::create(builder, operands);
1744  });
1745 
1746  // Do a single step of Newton-Raphson iteration to improve the approximation.
1747  // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
1748  // It is essential to evaluate the inner term like this because forming
1749  // y_n^2 may over- or underflow.
1750  Value inner = arith::MulFOp::create(builder, negHalf, yApprox);
1751  Value fma = math::FmaOp::create(builder, yApprox, inner, cstOnePointFive);
1752  Value yNewton = arith::MulFOp::create(builder, yApprox, fma);
1753 
1754  // Select the result of the Newton-Raphson step for positive normal arguments.
1755  // For other arguments, choose the output of the intrinsic. This will
1756  // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if
1757  // x is zero or a positive denormalized float (equivalent to flushing positive
1758  // denormalized inputs to zero).
1759  Value res =
1760  arith::SelectOp::create(builder, notNormalFiniteMask, yApprox, yNewton);
1761  rewriter.replaceOp(op, res);
1762 
1763  return success();
1764 }
1765 
1766 //----------------------------------------------------------------------------//
1767 
1770  patterns.add<TanhApproximation>(patterns.getContext());
1771 }
1772 
1775  patterns.add<ErfPolynomialApproximation>(patterns.getContext());
1776 }
1777 
1780  patterns.add<ErfcPolynomialApproximation>(patterns.getContext());
1781 }
1782 
1783 template <typename OpType>
1784 static void
1786  llvm::function_ref<bool(StringRef)> predicate,
1787  PatternBenefit benefit) {
1788  if (predicate(OpType::getOperationName())) {
1789  patterns.add<ReuseF32Expansion<OpType>>(patterns.getContext(), benefit);
1790  }
1791 }
1792 
1794  RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
1795  PatternBenefit benefit) {
1796  populateMathF32ExpansionPattern<math::AcosOp>(patterns, predicate, benefit);
1797  populateMathF32ExpansionPattern<math::AcoshOp>(patterns, predicate, benefit);
1798  populateMathF32ExpansionPattern<math::AsinOp>(patterns, predicate, benefit);
1799  populateMathF32ExpansionPattern<math::AsinhOp>(patterns, predicate, benefit);
1800  populateMathF32ExpansionPattern<math::AtanOp>(patterns, predicate, benefit);
1801  populateMathF32ExpansionPattern<math::Atan2Op>(patterns, predicate, benefit);
1802  populateMathF32ExpansionPattern<math::AtanhOp>(patterns, predicate, benefit);
1803  populateMathF32ExpansionPattern<math::CbrtOp>(patterns, predicate, benefit);
1804  populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate, benefit);
1805  populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate, benefit);
1806  populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate, benefit);
1807  populateMathF32ExpansionPattern<math::ErfcOp>(patterns, predicate, benefit);
1808  populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate, benefit);
1809  populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate, benefit);
1810  populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate, benefit);
1811  populateMathF32ExpansionPattern<math::LogOp>(patterns, predicate, benefit);
1812  populateMathF32ExpansionPattern<math::Log10Op>(patterns, predicate, benefit);
1813  populateMathF32ExpansionPattern<math::Log1pOp>(patterns, predicate, benefit);
1814  populateMathF32ExpansionPattern<math::Log2Op>(patterns, predicate, benefit);
1815  populateMathF32ExpansionPattern<math::PowFOp>(patterns, predicate, benefit);
1816  populateMathF32ExpansionPattern<math::RsqrtOp>(patterns, predicate, benefit);
1817  populateMathF32ExpansionPattern<math::SinOp>(patterns, predicate, benefit);
1818  populateMathF32ExpansionPattern<math::SinhOp>(patterns, predicate, benefit);
1819  populateMathF32ExpansionPattern<math::SqrtOp>(patterns, predicate, benefit);
1820  populateMathF32ExpansionPattern<math::TanOp>(patterns, predicate, benefit);
1821  populateMathF32ExpansionPattern<math::TanhOp>(patterns, predicate, benefit);
1822 }
1823 
1824 template <typename OpType, typename PatternType>
1826  RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
1827  PatternBenefit benefit) {
1828  if (predicate(OpType::getOperationName())) {
1829  patterns.add<PatternType>(patterns.getContext(), benefit);
1830  }
1831 }
1832 
1834  RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
1835  PatternBenefit benefit) {
1837  AcosPolynomialApproximation>(
1838  patterns, predicate, benefit);
1840  AsinPolynomialApproximation>(
1841  patterns, predicate, benefit);
1842  populateMathPolynomialApproximationPattern<AtanOp, AtanApproximation>(
1843  patterns, predicate, benefit);
1844  populateMathPolynomialApproximationPattern<Atan2Op, Atan2Approximation>(
1845  patterns, predicate, benefit);
1846  populateMathPolynomialApproximationPattern<CbrtOp, CbrtApproximation>(
1847  patterns, predicate, benefit);
1849  CosOp, SinAndCosApproximation<false, math::CosOp>>(patterns, predicate,
1850  benefit);
1851  populateMathPolynomialApproximationPattern<ErfOp, ErfPolynomialApproximation>(
1852  patterns, predicate, benefit);
1855  patterns, predicate, benefit);
1856  populateMathPolynomialApproximationPattern<ExpOp, ExpApproximation>(
1857  patterns, predicate, benefit);
1858  populateMathPolynomialApproximationPattern<ExpM1Op, ExpM1Approximation>(
1859  patterns, predicate, benefit);
1860  populateMathPolynomialApproximationPattern<LogOp, LogApproximation>(
1861  patterns, predicate, benefit);
1862  populateMathPolynomialApproximationPattern<Log2Op, Log2Approximation>(
1863  patterns, predicate, benefit);
1864  populateMathPolynomialApproximationPattern<Log1pOp, Log1pApproximation>(
1865  patterns, predicate, benefit);
1866  populateMathPolynomialApproximationPattern<RsqrtOp, RsqrtApproximation>(
1867  patterns, predicate, benefit);
1869  SinOp, SinAndCosApproximation<true, math::SinOp>>(patterns, predicate,
1870  benefit);
1871  populateMathPolynomialApproximationPattern<TanhOp, TanhApproximation>(
1872  patterns, predicate, benefit);
1873 }
1874 
1878  mlir::populateMathF32ExpansionPatterns(patterns, [](StringRef name) -> bool {
1879  return llvm::is_contained(
1880  {math::AtanOp::getOperationName(), math::Atan2Op::getOperationName(),
1881  math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
1882  math::Log2Op::getOperationName(), math::Log1pOp::getOperationName(),
1883  math::ErfOp::getOperationName(), math::ErfcOp::getOperationName(),
1884  math::ExpOp::getOperationName(), math::ExpM1Op::getOperationName(),
1885  math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
1886  math::CosOp::getOperationName()},
1887  name);
1888  });
1889 
1891  patterns, [](StringRef name) -> bool {
1892  return llvm::is_contained(
1893  {math::AtanOp::getOperationName(),
1894  math::Atan2Op::getOperationName(),
1895  math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
1896  math::Log2Op::getOperationName(),
1897  math::Log1pOp::getOperationName(), math::ErfOp::getOperationName(),
1898  math::ErfcOp::getOperationName(), math::AsinOp::getOperationName(),
1899  math::AcosOp::getOperationName(), math::ExpOp::getOperationName(),
1900  math::ExpM1Op::getOperationName(),
1901  math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
1902  math::CosOp::getOperationName()},
1903  name);
1904  });
1905 
1906  if (options.enableAvx2) {
1907  auto predicateRsqrt = [](StringRef name) {
1908  return name == math::RsqrtOp::getOperationName();
1909  };
1912  }
1913 }
static llvm::ManagedStatic< PassManagerOptions > options
static std::pair< Value, Value > frexp(ImplicitLocOpBuilder &builder, Value arg, bool isPositive=false)
#define LN2_VALUE
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg)
#define PI_OVER_2
static void populateMathF32ExpansionPattern(RewritePatternSet &patterns, llvm::function_ref< bool(StringRef)> predicate, PatternBenefit benefit)
static std::optional< VectorShape > vectorShape(Type type)
#define TWO_OVER_PI
static Value boolCst(ImplicitLocOpBuilder &builder, bool value)
static Value floatCst(ImplicitLocOpBuilder &builder, float value, Type elementType)
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value)
static Type broadcast(Type type, std::optional< VectorShape > shape)
#define LOG2E_VALUE
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits)
static Value f32Cst(ImplicitLocOpBuilder &builder, double value)
static void populateMathPolynomialApproximationPattern(RewritePatternSet &patterns, llvm::function_ref< bool(StringRef)> predicate, PatternBenefit benefit)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:195
FloatType getF32Type()
Definition: Builders.cpp:42
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:249
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:95
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:241
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:621
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:129
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
Type front()
Return first type in the range.
Definition: TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
DynamicAPInt floor(const Fraction &f)
Definition: Fraction.h:77
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
detail::LazyTextBuild add(const char *fmt, Ts &&...ts)
Create a Remark with llvm::formatv formatting.
Definition: Remarks.h:463
Include the generated interface declarations.
void populatePolynomialApproximateErfcPattern(RewritePatternSet &patterns)
void populateMathF32ExpansionPatterns(RewritePatternSet &patterns, llvm::function_ref< bool(StringRef)> predicate, PatternBenefit=1)
void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns)
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
int64_t computeMaxLinearIndex(ArrayRef< int64_t > basis)
Return the number of elements of basis (i.e.
Definition: IndexingUtils.h:69
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns, llvm::function_ref< bool(StringRef)> predicate, PatternBenefit=1)
ArrayRef< int64_t > sizes
ArrayRef< bool > scalableFlags
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
LogicalResult matchAndRewrite(math::ErfOp op, PatternRewriter &rewriter) const final
LogicalResult matchAndRewrite(math::ErfcOp op, PatternRewriter &rewriter) const final
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.