MLIR 22.0.0git
PolynomialApproximation.cpp
Go to the documentation of this file.
1//===- PolynomialApproximation.cpp - Approximate math operations ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements expansion of math operations to fast approximations
10// that do not rely on any of the library functions.
11//
12//===----------------------------------------------------------------------===//
13
14#include <climits>
15#include <cmath>
16#include <cstddef>
17
26#include "mlir/IR/Builders.h"
32#include "llvm/ADT/ArrayRef.h"
33#include "llvm/ADT/STLExtras.h"
34#include "llvm/Support/MathExtras.h"
35
36using namespace mlir;
37using namespace mlir::math;
38using namespace mlir::vector;
39
40// Helper to encapsulate a vector's shape (including scalable dims).
45
46// Returns vector shape if the type is a vector, otherwise return nullopt.
47static std::optional<VectorShape> vectorShape(Type type) {
48 if (auto vectorType = dyn_cast<VectorType>(type)) {
49 return VectorShape{vectorType.getShape(), vectorType.getScalableDims()};
50 }
51 return std::nullopt;
52}
53
54static std::optional<VectorShape> vectorShape(Value value) {
55 return vectorShape(value.getType());
56}
57
58//----------------------------------------------------------------------------//
59// Broadcast scalar types and values into vector types and values.
60//----------------------------------------------------------------------------//
61
62// Broadcasts scalar type into vector type (iff shape is non-scalar).
63static Type broadcast(Type type, std::optional<VectorShape> shape) {
64 assert(!isa<VectorType>(type) && "must be scalar type");
65 return shape ? VectorType::get(shape->sizes, type, shape->scalableFlags)
66 : type;
67}
68
69// Broadcasts scalar value into vector (iff shape is non-scalar).
71 std::optional<VectorShape> shape) {
72 assert(!isa<VectorType>(value.getType()) && "must be scalar value");
73 auto type = broadcast(value.getType(), shape);
74 return shape ? BroadcastOp::create(builder, type, value) : value;
75}
76
77//----------------------------------------------------------------------------//
78// Helper function to handle n-D vectors with 1-D operations.
79//----------------------------------------------------------------------------//
80
81// Expands and unrolls n-D vector operands into multiple fixed size 1-D vectors
82// and calls the compute function with 1-D vector operands. Stitches back all
83// results into the original n-D vector result.
84//
85// Examples: vectorWidth = 8
86// - vector<4x8xf32> unrolled 4 times
87// - vector<16xf32> expanded to vector<2x8xf32> and unrolled 2 times
88// - vector<4x16xf32> expanded to vector<4x2x8xf32> and unrolled 4*2 times
89//
90// Some math approximations rely on ISA-specific operations that only accept
91// fixed size 1-D vectors (e.g. AVX expects vectors of width 8).
92//
93// It is the caller's responsibility to verify that the inner dimension is
94// divisible by the vectorWidth, and that all operands have the same vector
95// shape.
96static Value
98 ValueRange operands, int64_t vectorWidth,
100 assert(!operands.empty() && "operands must be not empty");
101 assert(vectorWidth > 0 && "vector width must be larger than 0");
102
103 VectorType inputType = cast<VectorType>(operands[0].getType());
104 ArrayRef<int64_t> inputShape = inputType.getShape();
105
106 // If input shape matches target vector width, we can just call the
107 // user-provided compute function with the operands.
108 if (inputShape == llvm::ArrayRef(vectorWidth))
109 return compute(operands);
110
111 // Check if the inner dimension has to be expanded, or we can directly iterate
112 // over the outer dimensions of the vector.
113 int64_t innerDim = inputShape.back();
114 int64_t expansionDim = innerDim / vectorWidth;
115 assert((innerDim % vectorWidth == 0) && "invalid inner dimension size");
116
117 // Maybe expand operands to the higher rank vector shape that we'll use to
118 // iterate over and extract one dimensional vectors.
119 SmallVector<int64_t> expandedShape(inputShape);
120 SmallVector<Value> expandedOperands(operands);
121
122 if (expansionDim > 1) {
123 // Expand shape from [..., innerDim] to [..., expansionDim, vectorWidth].
124 expandedShape.insert(expandedShape.end() - 1, expansionDim);
125 expandedShape.back() = vectorWidth;
126
127 for (unsigned i = 0; i < operands.size(); ++i) {
128 auto operand = operands[i];
129 auto eltType = cast<VectorType>(operand.getType()).getElementType();
130 auto expandedType = VectorType::get(expandedShape, eltType);
131 expandedOperands[i] =
132 vector::ShapeCastOp::create(builder, expandedType, operand);
133 }
134 }
135
136 // Iterate over all outer dimensions of the compute shape vector type.
137 auto iterationDims = ArrayRef<int64_t>(expandedShape).drop_back();
138 int64_t maxIndex = computeMaxLinearIndex(iterationDims);
139 auto strides = computeStrides(iterationDims);
140
141 // Compute results for each one dimensional vector.
142 SmallVector<Value> results(maxIndex);
143
144 for (int64_t i = 0; i < maxIndex; ++i) {
145 auto offsets = delinearize(i, strides);
146
147 SmallVector<Value> extracted(expandedOperands.size());
148 for (const auto &tuple : llvm::enumerate(expandedOperands))
149 extracted[tuple.index()] =
150 vector::ExtractOp::create(builder, tuple.value(), offsets);
151
152 results[i] = compute(extracted);
153 }
154
155 // Stitch results together into one large vector.
156 Type resultEltType = cast<VectorType>(results[0].getType()).getElementType();
157 Type resultExpandedType = VectorType::get(expandedShape, resultEltType);
158 Value result = arith::ConstantOp::create(
159 builder, resultExpandedType, builder.getZeroAttr(resultExpandedType));
160
161 for (int64_t i = 0; i < maxIndex; ++i)
162 result = vector::InsertOp::create(builder, results[i], result,
163 delinearize(i, strides));
164
165 // Reshape back to the original vector shape.
166 return vector::ShapeCastOp::create(
167 builder, VectorType::get(inputShape, resultEltType), result);
168}
169
170//----------------------------------------------------------------------------//
171// Helper functions to create constants.
172//----------------------------------------------------------------------------//
173
174static Value boolCst(ImplicitLocOpBuilder &builder, bool value) {
175 return arith::ConstantOp::create(builder, builder.getBoolAttr(value));
176}
177
178static Value floatCst(ImplicitLocOpBuilder &builder, float value,
179 Type elementType) {
180 assert((elementType.isF16() || elementType.isF32()) &&
181 "x must be f16 or f32 type.");
182 return arith::ConstantOp::create(builder,
183 builder.getFloatAttr(elementType, value));
184}
185
186static Value f32Cst(ImplicitLocOpBuilder &builder, double value) {
187 return arith::ConstantOp::create(builder, builder.getF32FloatAttr(value));
188}
189
190static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value) {
191 return arith::ConstantOp::create(builder, builder.getI32IntegerAttr(value));
192}
193
194static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits) {
195 Value i32Value = i32Cst(builder, static_cast<int32_t>(bits));
196 return arith::BitcastOp::create(builder, builder.getF32Type(), i32Value);
197}
198
199//----------------------------------------------------------------------------//
200// Helper functions to build math functions approximations.
201//----------------------------------------------------------------------------//
202
203// Return the minimum of the two values or NaN if value is NaN
204static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound) {
205 return arith::SelectOp::create(
206 builder,
207 arith::CmpFOp::create(builder, arith::CmpFPredicate::ULT, value, bound),
208 value, bound);
209}
210
211// Return the maximum of the two values or NaN if value is NaN
212static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound) {
213 return arith::SelectOp::create(
214 builder,
215 arith::CmpFOp::create(builder, arith::CmpFPredicate::UGT, value, bound),
216 value, bound);
217}
218
219// Return the clamped value or NaN if value is NaN
220static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound,
221 Value upperBound) {
222 return max(builder, min(builder, value, upperBound), lowerBound);
223}
224
225// Decomposes given floating point value `arg` into a normalized fraction and
226// an integral power of two (see std::frexp). Returned values have float type.
227static std::pair<Value, Value> frexp(ImplicitLocOpBuilder &builder, Value arg,
228 bool isPositive = false) {
229 assert(getElementTypeOrSelf(arg).isF32() && "arg must be f32 type");
230 std::optional<VectorShape> shape = vectorShape(arg);
231
232 auto bcast = [&](Value value) -> Value {
233 return broadcast(builder, value, shape);
234 };
235
236 auto i32 = builder.getIntegerType(32);
237 auto i32Vec = broadcast(i32, shape);
238 auto f32Vec = broadcast(builder.getF32Type(), shape);
239
240 Value cst126f = f32Cst(builder, 126.0f);
241 Value cstHalf = f32Cst(builder, 0.5f);
242 Value cstInvMantMask = f32FromBits(builder, ~0x7f800000u);
243
244 // Bitcast to i32 for bitwise operations.
245 Value i32Half = arith::BitcastOp::create(builder, i32, cstHalf);
246 Value i32InvMantMask = arith::BitcastOp::create(builder, i32, cstInvMantMask);
247 Value i32Arg = arith::BitcastOp::create(builder, i32Vec, arg);
248
249 // Compute normalized fraction.
250 Value tmp0 = arith::AndIOp::create(builder, i32Arg, bcast(i32InvMantMask));
251 Value tmp1 = arith::OrIOp::create(builder, tmp0, bcast(i32Half));
252 Value normalizedFraction = arith::BitcastOp::create(builder, f32Vec, tmp1);
253
254 // Compute exponent.
255 Value arg0 = isPositive ? arg : math::AbsFOp::create(builder, arg);
256 Value biasedExponentBits = arith::ShRUIOp::create(
257 builder, arith::BitcastOp::create(builder, i32Vec, arg0),
258 bcast(i32Cst(builder, 23)));
259 Value biasedExponent =
260 arith::SIToFPOp::create(builder, f32Vec, biasedExponentBits);
261 Value exponent =
262 arith::SubFOp::create(builder, biasedExponent, bcast(cst126f));
263
264 return {normalizedFraction, exponent};
265}
266
267// Computes exp2 for an i32 argument.
269 assert(getElementTypeOrSelf(arg).isInteger(32) && "arg must be i32 type");
270 std::optional<VectorShape> shape = vectorShape(arg);
271
272 auto bcast = [&](Value value) -> Value {
273 return broadcast(builder, value, shape);
274 };
275
276 auto f32Vec = broadcast(builder.getF32Type(), shape);
277 // The exponent of f32 located at 23-bit.
278 auto exponetBitLocation = bcast(i32Cst(builder, 23));
279 // Set the exponent bias to zero.
280 auto bias = bcast(i32Cst(builder, 127));
281
282 Value biasedArg = arith::AddIOp::create(builder, arg, bias);
283 Value exp2ValueInt =
284 arith::ShLIOp::create(builder, biasedArg, exponetBitLocation);
285 Value exp2ValueF32 = arith::BitcastOp::create(builder, f32Vec, exp2ValueInt);
286
287 return exp2ValueF32;
288}
289
290namespace {
291Value makePolynomialCalculation(ImplicitLocOpBuilder &builder,
292 llvm::ArrayRef<Value> coeffs, Value x) {
293 Type elementType = getElementTypeOrSelf(x);
294 assert((elementType.isF32() || elementType.isF16()) &&
295 "x must be f32 or f16 type");
296 std::optional<VectorShape> shape = vectorShape(x);
297
298 if (coeffs.empty())
299 return broadcast(builder, floatCst(builder, 0.0f, elementType), shape);
300
301 if (coeffs.size() == 1)
302 return coeffs[0];
303
304 Value res = math::FmaOp::create(builder, x, coeffs[coeffs.size() - 1],
305 coeffs[coeffs.size() - 2]);
306 for (auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) {
307 res = math::FmaOp::create(builder, x, res, coeffs[i]);
308 }
309 return res;
310}
311} // namespace
312
313//----------------------------------------------------------------------------//
314// Helper function/pattern to insert casts for reusing F32 bit expansion.
315//----------------------------------------------------------------------------//
316
317template <typename T>
318LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) {
319 // Conservatively only allow where the operand and result types are exactly 1.
320 Type origType = op->getResultTypes().front();
321 for (Type t : llvm::drop_begin(op->getResultTypes()))
322 if (origType != t)
323 return rewriter.notifyMatchFailure(op, "required all types to match");
324 for (Type t : op->getOperandTypes())
325 if (origType != t)
326 return rewriter.notifyMatchFailure(op, "required all types to match");
327
328 // Skip if already F32 or larger than 32 bits.
329 if (getElementTypeOrSelf(origType).isF32() ||
330 getElementTypeOrSelf(origType).getIntOrFloatBitWidth() > 32)
331 return failure();
332
333 // Create F32 equivalent type.
334 Type newType;
335 if (auto shaped = dyn_cast<ShapedType>(origType)) {
336 newType = shaped.clone(rewriter.getF32Type());
337 } else if (isa<FloatType>(origType)) {
338 newType = rewriter.getF32Type();
339 } else {
340 return rewriter.notifyMatchFailure(op,
341 "unable to find F32 equivalent type");
342 }
343
344 Location loc = op->getLoc();
345 SmallVector<Value> operands;
346 for (auto operand : op->getOperands())
347 operands.push_back(arith::ExtFOp::create(rewriter, loc, newType, operand));
348 auto result =
349 T::create(rewriter, loc, TypeRange{newType}, operands, op->getAttrs());
350 rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, origType, result);
351 return success();
352}
353
354namespace {
355// Pattern to cast to F32 to reuse F32 expansion as fallback for single-result
356// op.
357// TODO: Consider revising to avoid adding multiple casts for a subgraph that is
358// all in lower precision. Currently this is only fallback support and performs
359// simplistic casting.
360template <typename T>
361struct ReuseF32Expansion : public OpRewritePattern<T> {
362public:
363 using OpRewritePattern<T>::OpRewritePattern;
364 LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const final {
365 static_assert(
366 T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(),
367 "requires same operands and result types");
368 return insertCasts<T>(op, rewriter);
369 }
370};
371} // namespace
372
373//----------------------------------------------------------------------------//
374// AtanOp approximation.
375//----------------------------------------------------------------------------//
376
377namespace {
378struct AtanApproximation : public OpRewritePattern<math::AtanOp> {
379public:
381
382 LogicalResult matchAndRewrite(math::AtanOp op,
383 PatternRewriter &rewriter) const final;
384};
385} // namespace
386
387LogicalResult
388AtanApproximation::matchAndRewrite(math::AtanOp op,
389 PatternRewriter &rewriter) const {
390 auto operand = op.getOperand();
391 if (!getElementTypeOrSelf(operand).isF32())
392 return rewriter.notifyMatchFailure(op, "unsupported operand type");
393
394 std::optional<VectorShape> shape = vectorShape(op.getOperand());
395
396 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
397 Value abs = math::AbsFOp::create(builder, operand);
398
399 auto one = broadcast(builder, f32Cst(builder, 1.0), shape);
400
401 // When 0.66 < x <= 2.41 we do (x-1) / (x+1):
402 auto twoThirds = broadcast(builder, f32Cst(builder, 0.66), shape);
403 Value cmp2 =
404 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, abs, twoThirds);
405 Value addone = arith::AddFOp::create(builder, abs, one);
406 Value subone = arith::SubFOp::create(builder, abs, one);
407 Value xnum = arith::SelectOp::create(builder, cmp2, subone, abs);
408 Value xden = arith::SelectOp::create(builder, cmp2, addone, one);
409
410 auto bcast = [&](Value value) -> Value {
411 return broadcast(builder, value, shape);
412 };
413
414 // Break into the <= 0.66 or > 2.41 we do x or 1/x:
415 auto tan3pio8 = bcast(f32Cst(builder, 2.41421356237309504880));
416 Value cmp1 =
417 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, abs, tan3pio8);
418 xnum = arith::SelectOp::create(builder, cmp1, one, xnum);
419 xden = arith::SelectOp::create(builder, cmp1, abs, xden);
420
421 Value x = arith::DivFOp::create(builder, xnum, xden);
422 Value xx = arith::MulFOp::create(builder, x, x);
423
424 // Perform the Taylor series approximation for atan over the range
425 // [0.0, 0.66].
426 auto p0 = bcast(f32Cst(builder, -8.750608600031904122785e-01));
427 auto p1 = bcast(f32Cst(builder, -1.615753718733365076637e+01));
428 auto p2 = bcast(f32Cst(builder, -7.500855792314704667340e+01));
429 auto p3 = bcast(f32Cst(builder, -1.228866684490136173410e+02));
430 auto p4 = bcast(f32Cst(builder, -6.485021904942025371773e+01));
431 auto q0 = bcast(f32Cst(builder, +2.485846490142306297962e+01));
432 auto q1 = bcast(f32Cst(builder, +1.650270098316988542046e+02));
433 auto q2 = bcast(f32Cst(builder, +4.328810604912902668951e+02));
434 auto q3 = bcast(f32Cst(builder, +4.853903996359136964868e+02));
435 auto q4 = bcast(f32Cst(builder, +1.945506571482613964425e+02));
436
437 // Apply the polynomial approximation for the numerator:
438 Value n = p0;
439 n = math::FmaOp::create(builder, xx, n, p1);
440 n = math::FmaOp::create(builder, xx, n, p2);
441 n = math::FmaOp::create(builder, xx, n, p3);
442 n = math::FmaOp::create(builder, xx, n, p4);
443 n = arith::MulFOp::create(builder, n, xx);
444
445 // Apply the polynomial approximation for the denominator:
446 Value d = q0;
447 d = math::FmaOp::create(builder, xx, d, q1);
448 d = math::FmaOp::create(builder, xx, d, q2);
449 d = math::FmaOp::create(builder, xx, d, q3);
450 d = math::FmaOp::create(builder, xx, d, q4);
451
452 // Compute approximation of theta:
453 Value ans0 = arith::DivFOp::create(builder, n, d);
454 ans0 = math::FmaOp::create(builder, ans0, x, x);
455
456 // Correct for the input mapping's angles:
457 Value mpi4 = bcast(f32Cst(builder, llvm::numbers::pi / 4));
458 Value ans2 = arith::AddFOp::create(builder, mpi4, ans0);
459 Value ans = arith::SelectOp::create(builder, cmp2, ans2, ans0);
460
461 Value mpi2 = bcast(f32Cst(builder, llvm::numbers::pi / 2));
462 Value ans1 = arith::SubFOp::create(builder, mpi2, ans0);
463 ans = arith::SelectOp::create(builder, cmp1, ans1, ans);
464
465 // Correct for signing of the input.
466 rewriter.replaceOpWithNewOp<math::CopySignOp>(op, ans, operand);
467 return success();
468}
469
470//----------------------------------------------------------------------------//
471// AtanOp approximation.
472//----------------------------------------------------------------------------//
473
474namespace {
475struct Atan2Approximation : public OpRewritePattern<math::Atan2Op> {
476public:
478
479 LogicalResult matchAndRewrite(math::Atan2Op op,
480 PatternRewriter &rewriter) const final;
481};
482} // namespace
483
484LogicalResult
485Atan2Approximation::matchAndRewrite(math::Atan2Op op,
486 PatternRewriter &rewriter) const {
487 auto y = op.getOperand(0);
488 auto x = op.getOperand(1);
489 if (!getElementTypeOrSelf(x).isF32())
490 return rewriter.notifyMatchFailure(op, "unsupported operand type");
491
492 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
493 std::optional<VectorShape> shape = vectorShape(op.getResult());
494
495 // Compute atan in the valid range.
496 auto div = arith::DivFOp::create(builder, y, x);
497 auto atan = math::AtanOp::create(builder, div);
498
499 // Determine what the atan would be for a 180 degree rotation.
500 auto zero = broadcast(builder, f32Cst(builder, 0.0f), shape);
501 auto pi = broadcast(builder, f32Cst(builder, 3.14159265359f), shape);
502 auto addPi = arith::AddFOp::create(builder, atan, pi);
503 auto subPi = arith::SubFOp::create(builder, atan, pi);
504 auto atanGt =
505 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, atan, zero);
506 auto flippedAtan = arith::SelectOp::create(builder, atanGt, subPi, addPi);
507
508 // Determine whether to directly use atan or use the 180 degree flip
509 auto xGt = arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, x, zero);
510 Value result = arith::SelectOp::create(builder, xGt, atan, flippedAtan);
511
512 // Handle x = 0, y > 0
513 Value xZero =
514 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, x, zero);
515 Value yGt =
516 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, y, zero);
517 Value isHalfPi = arith::AndIOp::create(builder, xZero, yGt);
518 auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape);
519 result = arith::SelectOp::create(builder, isHalfPi, halfPi, result);
520
521 // Handle x = 0, y < 0
522 Value yLt =
523 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, y, zero);
524 Value isNegativeHalfPiPi = arith::AndIOp::create(builder, xZero, yLt);
525 auto negativeHalfPiPi =
526 broadcast(builder, f32Cst(builder, -1.57079632679f), shape);
527 result = arith::SelectOp::create(builder, isNegativeHalfPiPi,
528 negativeHalfPiPi, result);
529
530 // Handle x = 0, y = 0;
531 Value yZero =
532 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, y, zero);
533 Value isNan = arith::AndIOp::create(builder, xZero, yZero);
534 Value cstNan = broadcast(builder, f32FromBits(builder, 0x7fc00000), shape);
535 result = arith::SelectOp::create(builder, isNan, cstNan, result);
536
537 rewriter.replaceOp(op, result);
538 return success();
539}
540
541//----------------------------------------------------------------------------//
542// TanhOp approximation.
543//----------------------------------------------------------------------------//
544
545namespace {
546struct TanhApproximation : public OpRewritePattern<math::TanhOp> {
547public:
549
550 LogicalResult matchAndRewrite(math::TanhOp op,
551 PatternRewriter &rewriter) const final;
552};
553} // namespace
554
555LogicalResult
556TanhApproximation::matchAndRewrite(math::TanhOp op,
557 PatternRewriter &rewriter) const {
558 if (!getElementTypeOrSelf(op.getOperand()).isF32())
559 return rewriter.notifyMatchFailure(op, "unsupported operand type");
560
561 std::optional<VectorShape> shape = vectorShape(op.getOperand());
562
563 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
564 auto bcast = [&](Value value) -> Value {
565 return broadcast(builder, value, shape);
566 };
567
568 // Clamp operand into [plusClamp, minusClamp] range.
569 Value minusClamp = bcast(f32Cst(builder, -7.99881172180175781f));
570 Value plusClamp = bcast(f32Cst(builder, 7.99881172180175781f));
571 Value x = clamp(builder, op.getOperand(), minusClamp, plusClamp);
572
573 // Mask for tiny values that are approximated with `operand`.
574 Value tiny = bcast(f32Cst(builder, 0.0004f));
575 Value tinyMask = arith::CmpFOp::create(
576 builder, arith::CmpFPredicate::OLT,
577 math::AbsFOp::create(builder, op.getOperand()), tiny);
578
579 // The monomial coefficients of the numerator polynomial (odd).
580 Value alpha1 = bcast(f32Cst(builder, 4.89352455891786e-03f));
581 Value alpha3 = bcast(f32Cst(builder, 6.37261928875436e-04f));
582 Value alpha5 = bcast(f32Cst(builder, 1.48572235717979e-05f));
583 Value alpha7 = bcast(f32Cst(builder, 5.12229709037114e-08f));
584 Value alpha9 = bcast(f32Cst(builder, -8.60467152213735e-11f));
585 Value alpha11 = bcast(f32Cst(builder, 2.00018790482477e-13f));
586 Value alpha13 = bcast(f32Cst(builder, -2.76076847742355e-16f));
587
588 // The monomial coefficients of the denominator polynomial (even).
589 Value beta0 = bcast(f32Cst(builder, 4.89352518554385e-03f));
590 Value beta2 = bcast(f32Cst(builder, 2.26843463243900e-03f));
591 Value beta4 = bcast(f32Cst(builder, 1.18534705686654e-04f));
592 Value beta6 = bcast(f32Cst(builder, 1.19825839466702e-06f));
593
594 // Since the polynomials are odd/even, we need x^2.
595 Value x2 = arith::MulFOp::create(builder, x, x);
596
597 // Evaluate the numerator polynomial p.
598 Value p = math::FmaOp::create(builder, x2, alpha13, alpha11);
599 p = math::FmaOp::create(builder, x2, p, alpha9);
600 p = math::FmaOp::create(builder, x2, p, alpha7);
601 p = math::FmaOp::create(builder, x2, p, alpha5);
602 p = math::FmaOp::create(builder, x2, p, alpha3);
603 p = math::FmaOp::create(builder, x2, p, alpha1);
604 p = arith::MulFOp::create(builder, x, p);
605
606 // Evaluate the denominator polynomial q.
607 Value q = math::FmaOp::create(builder, x2, beta6, beta4);
608 q = math::FmaOp::create(builder, x2, q, beta2);
609 q = math::FmaOp::create(builder, x2, q, beta0);
610
611 // Divide the numerator by the denominator.
612 Value res = arith::SelectOp::create(builder, tinyMask, x,
613 arith::DivFOp::create(builder, p, q));
614
615 rewriter.replaceOp(op, res);
616
617 return success();
618}
619
620#define LN2_VALUE \
621 0.693147180559945309417232121458176568075500134360255254120680009493393621L
622#define LOG2E_VALUE \
623 1.442695040888963407359924681001892137426645954152985934135449406931109219L
624
625//----------------------------------------------------------------------------//
626// LogOp and Log2Op approximation.
627//----------------------------------------------------------------------------//
628
629namespace {
630template <typename Op>
631struct LogApproximationBase : public OpRewritePattern<Op> {
633
634 /// Base 2 if 'base2' is set; natural logarithm (base e) otherwise.
635 LogicalResult logMatchAndRewrite(Op op, PatternRewriter &rewriter,
636 bool base2) const;
637};
638} // namespace
639
640// This approximation comes from Julien Pommier's SSE math library.
641// Link: http://gruntthepeon.free.fr/ssemath
642template <typename Op>
643LogicalResult
644LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
645 bool base2) const {
646 if (!getElementTypeOrSelf(op.getOperand()).isF32())
647 return rewriter.notifyMatchFailure(op, "unsupported operand type");
648
649 std::optional<VectorShape> shape = vectorShape(op.getOperand());
650
651 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
652 auto bcast = [&](Value value) -> Value {
653 return broadcast(builder, value, shape);
654 };
655
656 Value cstZero = bcast(f32Cst(builder, 0.0f));
657 Value cstOne = bcast(f32Cst(builder, 1.0f));
658 Value cstNegHalf = bcast(f32Cst(builder, -0.5f));
659
660 // The smallest non denormalized float number.
661 Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u));
662 Value cstMinusInf = bcast(f32FromBits(builder, 0xff800000u));
663 Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));
664 Value cstNan = bcast(f32FromBits(builder, 0x7fc00000));
665
666 // Polynomial coefficients.
667 Value cstCephesSQRTHF = bcast(f32Cst(builder, 0.707106781186547524f));
668 Value cstCephesLogP0 = bcast(f32Cst(builder, 7.0376836292E-2f));
669 Value cstCephesLogP1 = bcast(f32Cst(builder, -1.1514610310E-1f));
670 Value cstCephesLogP2 = bcast(f32Cst(builder, 1.1676998740E-1f));
671 Value cstCephesLogP3 = bcast(f32Cst(builder, -1.2420140846E-1f));
672 Value cstCephesLogP4 = bcast(f32Cst(builder, +1.4249322787E-1f));
673 Value cstCephesLogP5 = bcast(f32Cst(builder, -1.6668057665E-1f));
674 Value cstCephesLogP6 = bcast(f32Cst(builder, +2.0000714765E-1f));
675 Value cstCephesLogP7 = bcast(f32Cst(builder, -2.4999993993E-1f));
676 Value cstCephesLogP8 = bcast(f32Cst(builder, +3.3333331174E-1f));
677
678 Value x = op.getOperand();
679
680 // Truncate input values to the minimum positive normal.
681 x = max(builder, x, cstMinNormPos);
682
683 // Extract significant in the range [0.5,1) and exponent.
684 std::pair<Value, Value> pair = frexp(builder, x, /*isPositive=*/true);
685 x = pair.first;
686 Value e = pair.second;
687
688 // Shift the inputs from the range [0.5,1) to [sqrt(1/2), sqrt(2)) and shift
689 // by -1.0. The values are then centered around 0, which improves the
690 // stability of the polynomial evaluation:
691 //
692 // if( x < SQRTHF ) {
693 // e -= 1;
694 // x = x + x - 1.0;
695 // } else { x = x - 1.0; }
696 Value mask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, x,
697 cstCephesSQRTHF);
698 Value tmp = arith::SelectOp::create(builder, mask, x, cstZero);
699
700 x = arith::SubFOp::create(builder, x, cstOne);
701 e = arith::SubFOp::create(
702 builder, e, arith::SelectOp::create(builder, mask, cstOne, cstZero));
703 x = arith::AddFOp::create(builder, x, tmp);
704
705 Value x2 = arith::MulFOp::create(builder, x, x);
706 Value x3 = arith::MulFOp::create(builder, x2, x);
707
708 // Evaluate the polynomial approximant of degree 8 in three parts.
709 Value y0, y1, y2;
710 y0 = math::FmaOp::create(builder, cstCephesLogP0, x, cstCephesLogP1);
711 y1 = math::FmaOp::create(builder, cstCephesLogP3, x, cstCephesLogP4);
712 y2 = math::FmaOp::create(builder, cstCephesLogP6, x, cstCephesLogP7);
713 y0 = math::FmaOp::create(builder, y0, x, cstCephesLogP2);
714 y1 = math::FmaOp::create(builder, y1, x, cstCephesLogP5);
715 y2 = math::FmaOp::create(builder, y2, x, cstCephesLogP8);
716 y0 = math::FmaOp::create(builder, y0, x3, y1);
717 y0 = math::FmaOp::create(builder, y0, x3, y2);
718 y0 = arith::MulFOp::create(builder, y0, x3);
719
720 y0 = math::FmaOp::create(builder, cstNegHalf, x2, y0);
721 x = arith::AddFOp::create(builder, x, y0);
722
723 if (base2) {
724 Value cstLog2e = bcast(f32Cst(builder, static_cast<float>(LOG2E_VALUE)));
725 x = math::FmaOp::create(builder, x, cstLog2e, e);
726 } else {
727 Value cstLn2 = bcast(f32Cst(builder, static_cast<float>(LN2_VALUE)));
728 x = math::FmaOp::create(builder, e, cstLn2, x);
729 }
730
731 Value invalidMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::ULT,
732 op.getOperand(), cstZero);
733 Value zeroMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ,
734 op.getOperand(), cstZero);
735 Value posInfMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ,
736 op.getOperand(), cstPosInf);
737
738 // Filter out invalid values:
739 // • x == 0 -> -INF
740 // • x < 0 -> NAN
741 // • x == +INF -> +INF
742 Value aproximation = arith::SelectOp::create(
743 builder, zeroMask, cstMinusInf,
744 arith::SelectOp::create(
745 builder, invalidMask, cstNan,
746 arith::SelectOp::create(builder, posInfMask, cstPosInf, x)));
747
748 rewriter.replaceOp(op, aproximation);
749
750 return success();
751}
752
753namespace {
754struct LogApproximation : public LogApproximationBase<math::LogOp> {
755 using LogApproximationBase::LogApproximationBase;
756
757 LogicalResult matchAndRewrite(math::LogOp op,
758 PatternRewriter &rewriter) const final {
759 return logMatchAndRewrite(op, rewriter, /*base2=*/false);
760 }
761};
762} // namespace
763
764namespace {
765struct Log2Approximation : public LogApproximationBase<math::Log2Op> {
766 using LogApproximationBase::LogApproximationBase;
767
768 LogicalResult matchAndRewrite(math::Log2Op op,
769 PatternRewriter &rewriter) const final {
770 return logMatchAndRewrite(op, rewriter, /*base2=*/true);
771 }
772};
773} // namespace
774
775//----------------------------------------------------------------------------//
776// Log1p approximation.
777//----------------------------------------------------------------------------//
778
779namespace {
780struct Log1pApproximation : public OpRewritePattern<math::Log1pOp> {
781public:
783
784 LogicalResult matchAndRewrite(math::Log1pOp op,
785 PatternRewriter &rewriter) const final;
786};
787} // namespace
788
789// Approximate log(1+x).
790LogicalResult
791Log1pApproximation::matchAndRewrite(math::Log1pOp op,
792 PatternRewriter &rewriter) const {
793 if (!getElementTypeOrSelf(op.getOperand()).isF32())
794 return rewriter.notifyMatchFailure(op, "unsupported operand type");
795
796 std::optional<VectorShape> shape = vectorShape(op.getOperand());
797
798 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
799 auto bcast = [&](Value value) -> Value {
800 return broadcast(builder, value, shape);
801 };
802
803 // Approximate log(1+x) using the following, due to W. Kahan:
804 // u = x + 1.0;
805 // if (u == 1.0 || u == inf) return x;
806 // return x * log(u) / (u - 1.0);
807 // ^^^^^^^^^^^^^^^^^^^^^^
808 // "logLarge" below.
809 Value cstOne = bcast(f32Cst(builder, 1.0f));
810 Value x = op.getOperand();
811 Value u = arith::AddFOp::create(builder, x, cstOne);
812 Value uSmall =
813 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, u, cstOne);
814 Value logU = math::LogOp::create(builder, u);
815 Value uInf =
816 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, u, logU);
817 Value logLarge = arith::MulFOp::create(
818 builder, x,
819 arith::DivFOp::create(builder, logU,
820 arith::SubFOp::create(builder, u, cstOne)));
821 Value approximation = arith::SelectOp::create(
822 builder, arith::OrIOp::create(builder, uSmall, uInf), x, logLarge);
823 rewriter.replaceOp(op, approximation);
824 return success();
825}
826
827//----------------------------------------------------------------------------//
828// Asin approximation.
829//----------------------------------------------------------------------------//
830
831// Approximates asin(x).
832// This approximation is based on the following stackoverflow post:
833// https://stackoverflow.com/a/42683455
834namespace {
835struct AsinPolynomialApproximation : public OpRewritePattern<math::AsinOp> {
836public:
838
839 LogicalResult matchAndRewrite(math::AsinOp op,
840 PatternRewriter &rewriter) const final;
841};
842} // namespace
843LogicalResult
844AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
845 PatternRewriter &rewriter) const {
846 Value operand = op.getOperand();
847 Type elementType = getElementTypeOrSelf(operand);
848
849 if (!(elementType.isF32() || elementType.isF16()))
850 return rewriter.notifyMatchFailure(op,
851 "only f32 and f16 type is supported.");
852 std::optional<VectorShape> shape = vectorShape(operand);
853
854 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
855 auto bcast = [&](Value value) -> Value {
856 return broadcast(builder, value, shape);
857 };
858
859 auto fma = [&](Value a, Value b, Value c) -> Value {
860 return math::FmaOp::create(builder, a, b, c);
861 };
862
863 auto mul = [&](Value a, Value b) -> Value {
864 return arith::MulFOp::create(builder, a, b);
865 };
866
867 auto sub = [&](Value a, Value b) -> Value {
868 return arith::SubFOp::create(builder, a, b);
869 };
870
871 auto abs = [&](Value a) -> Value { return math::AbsFOp::create(builder, a); };
872
873 auto sqrt = [&](Value a) -> Value {
874 return math::SqrtOp::create(builder, a);
875 };
876
877 auto scopy = [&](Value a, Value b) -> Value {
878 return math::CopySignOp::create(builder, a, b);
879 };
880
881 auto sel = [&](Value a, Value b, Value c) -> Value {
882 return arith::SelectOp::create(builder, a, b, c);
883 };
884
885 Value abso = abs(operand);
886 Value aa = mul(operand, operand);
887 Value opp = sqrt(sub(bcast(floatCst(builder, 1.0, elementType)), aa));
888
889 Value gt = arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, aa,
890 bcast(floatCst(builder, 0.5, elementType)));
891
892 Value x = sel(gt, opp, abso);
893
894 // Asin(x) approximation for x = [-9/16, 9/16]:
895 Value s = mul(x, x);
896 Value q = mul(s, s);
897 Value r = bcast(floatCst(builder, 5.5579749017470502e-2, elementType));
898 Value t = bcast(floatCst(builder, -6.2027913464120114e-2, elementType));
899
900 r = fma(r, q, bcast(floatCst(builder, 5.4224464349245036e-2, elementType)));
901 t = fma(t, q, bcast(floatCst(builder, -1.1326992890324464e-2, elementType)));
902 r = fma(r, q, bcast(floatCst(builder, 1.5268872539397656e-2, elementType)));
903 t = fma(t, q, bcast(floatCst(builder, 1.0493798473372081e-2, elementType)));
904 r = fma(r, q, bcast(floatCst(builder, 1.4106045900607047e-2, elementType)));
905 t = fma(t, q, bcast(floatCst(builder, 1.7339776384962050e-2, elementType)));
906 r = fma(r, q, bcast(floatCst(builder, 2.2372961589651054e-2, elementType)));
907 t = fma(t, q, bcast(floatCst(builder, 3.0381912707941005e-2, elementType)));
908 r = fma(r, q, bcast(floatCst(builder, 4.4642857881094775e-2, elementType)));
909 t = fma(t, q, bcast(floatCst(builder, 7.4999999991367292e-2, elementType)));
910 r = fma(r, s, t);
911 r = fma(r, s, bcast(floatCst(builder, 1.6666666666670193e-1, elementType)));
912 t = mul(x, s);
913 r = fma(r, t, x);
914
915 Value rsub = sub(bcast(floatCst(builder, 1.57079632679, elementType)), r);
916 r = sel(gt, rsub, r);
917 r = scopy(r, operand);
918
919 rewriter.replaceOp(op, r);
920 return success();
921}
922
923//----------------------------------------------------------------------------//
924// Acos approximation.
925//----------------------------------------------------------------------------//
926
927// Approximates acos(x).
928// This approximation is based on the following stackoverflow post:
929// https://stackoverflow.com/a/42683455
930namespace {
931struct AcosPolynomialApproximation : public OpRewritePattern<math::AcosOp> {
932public:
934
935 LogicalResult matchAndRewrite(math::AcosOp op,
936 PatternRewriter &rewriter) const final;
937};
938} // namespace
939LogicalResult
940AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
941 PatternRewriter &rewriter) const {
942 Value operand = op.getOperand();
943 Type elementType = getElementTypeOrSelf(operand);
944
945 if (!(elementType.isF32() || elementType.isF16()))
946 return rewriter.notifyMatchFailure(op,
947 "only f32 and f16 type is supported.");
948 std::optional<VectorShape> shape = vectorShape(operand);
949
950 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
951 auto bcast = [&](Value value) -> Value {
952 return broadcast(builder, value, shape);
953 };
954
955 auto fma = [&](Value a, Value b, Value c) -> Value {
956 return math::FmaOp::create(builder, a, b, c);
957 };
958
959 auto mul = [&](Value a, Value b) -> Value {
960 return arith::MulFOp::create(builder, a, b);
961 };
962
963 Value negOperand = arith::NegFOp::create(builder, operand);
964 Value zero = bcast(floatCst(builder, 0.0, elementType));
965 Value half = bcast(floatCst(builder, 0.5, elementType));
966 Value negOne = bcast(floatCst(builder, -1.0, elementType));
967 Value selR =
968 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, operand, zero);
969 Value r = arith::SelectOp::create(builder, selR, negOperand, operand);
970 Value chkConst = bcast(floatCst(builder, -0.5625, elementType));
971 Value firstPred =
972 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, r, chkConst);
973
974 Value trueVal =
975 fma(bcast(floatCst(builder, 9.3282184640716537e-1, elementType)),
976 bcast(floatCst(builder, 1.6839188885261840e+0, elementType)),
977 math::AsinOp::create(builder, r));
978
979 Value falseVal = math::SqrtOp::create(builder, fma(half, r, half));
980 falseVal = math::AsinOp::create(builder, falseVal);
981 falseVal = mul(bcast(floatCst(builder, 2.0, elementType)), falseVal);
982
983 r = arith::SelectOp::create(builder, firstPred, trueVal, falseVal);
984
985 // Check whether the operand lies in between [-1.0, 0.0).
986 Value greaterThanNegOne = arith::CmpFOp::create(
987 builder, arith::CmpFPredicate::OGE, operand, negOne);
988
989 Value lessThanZero =
990 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, operand, zero);
991
992 Value betweenNegOneZero =
993 arith::AndIOp::create(builder, greaterThanNegOne, lessThanZero);
994
995 trueVal = fma(bcast(floatCst(builder, 1.8656436928143307e+0, elementType)),
996 bcast(floatCst(builder, 1.6839188885261840e+0, elementType)),
997 arith::NegFOp::create(builder, r));
998
999 Value finalVal =
1000 arith::SelectOp::create(builder, betweenNegOneZero, trueVal, r);
1001
1002 rewriter.replaceOp(op, finalVal);
1003 return success();
1004}
1005
1006//----------------------------------------------------------------------------//
1007// Erf approximation.
1008//----------------------------------------------------------------------------//
1009
1010// Approximates erf(x) with
1011// a - P(x)/Q(x)
1012// where P and Q are polynomials of degree 4.
1013// Different coefficients are chosen based on the value of x.
1014// The approximation error is ~2.5e-07.
1015// Boost's minimax tool that utilizes the Remez method was used to find the
1016// coefficients.
1017LogicalResult
1019 PatternRewriter &rewriter) const {
1020 Value operand = op.getOperand();
1021 Type elementType = getElementTypeOrSelf(operand);
1022
1023 if (!(elementType.isF32() || elementType.isF16()))
1024 return rewriter.notifyMatchFailure(op,
1025 "only f32 and f16 type is supported.");
1026 std::optional<VectorShape> shape = vectorShape(operand);
1027
1028 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1029 auto bcast = [&](Value value) -> Value {
1030 return broadcast(builder, value, shape);
1031 };
1032
1033 const int intervalsCount = 3;
1034 const int polyDegree = 4;
1035
1036 Value zero = bcast(floatCst(builder, 0, elementType));
1037 Value one = bcast(floatCst(builder, 1, elementType));
1038 Value pp[intervalsCount][polyDegree + 1];
1039 pp[0][0] = bcast(floatCst(builder, +0.00000000000000000e+00f, elementType));
1040 pp[0][1] = bcast(floatCst(builder, +1.12837916222975858e+00f, elementType));
1041 pp[0][2] = bcast(floatCst(builder, -5.23018562988006470e-01f, elementType));
1042 pp[0][3] = bcast(floatCst(builder, +2.09741709609267072e-01f, elementType));
1043 pp[0][4] = bcast(floatCst(builder, +2.58146801602987875e-02f, elementType));
1044 pp[1][0] = bcast(floatCst(builder, +0.00000000000000000e+00f, elementType));
1045 pp[1][1] = bcast(floatCst(builder, +1.12750687816789140e+00f, elementType));
1046 pp[1][2] = bcast(floatCst(builder, -3.64721408487825775e-01f, elementType));
1047 pp[1][3] = bcast(floatCst(builder, +1.18407396425136952e-01f, elementType));
1048 pp[1][4] = bcast(floatCst(builder, +3.70645533056476558e-02f, elementType));
1049 pp[2][0] = bcast(floatCst(builder, -3.30093071049483172e-03f, elementType));
1050 pp[2][1] = bcast(floatCst(builder, +3.51961938357697011e-03f, elementType));
1051 pp[2][2] = bcast(floatCst(builder, -1.41373622814988039e-03f, elementType));
1052 pp[2][3] = bcast(floatCst(builder, +2.53447094961941348e-04f, elementType));
1053 pp[2][4] = bcast(floatCst(builder, -1.71048029455037401e-05f, elementType));
1054
1055 Value qq[intervalsCount][polyDegree + 1];
1056 qq[0][0] = bcast(floatCst(builder, +1.000000000000000000e+00f, elementType));
1057 qq[0][1] = bcast(floatCst(builder, -4.635138185962547255e-01f, elementType));
1058 qq[0][2] = bcast(floatCst(builder, +5.192301327279782447e-01f, elementType));
1059 qq[0][3] = bcast(floatCst(builder, -1.318089722204810087e-01f, elementType));
1060 qq[0][4] = bcast(floatCst(builder, +7.397964654672315005e-02f, elementType));
1061 qq[1][0] = bcast(floatCst(builder, +1.00000000000000000e+00f, elementType));
1062 qq[1][1] = bcast(floatCst(builder, -3.27607011824493086e-01f, elementType));
1063 qq[1][2] = bcast(floatCst(builder, +4.48369090658821977e-01f, elementType));
1064 qq[1][3] = bcast(floatCst(builder, -8.83462621207857930e-02f, elementType));
1065 qq[1][4] = bcast(floatCst(builder, +5.72442770283176093e-02f, elementType));
1066 qq[2][0] = bcast(floatCst(builder, +1.00000000000000000e+00f, elementType));
1067 qq[2][1] = bcast(floatCst(builder, -2.06069165953913769e+00f, elementType));
1068 qq[2][2] = bcast(floatCst(builder, +1.62705939945477759e+00f, elementType));
1069 qq[2][3] = bcast(floatCst(builder, -5.83389859211130017e-01f, elementType));
1070 qq[2][4] = bcast(floatCst(builder, +8.21908939856640930e-02f, elementType));
1071
1072 Value offsets[intervalsCount];
1073 offsets[0] = bcast(floatCst(builder, 0.0f, elementType));
1074 offsets[1] = bcast(floatCst(builder, 0.0f, elementType));
1075 offsets[2] = bcast(floatCst(builder, 1.0f, elementType));
1076
1077 Value bounds[intervalsCount];
1078 bounds[0] = bcast(floatCst(builder, 0.8f, elementType));
1079 bounds[1] = bcast(floatCst(builder, 2.0f, elementType));
1080 bounds[2] = bcast(floatCst(builder, 3.75f, elementType));
1081
1082 Value isNegativeArg =
1083 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, operand, zero);
1084 Value negArg = arith::NegFOp::create(builder, operand);
1085 Value x = arith::SelectOp::create(builder, isNegativeArg, negArg, operand);
1086
1087 Value offset = offsets[0];
1088 Value p[polyDegree + 1];
1089 Value q[polyDegree + 1];
1090 for (int i = 0; i <= polyDegree; ++i) {
1091 p[i] = pp[0][i];
1092 q[i] = qq[0][i];
1093 }
1094
1095 // TODO: maybe use vector stacking to reduce the number of selects.
1096 Value isLessThanBound[intervalsCount];
1097 for (int j = 0; j < intervalsCount - 1; ++j) {
1098 isLessThanBound[j] =
1099 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, x, bounds[j]);
1100 for (int i = 0; i <= polyDegree; ++i) {
1101 p[i] = arith::SelectOp::create(builder, isLessThanBound[j], p[i],
1102 pp[j + 1][i]);
1103 q[i] = arith::SelectOp::create(builder, isLessThanBound[j], q[i],
1104 qq[j + 1][i]);
1105 }
1106 offset = arith::SelectOp::create(builder, isLessThanBound[j], offset,
1107 offsets[j + 1]);
1108 }
1109 isLessThanBound[intervalsCount - 1] = arith::CmpFOp::create(
1110 builder, arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
1111
1112 Value pPoly = makePolynomialCalculation(builder, p, x);
1113 Value qPoly = makePolynomialCalculation(builder, q, x);
1114 Value rationalPoly = arith::DivFOp::create(builder, pPoly, qPoly);
1115 Value formula = arith::AddFOp::create(builder, offset, rationalPoly);
1116 formula = arith::SelectOp::create(
1117 builder, isLessThanBound[intervalsCount - 1], formula, one);
1118
1119 // erf is odd function: erf(x) = -erf(-x).
1120 Value negFormula = arith::NegFOp::create(builder, formula);
1121 Value res =
1122 arith::SelectOp::create(builder, isNegativeArg, negFormula, formula);
1123
1124 rewriter.replaceOp(op, res);
1125
1126 return success();
1127}
1128
1129// Approximates erfc(x) with p((x - 2) / (x + 2)), where p is a 9 degree
1130// polynomial.This approximation is based on the following stackoverflow post:
1131// https://stackoverflow.com/questions/35966695/vectorizable-implementation-of-complementary-error-function-erfcf
1132// The stackoverflow post is in turn based on:
1133// M. M. Shepherd and J. G. Laframboise, "Chebyshev Approximation of
1134// (1+2x)exp(x^2)erfc x in 0 <= x < INF", Mathematics of Computation, Vol. 36,
1135// No. 153, January 1981, pp. 249-253.
1136//
1137// Maximum error: 2.65 ulps
1138LogicalResult
1140 PatternRewriter &rewriter) const {
1141 Value x = op.getOperand();
1142 Type et = getElementTypeOrSelf(x);
1143
1144 if (!et.isF32())
1145 return rewriter.notifyMatchFailure(op, "only f32 type is supported.");
1146 std::optional<VectorShape> shape = vectorShape(x);
1147
1148 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1149 auto bcast = [&](Value value) -> Value {
1150 return broadcast(builder, value, shape);
1151 };
1152
1153 Value trueValue = bcast(boolCst(builder, true));
1154 Value zero = bcast(floatCst(builder, 0.0f, et));
1155 Value one = bcast(floatCst(builder, 1.0f, et));
1156 Value onehalf = bcast(floatCst(builder, 0.5f, et));
1157 Value neg4 = bcast(floatCst(builder, -4.0f, et));
1158 Value neg2 = bcast(floatCst(builder, -2.0f, et));
1159 Value pos2 = bcast(floatCst(builder, 2.0f, et));
1160 Value posInf = bcast(floatCst(builder, INFINITY, et));
1161 Value clampVal = bcast(floatCst(builder, 10.0546875f, et));
1162
1163 Value a = math::AbsFOp::create(builder, x);
1164 Value p = arith::AddFOp::create(builder, a, pos2);
1165 Value r = arith::DivFOp::create(builder, one, p);
1166 Value q = math::FmaOp::create(builder, neg4, r, one);
1167 Value t = math::FmaOp::create(builder, arith::AddFOp::create(builder, q, one),
1168 neg2, a);
1169 Value e =
1170 math::FmaOp::create(builder, arith::NegFOp::create(builder, a), q, t);
1171 q = math::FmaOp::create(builder, r, e, q);
1172
1173 p = bcast(floatCst(builder, -0x1.a4a000p-12f, et)); // -4.01139259e-4
1174 Value c1 = bcast(floatCst(builder, -0x1.42a260p-10f, et)); // -1.23075210e-3
1175 p = math::FmaOp::create(builder, p, q, c1);
1176 Value c2 = bcast(floatCst(builder, 0x1.585714p-10f, et)); // 1.31355342e-3
1177 p = math::FmaOp::create(builder, p, q, c2);
1178 Value c3 = bcast(floatCst(builder, 0x1.1adcc4p-07f, et)); // 8.63227434e-3
1179 p = math::FmaOp::create(builder, p, q, c3);
1180 Value c4 = bcast(floatCst(builder, -0x1.081b82p-07f, et)); // -8.05991981e-3
1181 p = math::FmaOp::create(builder, p, q, c4);
1182 Value c5 = bcast(floatCst(builder, -0x1.bc0b6ap-05f, et)); // -5.42046614e-2
1183 p = math::FmaOp::create(builder, p, q, c5);
1184 Value c6 = bcast(floatCst(builder, 0x1.4ffc46p-03f, et)); // 1.64055392e-1
1185 p = math::FmaOp::create(builder, p, q, c6);
1186 Value c7 = bcast(floatCst(builder, -0x1.540840p-03f, et)); // -1.66031361e-1
1187 p = math::FmaOp::create(builder, p, q, c7);
1188 Value c8 = bcast(floatCst(builder, -0x1.7bf616p-04f, et)); // -9.27639827e-2
1189 p = math::FmaOp::create(builder, p, q, c8);
1190 Value c9 = bcast(floatCst(builder, 0x1.1ba03ap-02f, et)); // 2.76978403e-1
1191 p = math::FmaOp::create(builder, p, q, c9);
1192
1193 Value d = math::FmaOp::create(builder, pos2, a, one);
1194 r = arith::DivFOp::create(builder, one, d);
1195 q = math::FmaOp::create(builder, p, r, r);
1196 Value negfa = arith::NegFOp::create(builder, a);
1197 Value fmaqah = math::FmaOp::create(builder, q, negfa, onehalf);
1198 Value psubq = arith::SubFOp::create(builder, p, q);
1199 e = math::FmaOp::create(builder, fmaqah, pos2, psubq);
1200 r = math::FmaOp::create(builder, e, r, q);
1201
1202 Value s = arith::MulFOp::create(builder, a, a);
1203 e = math::ExpOp::create(builder, arith::NegFOp::create(builder, s));
1204
1205 t = math::FmaOp::create(builder, arith::NegFOp::create(builder, a), a, s);
1206 r = math::FmaOp::create(
1207 builder, r, e,
1208 arith::MulFOp::create(builder, arith::MulFOp::create(builder, r, e), t));
1209
1210 Value isNotLessThanInf = arith::XOrIOp::create(
1211 builder,
1212 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, a, posInf),
1213 trueValue);
1214 r = arith::SelectOp::create(builder, isNotLessThanInf,
1215 arith::AddFOp::create(builder, x, x), r);
1216 Value isGreaterThanClamp =
1217 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, a, clampVal);
1218 r = arith::SelectOp::create(builder, isGreaterThanClamp, zero, r);
1219
1220 Value isNegative =
1221 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, x, zero);
1222 r = arith::SelectOp::create(builder, isNegative,
1223 arith::SubFOp::create(builder, pos2, r), r);
1224
1225 rewriter.replaceOp(op, r);
1226 return success();
1227}
1228//----------------------------------------------------------------------------//
1229// Exp approximation.
1230//----------------------------------------------------------------------------//
1231
1232namespace {
1233
1234Value clampWithNormals(ImplicitLocOpBuilder &builder,
1235 const std::optional<VectorShape> shape, Value value,
1236 float lowerBound, float upperBound) {
1237 assert(!std::isnan(lowerBound));
1238 assert(!std::isnan(upperBound));
1239
1240 auto bcast = [&](Value value) -> Value {
1241 return broadcast(builder, value, shape);
1242 };
1243
1244 auto selectCmp = [&builder](auto pred, Value value, Value bound) {
1245 return arith::SelectOp::create(
1246 builder, arith::CmpFOp::create(builder, pred, value, bound), value,
1247 bound);
1248 };
1249
1250 // Note: prefer UGE/ULE vs. UGT/ULT, since they generate vmaxps/vminps vs.
1251 // vcmpleps+vmovaps on x86_64. The latter outcome is also obtained with
1252 // arith::{Max,Min}FOp.
1253 value = selectCmp(arith::CmpFPredicate::UGE, value,
1254 bcast(f32Cst(builder, lowerBound)));
1255 value = selectCmp(arith::CmpFPredicate::ULE, value,
1256 bcast(f32Cst(builder, upperBound)));
1257 return value;
1258}
1259
1260struct ExpApproximation : public OpRewritePattern<math::ExpOp> {
1261public:
1263
1264 LogicalResult matchAndRewrite(math::ExpOp op,
1265 PatternRewriter &rewriter) const final;
1266};
1267
1268LogicalResult
1269ExpApproximation::matchAndRewrite(math::ExpOp op,
1270 PatternRewriter &rewriter) const {
1271 auto shape = vectorShape(op.getOperand().getType());
1272 auto elementTy = getElementTypeOrSelf(op.getType());
1273 if (!elementTy.isF32())
1274 return rewriter.notifyMatchFailure(op, "unsupported operand type");
1275
1276 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1277
1278 auto add = [&](Value a, Value b) -> Value {
1279 return arith::AddFOp::create(builder, a, b);
1280 };
1281 auto bcast = [&](Value value) -> Value {
1282 return broadcast(builder, value, shape);
1283 };
1284 auto floor = [&](Value a) { return math::FloorOp::create(builder, a); };
1285 auto fmla = [&](Value a, Value b, Value c) {
1286 return math::FmaOp::create(builder, a, b, c);
1287 };
1288 auto mul = [&](Value a, Value b) -> Value {
1289 return arith::MulFOp::create(builder, a, b);
1290 };
1291
1292 // Polynomial approximation from Cephes.
1293 //
1294 // To compute e^x, we re-express it as
1295 //
1296 // e^x = e^(a + b)
1297 // = e^(a + n log(2))
1298 // = e^a * 2^n.
1299 //
1300 // We choose n = round(x / log(2)), restricting the value of `a` to
1301 // (-log(2)/2, log(2)/2). We then use a polynomial to compute e^a. The
1302 // relative error between our approximation and the true value of e^a is less
1303 // than 2^-22.5 for all values of `a` within this range.
1304
1305 // Restrict input to a small range, including some values that evaluate to
1306 // +/- inf. Note that for our lower bound, we choose log(2^-126) instead of
1307 // log(F32_EPSILON). We do so because this routine always flushes denormal
1308 // floating points to 0. Therefore, we only need to worry about exponentiating
1309 // up to the smallest representable non-denormal floating point, which is
1310 // 2^-126.
1311
1312 // Constants.
1313 Value cstHalf = bcast(f32Cst(builder, 0.5f));
1314 Value cstOne = bcast(f32Cst(builder, 1.0f));
1315
1316 // 1/log(2)
1317 Value cstLog2ef = bcast(f32Cst(builder, 1.44269504088896341f));
1318
1319 Value cstExpC1 = bcast(f32Cst(builder, -0.693359375f));
1320 Value cstExpC2 = bcast(f32Cst(builder, 2.12194440e-4f));
1321 Value cstExpP0 = bcast(f32Cst(builder, 1.9875691500E-4f));
1322 Value cstExpP1 = bcast(f32Cst(builder, 1.3981999507E-3f));
1323 Value cstExpP2 = bcast(f32Cst(builder, 8.3334519073E-3f));
1324 Value cstExpP3 = bcast(f32Cst(builder, 4.1665795894E-2f));
1325 Value cstExpP4 = bcast(f32Cst(builder, 1.6666665459E-1f));
1326 Value cstExpP5 = bcast(f32Cst(builder, 5.0000001201E-1f));
1327
1328 // Our computations below aren't particularly sensitive to the exact choices
1329 // here, so we choose values a bit larger/smaller than
1330 //
1331 // log(F32_MAX) = 88.723...
1332 // log(2^-126) = -87.337...
1333 Value x = op.getOperand();
1334 x = clampWithNormals(builder, shape, x, -87.8f, 88.8f);
1335 Value n = floor(fmla(x, cstLog2ef, cstHalf));
1336
1337 // When we eventually do the multiplication in e^a * 2^n, we need to handle
1338 // the case when n > 127, the max fp32 exponent (so 2^n == inf) but e^a < 1
1339 // (so e^a * 2^n != inf). There's a similar problem for n < -126, the
1340 // smallest fp32 exponent.
1341 //
1342 // A straightforward solution would be to detect n out of range and split it
1343 // up, doing
1344 //
1345 // e^a * 2^n = e^a * 2^(n1 + n2)
1346 // = (2^n1 * e^a) * 2^n2.
1347 //
1348 // But it turns out this approach is quite slow, probably because it
1349 // manipulates subnormal values.
1350 //
1351 // The approach we use instead is to clamp n to [-127, 127]. Let n' be the
1352 // value of n clamped to [-127, 127]. In the case where n' = 127, `a` can grow
1353 // up to as large as 88.8 - 127 * log(2) which is about 0.7703. Even though
1354 // this value of `a` is outside our previously specified range, e^a will still
1355 // only have a relative error of approximately 2^-16 at worse. In practice
1356 // this seems to work well enough; it passes our exhaustive tests, breaking
1357 // only one result, and by one ulp (we return exp(88.7228394) = max-float but
1358 // we should return inf).
1359 //
1360 // In the case where n' = -127, the original input value of x is so small that
1361 // e^x, our final answer, is less than 2^-126. Since 2^-126 is the smallest
1362 // normal floating point, and since we flush denormals, we simply return 0. We
1363 // do this in a branchless way by observing that our code for constructing 2^n
1364 // produces 0 if n = -127.
1365 //
1366 // The proof that n' = -127 implies e^x < 2^-126 is as follows:
1367 //
1368 // n' = -127 implies n <= -127
1369 // implies round(x / log(2)) <= -127
1370 // implies x/log(2) < -126.5
1371 // implies x < -126.5 * log(2)
1372 // implies e^x < e^(-126.5 * log(2))
1373 // implies e^x < 2^-126.5 < 2^-126
1374 //
1375 // This proves that n' = -127 implies e^x < 2^-126.
1376 n = clampWithNormals(builder, shape, n, -127.0f, 127.0f);
1377
1378 // Computes x = x - n' * log(2), the value for `a`
1379 x = fmla(cstExpC1, n, x);
1380 x = fmla(cstExpC2, n, x);
1381
1382 // Polynomial to compute z = e^a, accurate for a in (-0.5, 0.5).
1383 Value z = fmla(x, cstExpP0, cstExpP1);
1384 z = fmla(z, x, cstExpP2);
1385 z = fmla(z, x, cstExpP3);
1386 z = fmla(z, x, cstExpP4);
1387 z = fmla(z, x, cstExpP5);
1388 z = fmla(z, mul(x, x), x);
1389 z = add(cstOne, z);
1390
1391 // Convert n' to an i32. This is safe because we clamped it above.
1392 auto i32Vec = broadcast(builder.getI32Type(), shape);
1393 Value nI32 = arith::FPToSIOp::create(builder, i32Vec, n);
1394
1395 // Creates the value 2^n' if -126 <= n' <= 127 and 0 if n' = -127.
1396 Value pow2 = exp2I32(builder, nI32);
1397
1398 // Return z * 2^n' if -126 <= n' <= 127 and 0 if n = -127.
1399 Value ret = mul(z, pow2);
1400
1401 rewriter.replaceOp(op, ret);
1402 return mlir::success();
1403}
1404
1405} // namespace
1406
1407//----------------------------------------------------------------------------//
1408// ExpM1 approximation.
1409//----------------------------------------------------------------------------//
1410
1411namespace {
1412
1413struct ExpM1Approximation : public OpRewritePattern<math::ExpM1Op> {
1414public:
1416
1417 LogicalResult matchAndRewrite(math::ExpM1Op op,
1418 PatternRewriter &rewriter) const final;
1419};
1420} // namespace
1421
1422LogicalResult
1423ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
1424 PatternRewriter &rewriter) const {
1425 if (!getElementTypeOrSelf(op.getOperand()).isF32())
1426 return rewriter.notifyMatchFailure(op, "unsupported operand type");
1427
1428 std::optional<VectorShape> shape = vectorShape(op.getOperand());
1429
1430 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1431 auto bcast = [&](Value value) -> Value {
1432 return broadcast(builder, value, shape);
1433 };
1434
1435 // expm1(x) = exp(x) - 1 = u - 1.
1436 // We have to handle it carefully when x is near 0, i.e. u ~= 1,
1437 // and when the input is ~= -inf, i.e. u - 1 ~= -1.
1438 Value cstOne = bcast(f32Cst(builder, 1.0f));
1439 Value cstNegOne = bcast(f32Cst(builder, -1.0f));
1440 Value x = op.getOperand();
1441 Value u = math::ExpOp::create(builder, x);
1442 Value uEqOneOrNaN =
1443 arith::CmpFOp::create(builder, arith::CmpFPredicate::UEQ, u, cstOne);
1444 Value uMinusOne = arith::SubFOp::create(builder, u, cstOne);
1445 Value uMinusOneEqNegOne = arith::CmpFOp::create(
1446 builder, arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
1447 // logU = log(u) ~= x
1448 Value logU = math::LogOp::create(builder, u);
1449
1450 // Detect exp(x) = +inf; written this way to avoid having to form +inf.
1451 Value isInf =
1452 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, logU, u);
1453
1454 // (u - 1) * (x / ~x)
1455 Value expm1 = arith::MulFOp::create(builder, uMinusOne,
1456 arith::DivFOp::create(builder, x, logU));
1457 expm1 = arith::SelectOp::create(builder, isInf, u, expm1);
1458 Value approximation = arith::SelectOp::create(
1459 builder, uEqOneOrNaN, x,
1460 arith::SelectOp::create(builder, uMinusOneEqNegOne, cstNegOne, expm1));
1461 rewriter.replaceOp(op, approximation);
1462 return success();
1463}
1464
1465//----------------------------------------------------------------------------//
1466// Sin and Cos approximation.
1467//----------------------------------------------------------------------------//
1468
1469namespace {
1470
1471template <bool isSine, typename OpTy>
1472struct SinAndCosApproximation : public OpRewritePattern<OpTy> {
1473public:
1474 using OpRewritePattern<OpTy>::OpRewritePattern;
1475
1476 LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final;
1477};
1478} // namespace
1479
1480#define TWO_OVER_PI \
1481 0.6366197723675813430755350534900574481378385829618257949906693762L
1482#define PI_OVER_2 \
1483 1.5707963267948966192313216916397514420985846996875529104874722961L
1484
1485// Approximates sin(x) or cos(x) by finding the best approximation polynomial in
1486// the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the
1487// reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y).
1488template <bool isSine, typename OpTy>
1489LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
1490 OpTy op, PatternRewriter &rewriter) const {
1491 static_assert(
1492 llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
1493 "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
1494
1495 if (!getElementTypeOrSelf(op.getOperand()).isF32())
1496 return rewriter.notifyMatchFailure(op, "unsupported operand type");
1497
1498 std::optional<VectorShape> shape = vectorShape(op.getOperand());
1499
1500 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1501 auto bcast = [&](Value value) -> Value {
1502 return broadcast(builder, value, shape);
1503 };
1504 auto mul = [&](Value a, Value b) -> Value {
1505 return arith::MulFOp::create(builder, a, b);
1506 };
1507 auto sub = [&](Value a, Value b) -> Value {
1508 return arith::SubFOp::create(builder, a, b);
1509 };
1510 auto floor = [&](Value a) { return math::FloorOp::create(builder, a); };
1511
1512 auto i32Vec = broadcast(builder.getI32Type(), shape);
1513 auto fPToSingedInteger = [&](Value a) -> Value {
1514 return arith::FPToSIOp::create(builder, i32Vec, a);
1515 };
1516
1517 auto modulo4 = [&](Value a) -> Value {
1518 return arith::AndIOp::create(builder, a, bcast(i32Cst(builder, 3)));
1519 };
1520
1521 auto isEqualTo = [&](Value a, Value b) -> Value {
1522 return arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, a, b);
1523 };
1524
1525 auto isGreaterThan = [&](Value a, Value b) -> Value {
1526 return arith::CmpIOp::create(builder, arith::CmpIPredicate::sgt, a, b);
1527 };
1528
1529 auto select = [&](Value cond, Value t, Value f) -> Value {
1530 return arith::SelectOp::create(builder, cond, t, f);
1531 };
1532
1533 auto fmla = [&](Value a, Value b, Value c) {
1534 return math::FmaOp::create(builder, a, b, c);
1535 };
1536
1537 auto bitwiseOr = [&](Value a, Value b) {
1538 return arith::OrIOp::create(builder, a, b);
1539 };
1540
1541 Value twoOverPi = bcast(f32Cst(builder, (float)TWO_OVER_PI));
1542 Value piOverTwo = bcast(f32Cst(builder, (float)PI_OVER_2));
1543
1544 Value x = op.getOperand();
1545
1546 Value k = floor(mul(x, twoOverPi));
1547
1548 Value y = sub(x, mul(k, piOverTwo));
1549
1550 Value cstOne = bcast(f32Cst(builder, 1.0));
1551 Value cstNegativeOne = bcast(f32Cst(builder, -1.0));
1552
1553 Value cstSC2 = bcast(f32Cst(builder, -0.16666667163372039794921875f));
1554 Value cstSC4 = bcast(f32Cst(builder, 8.333347737789154052734375e-3f));
1555 Value cstSC6 = bcast(f32Cst(builder, -1.9842604524455964565277099609375e-4f));
1556 Value cstSC8 =
1557 bcast(f32Cst(builder, 2.760012648650445044040679931640625e-6f));
1558 Value cstSC10 =
1559 bcast(f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f));
1560
1561 Value cstCC2 = bcast(f32Cst(builder, -0.5f));
1562 Value cstCC4 = bcast(f32Cst(builder, 4.166664183139801025390625e-2f));
1563 Value cstCC6 = bcast(f32Cst(builder, -1.388833043165504932403564453125e-3f));
1564 Value cstCC8 = bcast(f32Cst(builder, 2.47562347794882953166961669921875e-5f));
1565 Value cstCC10 =
1566 bcast(f32Cst(builder, -2.59630184018533327616751194000244140625e-7f));
1567
1568 Value kMod4 = modulo4(fPToSingedInteger(k));
1569
1570 Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, 0)));
1571 Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, 1)));
1572 Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, 2)));
1573 Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, 3)));
1574
1575 Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
1576 Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, 1)))
1577 : bitwiseOr(kR1, kR2);
1578
1579 Value y2 = mul(y, y);
1580
1581 Value base = select(sinuseCos, cstOne, y);
1582 Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
1583 Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
1584 Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
1585 Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
1586 Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
1587
1588 Value v1 = fmla(y2, cstC10, cstC8);
1589 Value v2 = fmla(y2, v1, cstC6);
1590 Value v3 = fmla(y2, v2, cstC4);
1591 Value v4 = fmla(y2, v3, cstC2);
1592 Value v5 = fmla(y2, v4, cstOne);
1593 Value v6 = mul(base, v5);
1594
1595 Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6);
1596
1597 rewriter.replaceOp(op, approximation);
1598
1599 return success();
1600}
1601
1602//----------------------------------------------------------------------------//
1603// Cbrt approximation.
1604//----------------------------------------------------------------------------//
1605
1606namespace {
1607struct CbrtApproximation : public OpRewritePattern<math::CbrtOp> {
1609
1610 LogicalResult matchAndRewrite(math::CbrtOp op,
1611 PatternRewriter &rewriter) const final;
1612};
1613} // namespace
1614
1615// Estimation of cube-root using an algorithm defined in
1616// Hacker's Delight 2nd Edition.
1617LogicalResult
1618CbrtApproximation::matchAndRewrite(math::CbrtOp op,
1619 PatternRewriter &rewriter) const {
1620 auto operand = op.getOperand();
1621 if (!getElementTypeOrSelf(operand).isF32())
1622 return rewriter.notifyMatchFailure(op, "unsupported operand type");
1623
1624 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
1625 std::optional<VectorShape> shape = vectorShape(operand);
1626
1627 Type floatTy = getElementTypeOrSelf(operand.getType());
1628 Type intTy = b.getIntegerType(floatTy.getIntOrFloatBitWidth());
1629
1630 // Convert to vector types if necessary.
1631 floatTy = broadcast(floatTy, shape);
1632 intTy = broadcast(intTy, shape);
1633
1634 auto bconst = [&](TypedAttr attr) -> Value {
1635 Value value = arith::ConstantOp::create(b, attr);
1636 return broadcast(b, value, shape);
1637 };
1638
1639 // Declare the initial values:
1640 Value intTwo = bconst(b.getI32IntegerAttr(2));
1641 Value intFour = bconst(b.getI32IntegerAttr(4));
1642 Value intEight = bconst(b.getI32IntegerAttr(8));
1643 Value intMagic = bconst(b.getI32IntegerAttr(0x2a5137a0));
1644 Value fpThird = bconst(b.getF32FloatAttr(0.33333333f));
1645 Value fpTwo = bconst(b.getF32FloatAttr(2.0f));
1646 Value fpZero = bconst(b.getF32FloatAttr(0.0f));
1647
1648 // Compute an approximation of one third:
1649 // union {int ix; float x;};
1650 // x = x0;
1651 // ix = ix/4 + ix/16;
1652 Value absValue = math::AbsFOp::create(b, operand);
1653 Value intValue = arith::BitcastOp::create(b, intTy, absValue);
1654 Value divideBy4 = arith::ShRSIOp::create(b, intValue, intTwo);
1655 Value divideBy16 = arith::ShRSIOp::create(b, intValue, intFour);
1656 intValue = arith::AddIOp::create(b, divideBy4, divideBy16);
1657
1658 // ix = ix + ix/16;
1659 divideBy16 = arith::ShRSIOp::create(b, intValue, intFour);
1660 intValue = arith::AddIOp::create(b, intValue, divideBy16);
1661
1662 // ix = ix + ix/256;
1663 Value divideBy256 = arith::ShRSIOp::create(b, intValue, intEight);
1664 intValue = arith::AddIOp::create(b, intValue, divideBy256);
1665
1666 // ix = 0x2a5137a0 + ix;
1667 intValue = arith::AddIOp::create(b, intValue, intMagic);
1668
1669 // Perform one newtons step:
1670 // x = 0.33333333f*(2.0f*x + x0/(x*x));
1671 Value floatValue = arith::BitcastOp::create(b, floatTy, intValue);
1672 Value squared = arith::MulFOp::create(b, floatValue, floatValue);
1673 Value mulTwo = arith::MulFOp::create(b, floatValue, fpTwo);
1674 Value divSquared = arith::DivFOp::create(b, absValue, squared);
1675 floatValue = arith::AddFOp::create(b, mulTwo, divSquared);
1676 floatValue = arith::MulFOp::create(b, floatValue, fpThird);
1677
1678 // x = 0.33333333f*(2.0f*x + x0/(x*x));
1679 squared = arith::MulFOp::create(b, floatValue, floatValue);
1680 mulTwo = arith::MulFOp::create(b, floatValue, fpTwo);
1681 divSquared = arith::DivFOp::create(b, absValue, squared);
1682 floatValue = arith::AddFOp::create(b, mulTwo, divSquared);
1683 floatValue = arith::MulFOp::create(b, floatValue, fpThird);
1684
1685 // Check for zero and restore sign.
1686 Value isZero =
1687 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absValue, fpZero);
1688 floatValue = arith::SelectOp::create(b, isZero, fpZero, floatValue);
1689 floatValue = math::CopySignOp::create(b, floatValue, operand);
1690
1691 rewriter.replaceOp(op, floatValue);
1692 return success();
1693}
1694
1695//----------------------------------------------------------------------------//
1696// Rsqrt approximation.
1697//----------------------------------------------------------------------------//
1698
1699namespace {
1700struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> {
1702
1703 LogicalResult matchAndRewrite(math::RsqrtOp op,
1704 PatternRewriter &rewriter) const final;
1705};
1706} // namespace
1707
1708LogicalResult
1709RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
1710 PatternRewriter &rewriter) const {
1711 if (!getElementTypeOrSelf(op.getOperand()).isF32())
1712 return rewriter.notifyMatchFailure(op, "unsupported operand type");
1713
1714 std::optional<VectorShape> shape = vectorShape(op.getOperand());
1715
1716 // Only support already-vectorized rsqrt's.
1717 if (!shape || shape->sizes.empty() || shape->sizes.back() % 8 != 0)
1718 return rewriter.notifyMatchFailure(op, "unsupported operand type");
1719
1720 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1721 auto bcast = [&](Value value) -> Value {
1722 return broadcast(builder, value, shape);
1723 };
1724
1725 Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));
1726 Value cstOnePointFive = bcast(f32Cst(builder, 1.5f));
1727 Value cstNegHalf = bcast(f32Cst(builder, -0.5f));
1728 Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u));
1729
1730 Value negHalf = arith::MulFOp::create(builder, op.getOperand(), cstNegHalf);
1731
1732 // Select only the inverse sqrt of positive normals (denormals are
1733 // flushed to zero).
1734 Value ltMinMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT,
1735 op.getOperand(), cstMinNormPos);
1736 Value infMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ,
1737 op.getOperand(), cstPosInf);
1738 Value notNormalFiniteMask = arith::OrIOp::create(builder, ltMinMask, infMask);
1739
1740 // Compute an approximate result.
1741 Value yApprox = handleMultidimensionalVectors(
1742 builder, op->getOperands(), 8, [&builder](ValueRange operands) -> Value {
1743 return x86vector::RsqrtOp::create(builder, operands);
1744 });
1745
1746 // Do a single step of Newton-Raphson iteration to improve the approximation.
1747 // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
1748 // It is essential to evaluate the inner term like this because forming
1749 // y_n^2 may over- or underflow.
1750 Value inner = arith::MulFOp::create(builder, negHalf, yApprox);
1751 Value fma = math::FmaOp::create(builder, yApprox, inner, cstOnePointFive);
1752 Value yNewton = arith::MulFOp::create(builder, yApprox, fma);
1753
1754 // Select the result of the Newton-Raphson step for positive normal arguments.
1755 // For other arguments, choose the output of the intrinsic. This will
1756 // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if
1757 // x is zero or a positive denormalized float (equivalent to flushing positive
1758 // denormalized inputs to zero).
1759 Value res =
1760 arith::SelectOp::create(builder, notNormalFiniteMask, yApprox, yNewton);
1761 rewriter.replaceOp(op, res);
1762
1763 return success();
1764}
1765
1766//----------------------------------------------------------------------------//
1767
1770 patterns.add<TanhApproximation>(patterns.getContext());
1771}
1772
1777
1782
1783template <typename OpType>
1784static void
1786 llvm::function_ref<bool(StringRef)> predicate,
1787 PatternBenefit benefit) {
1788 if (predicate(OpType::getOperationName())) {
1789 patterns.add<ReuseF32Expansion<OpType>>(patterns.getContext(), benefit);
1790 }
1791}
1792
1794 RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
1795 PatternBenefit benefit) {
1822}
1823
1824template <typename OpType, typename PatternType>
1826 RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
1827 PatternBenefit benefit) {
1828 if (predicate(OpType::getOperationName())) {
1829 patterns.add<PatternType>(patterns.getContext(), benefit);
1830 }
1831}
1832
1834 RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
1835 PatternBenefit benefit) {
1837 AcosPolynomialApproximation>(
1838 patterns, predicate, benefit);
1840 AsinPolynomialApproximation>(
1841 patterns, predicate, benefit);
1843 patterns, predicate, benefit);
1845 patterns, predicate, benefit);
1847 patterns, predicate, benefit);
1849 CosOp, SinAndCosApproximation<false, math::CosOp>>(patterns, predicate,
1850 benefit);
1852 patterns, predicate, benefit);
1855 patterns, predicate, benefit);
1857 patterns, predicate, benefit);
1859 patterns, predicate, benefit);
1861 patterns, predicate, benefit);
1863 patterns, predicate, benefit);
1865 patterns, predicate, benefit);
1867 patterns, predicate, benefit);
1869 SinOp, SinAndCosApproximation<true, math::SinOp>>(patterns, predicate,
1870 benefit);
1872 patterns, predicate, benefit);
1873}
1874
1878 mlir::populateMathF32ExpansionPatterns(patterns, [](StringRef name) -> bool {
1879 return llvm::is_contained(
1880 {math::AtanOp::getOperationName(), math::Atan2Op::getOperationName(),
1881 math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
1882 math::Log2Op::getOperationName(), math::Log1pOp::getOperationName(),
1883 math::ErfOp::getOperationName(), math::ErfcOp::getOperationName(),
1884 math::ExpOp::getOperationName(), math::ExpM1Op::getOperationName(),
1885 math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
1886 math::CosOp::getOperationName()},
1887 name);
1888 });
1889
1891 patterns, [](StringRef name) -> bool {
1892 return llvm::is_contained(
1893 {math::AtanOp::getOperationName(),
1894 math::Atan2Op::getOperationName(),
1895 math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
1896 math::Log2Op::getOperationName(),
1897 math::Log1pOp::getOperationName(), math::ErfOp::getOperationName(),
1898 math::ErfcOp::getOperationName(), math::AsinOp::getOperationName(),
1899 math::AcosOp::getOperationName(), math::ExpOp::getOperationName(),
1900 math::ExpM1Op::getOperationName(),
1901 math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
1902 math::CosOp::getOperationName()},
1903 name);
1904 });
1905
1906 if (options.enableAvx2) {
1907 auto predicateRsqrt = [](StringRef name) {
1908 return name == math::RsqrtOp::getOperationName();
1909 };
1912 }
1913}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
#define LN2_VALUE
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg)
#define PI_OVER_2
static void populateMathF32ExpansionPattern(RewritePatternSet &patterns, llvm::function_ref< bool(StringRef)> predicate, PatternBenefit benefit)
#define TWO_OVER_PI
static Value boolCst(ImplicitLocOpBuilder &builder, bool value)
static Value floatCst(ImplicitLocOpBuilder &builder, float value, Type elementType)
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static std::pair< Value, Value > frexp(ImplicitLocOpBuilder &builder, Value arg, bool isPositive=false)
static std::optional< VectorShape > vectorShape(Type type)
static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value)
static Type broadcast(Type type, std::optional< VectorShape > shape)
#define LOG2E_VALUE
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits)
static Value f32Cst(ImplicitLocOpBuilder &builder, double value)
static void populateMathPolynomialApproximationPattern(RewritePatternSet &patterns, llvm::function_ref< bool(StringRef)> predicate, PatternBenefit benefit)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
#define mul(a, b)
#define add(a, b)
#define div(a, b)
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
FloatType getF32Type()
Definition Builders.cpp:43
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:254
IntegerType getI32Type()
Definition Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:100
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
FloatAttr getF32FloatAttr(float value)
Definition Builders.cpp:246
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
Type front()
Return first type in the range.
Definition TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
DynamicAPInt floor(const Fraction &f)
Definition Fraction.h:77
Fraction abs(const Fraction &f)
Definition Fraction.h:107
Include the generated interface declarations.
void populatePolynomialApproximateErfcPattern(RewritePatternSet &patterns)
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
void populateMathF32ExpansionPatterns(RewritePatternSet &patterns, llvm::function_ref< bool(StringRef)> predicate, PatternBenefit=1)
void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
int64_t computeMaxLinearIndex(ArrayRef< int64_t > basis)
Return the number of elements of basis (i.e.
void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns, llvm::function_ref< bool(StringRef)> predicate, PatternBenefit=1)
ArrayRef< int64_t > sizes
ArrayRef< bool > scalableFlags
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(math::ErfOp op, PatternRewriter &rewriter) const final
LogicalResult matchAndRewrite(math::ErfcOp op, PatternRewriter &rewriter) const final
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.