34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/Support/MathExtras.h"
50 if (
auto vectorType = dyn_cast<VectorType>(type)) {
51 return VectorShape{vectorType.getShape(), vectorType.getScalableDims()};
66 assert(!isa<VectorType>(type) &&
"must be scalar type");
67 return shape ?
VectorType::get(shape->sizes, type, shape->scalableFlags)
73 std::optional<VectorShape> shape) {
74 assert(!isa<VectorType>(value.
getType()) &&
"must be scalar value");
76 return shape ? builder.
create<BroadcastOp>(type, value) : value;
102 assert(!operands.empty() &&
"operands must be not empty");
103 assert(vectorWidth > 0 &&
"vector width must be larger than 0");
105 VectorType inputType = cast<VectorType>(operands[0].
getType());
111 return compute(operands);
115 int64_t innerDim = inputShape.back();
116 int64_t expansionDim = innerDim / vectorWidth;
117 assert((innerDim % vectorWidth == 0) &&
"invalid inner dimension size");
124 if (expansionDim > 1) {
126 expandedShape.insert(expandedShape.end() - 1, expansionDim);
127 expandedShape.back() = vectorWidth;
129 for (
unsigned i = 0; i < operands.size(); ++i) {
130 auto operand = operands[i];
131 auto eltType = cast<VectorType>(operand.getType()).getElementType();
133 expandedOperands[i] =
134 builder.
create<vector::ShapeCastOp>(expandedType, operand);
146 for (int64_t i = 0; i < maxIndex; ++i) {
151 extracted[tuple.index()] =
152 builder.
create<vector::ExtractOp>(tuple.value(), offsets);
154 results[i] = compute(extracted);
158 Type resultEltType = cast<VectorType>(results[0].
getType()).getElementType();
161 resultExpandedType, builder.
getZeroAttr(resultExpandedType));
163 for (int64_t i = 0; i < maxIndex; ++i)
164 result = builder.
create<vector::InsertOp>(results[i], result,
168 return builder.
create<vector::ShapeCastOp>(
178 assert((elementType.
isF16() || elementType.
isF32()) &&
179 "x must be f16 or f32 type.");
180 return builder.
create<arith::ConstantOp>(
193 Value i32Value =
i32Cst(builder,
static_cast<int32_t
>(bits));
203 return builder.
create<arith::SelectOp>(
204 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::ULT, value, bound),
210 return builder.
create<arith::SelectOp>(
211 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::UGT, value, bound),
218 return max(builder,
min(builder, value, upperBound), lowerBound);
224 bool isPositive =
false) {
226 std::optional<VectorShape> shape =
vectorShape(arg);
241 Value i32Half = builder.
create<arith::BitcastOp>(i32, cstHalf);
242 Value i32InvMantMask = builder.
create<arith::BitcastOp>(i32, cstInvMantMask);
243 Value i32Arg = builder.
create<arith::BitcastOp>(i32Vec, arg);
246 Value tmp0 = builder.
create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask));
247 Value tmp1 = builder.
create<arith::OrIOp>(tmp0, bcast(i32Half));
248 Value normalizedFraction = builder.
create<arith::BitcastOp>(f32Vec, tmp1);
251 Value arg0 = isPositive ? arg : builder.
create<math::AbsFOp>(arg);
252 Value biasedExponentBits = builder.
create<arith::ShRUIOp>(
253 builder.
create<arith::BitcastOp>(i32Vec, arg0),
254 bcast(
i32Cst(builder, 23)));
255 Value biasedExponent =
256 builder.
create<arith::SIToFPOp>(f32Vec, biasedExponentBits);
258 builder.
create<arith::SubFOp>(biasedExponent, bcast(cst126f));
260 return {normalizedFraction, exponent};
266 std::optional<VectorShape> shape =
vectorShape(arg);
274 auto exponetBitLocation = bcast(
i32Cst(builder, 23));
276 auto bias = bcast(
i32Cst(builder, 127));
278 Value biasedArg = builder.
create<arith::AddIOp>(arg, bias);
280 builder.
create<arith::ShLIOp>(biasedArg, exponetBitLocation);
281 Value exp2ValueF32 = builder.
create<arith::BitcastOp>(f32Vec, exp2ValueInt);
290 assert((elementType.
isF32() || elementType.
isF16()) &&
291 "x must be f32 or f16 type");
297 if (coeffs.size() == 1)
300 Value res = builder.
create<math::FmaOp>(x, coeffs[coeffs.size() - 1],
301 coeffs[coeffs.size() - 2]);
302 for (
auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) {
303 res = builder.
create<math::FmaOp>(x, res, coeffs[i]);
313 template <
typename T>
331 if (
auto shaped = dyn_cast<ShapedType>(origType)) {
332 newType = shaped.clone(rewriter.
getF32Type());
333 }
else if (isa<FloatType>(origType)) {
337 "unable to find F32 equivalent type");
343 operands.push_back(rewriter.
create<arith::ExtFOp>(loc, newType, operand));
356 template <
typename T>
360 LogicalResult matchAndRewrite(T op,
PatternRewriter &rewriter)
const final {
362 T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(),
363 "requires same operands and result types");
364 return insertCasts<T>(op, rewriter);
378 LogicalResult matchAndRewrite(math::AtanOp op,
384 AtanApproximation::matchAndRewrite(math::AtanOp op,
386 auto operand = op.getOperand();
390 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
400 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
abs, twoThirds);
404 Value xden = builder.
create<arith::SelectOp>(cmp2, addone, one);
411 auto tan3pio8 = bcast(
f32Cst(builder, 2.41421356237309504880));
413 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
abs, tan3pio8);
414 xnum = builder.
create<arith::SelectOp>(cmp1, one, xnum);
415 xden = builder.
create<arith::SelectOp>(cmp1,
abs, xden);
417 Value x = builder.
create<arith::DivFOp>(xnum, xden);
422 auto p0 = bcast(
f32Cst(builder, -8.750608600031904122785e-01));
423 auto p1 = bcast(
f32Cst(builder, -1.615753718733365076637e+01));
424 auto p2 = bcast(
f32Cst(builder, -7.500855792314704667340e+01));
425 auto p3 = bcast(
f32Cst(builder, -1.228866684490136173410e+02));
426 auto p4 = bcast(
f32Cst(builder, -6.485021904942025371773e+01));
427 auto q0 = bcast(
f32Cst(builder, +2.485846490142306297962e+01));
428 auto q1 = bcast(
f32Cst(builder, +1.650270098316988542046e+02));
429 auto q2 = bcast(
f32Cst(builder, +4.328810604912902668951e+02));
430 auto q3 = bcast(
f32Cst(builder, +4.853903996359136964868e+02));
431 auto q4 = bcast(
f32Cst(builder, +1.945506571482613964425e+02));
435 n = builder.
create<math::FmaOp>(xx, n, p1);
436 n = builder.
create<math::FmaOp>(xx, n, p2);
437 n = builder.
create<math::FmaOp>(xx, n, p3);
438 n = builder.
create<math::FmaOp>(xx, n, p4);
439 n = builder.
create<arith::MulFOp>(n, xx);
443 d = builder.
create<math::FmaOp>(xx, d, q1);
444 d = builder.
create<math::FmaOp>(xx, d, q2);
445 d = builder.
create<math::FmaOp>(xx, d, q3);
446 d = builder.
create<math::FmaOp>(xx, d, q4);
450 ans0 = builder.
create<math::FmaOp>(ans0, x, x);
453 Value mpi4 = bcast(
f32Cst(builder, llvm::numbers::pi / 4));
454 Value ans2 = builder.
create<arith::AddFOp>(mpi4, ans0);
455 Value ans = builder.
create<arith::SelectOp>(cmp2, ans2, ans0);
457 Value mpi2 = bcast(
f32Cst(builder, llvm::numbers::pi / 2));
458 Value ans1 = builder.
create<arith::SubFOp>(mpi2, ans0);
459 ans = builder.
create<arith::SelectOp>(cmp1, ans1, ans);
475 LogicalResult matchAndRewrite(math::Atan2Op op,
481 Atan2Approximation::matchAndRewrite(math::Atan2Op op,
483 auto y = op.getOperand(0);
484 auto x = op.getOperand(1);
489 std::optional<VectorShape> shape =
vectorShape(op.getResult());
492 auto div = builder.
create<arith::DivFOp>(y, x);
493 auto atan = builder.
create<math::AtanOp>(div);
498 auto addPi = builder.
create<arith::AddFOp>(atan, pi);
499 auto subPi = builder.
create<arith::SubFOp>(atan, pi);
501 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, atan, zero);
502 auto flippedAtan = builder.
create<arith::SelectOp>(atanGt, subPi, addPi);
505 auto xGt = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zero);
506 Value result = builder.
create<arith::SelectOp>(xGt, atan, flippedAtan);
510 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, x, zero);
511 Value yGt = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, y, zero);
512 Value isHalfPi = builder.
create<arith::AndIOp>(xZero, yGt);
513 auto halfPi =
broadcast(builder,
f32Cst(builder, 1.57079632679f), shape);
514 result = builder.
create<arith::SelectOp>(isHalfPi, halfPi, result);
517 Value yLt = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, y, zero);
518 Value isNegativeHalfPiPi = builder.
create<arith::AndIOp>(xZero, yLt);
519 auto negativeHalfPiPi =
521 result = builder.
create<arith::SelectOp>(isNegativeHalfPiPi, negativeHalfPiPi,
526 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, y, zero);
527 Value isNan = builder.
create<arith::AndIOp>(xZero, yZero);
529 result = builder.
create<arith::SelectOp>(isNan, cstNan, result);
544 LogicalResult matchAndRewrite(math::TanhOp op,
550 TanhApproximation::matchAndRewrite(math::TanhOp op,
555 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
563 Value minusClamp = bcast(
f32Cst(builder, -7.99881172180175781f));
564 Value plusClamp = bcast(
f32Cst(builder, 7.99881172180175781f));
565 Value x =
clamp(builder, op.getOperand(), minusClamp, plusClamp);
570 arith::CmpFPredicate::OLT, builder.
create<math::AbsFOp>(op.getOperand()),
574 Value alpha1 = bcast(
f32Cst(builder, 4.89352455891786e-03f));
575 Value alpha3 = bcast(
f32Cst(builder, 6.37261928875436e-04f));
576 Value alpha5 = bcast(
f32Cst(builder, 1.48572235717979e-05f));
577 Value alpha7 = bcast(
f32Cst(builder, 5.12229709037114e-08f));
578 Value alpha9 = bcast(
f32Cst(builder, -8.60467152213735e-11f));
579 Value alpha11 = bcast(
f32Cst(builder, 2.00018790482477e-13f));
580 Value alpha13 = bcast(
f32Cst(builder, -2.76076847742355e-16f));
583 Value beta0 = bcast(
f32Cst(builder, 4.89352518554385e-03f));
584 Value beta2 = bcast(
f32Cst(builder, 2.26843463243900e-03f));
585 Value beta4 = bcast(
f32Cst(builder, 1.18534705686654e-04f));
586 Value beta6 = bcast(
f32Cst(builder, 1.19825839466702e-06f));
592 Value p = builder.
create<math::FmaOp>(x2, alpha13, alpha11);
593 p = builder.
create<math::FmaOp>(x2, p, alpha9);
594 p = builder.
create<math::FmaOp>(x2, p, alpha7);
595 p = builder.
create<math::FmaOp>(x2, p, alpha5);
596 p = builder.
create<math::FmaOp>(x2, p, alpha3);
597 p = builder.
create<math::FmaOp>(x2, p, alpha1);
598 p = builder.
create<arith::MulFOp>(x, p);
601 Value q = builder.
create<math::FmaOp>(x2, beta6, beta4);
602 q = builder.
create<math::FmaOp>(x2, q, beta2);
603 q = builder.
create<math::FmaOp>(x2, q, beta0);
607 tinyMask, x, builder.
create<arith::DivFOp>(p, q));
615 0.693147180559945309417232121458176568075500134360255254120680009493393621L
616 #define LOG2E_VALUE \
617 1.442695040888963407359924681001892137426645954152985934135449406931109219L
624 template <
typename Op>
636 template <
typename Op>
643 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
661 Value cstCephesSQRTHF = bcast(
f32Cst(builder, 0.707106781186547524f));
662 Value cstCephesLogP0 = bcast(
f32Cst(builder, 7.0376836292E-2f));
663 Value cstCephesLogP1 = bcast(
f32Cst(builder, -1.1514610310E-1f));
664 Value cstCephesLogP2 = bcast(
f32Cst(builder, 1.1676998740E-1f));
665 Value cstCephesLogP3 = bcast(
f32Cst(builder, -1.2420140846E-1f));
666 Value cstCephesLogP4 = bcast(
f32Cst(builder, +1.4249322787E-1f));
667 Value cstCephesLogP5 = bcast(
f32Cst(builder, -1.6668057665E-1f));
668 Value cstCephesLogP6 = bcast(
f32Cst(builder, +2.0000714765E-1f));
669 Value cstCephesLogP7 = bcast(
f32Cst(builder, -2.4999993993E-1f));
670 Value cstCephesLogP8 = bcast(
f32Cst(builder, +3.3333331174E-1f));
672 Value x = op.getOperand();
675 x =
max(builder, x, cstMinNormPos);
678 std::pair<Value, Value> pair =
frexp(builder, x,
true);
680 Value e = pair.second;
690 Value mask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x,
692 Value tmp = builder.
create<arith::SelectOp>(mask, x, cstZero);
694 x = builder.
create<arith::SubFOp>(x, cstOne);
695 e = builder.
create<arith::SubFOp>(
696 e, builder.
create<arith::SelectOp>(mask, cstOne, cstZero));
697 x = builder.
create<arith::AddFOp>(x, tmp);
704 y0 = builder.
create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1);
705 y1 = builder.
create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4);
706 y2 = builder.
create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7);
707 y0 = builder.
create<math::FmaOp>(y0, x, cstCephesLogP2);
708 y1 = builder.
create<math::FmaOp>(y1, x, cstCephesLogP5);
709 y2 = builder.
create<math::FmaOp>(y2, x, cstCephesLogP8);
710 y0 = builder.
create<math::FmaOp>(y0, x3, y1);
711 y0 = builder.
create<math::FmaOp>(y0, x3, y2);
712 y0 = builder.
create<arith::MulFOp>(y0, x3);
714 y0 = builder.
create<math::FmaOp>(cstNegHalf, x2, y0);
715 x = builder.
create<arith::AddFOp>(x, y0);
719 x = builder.
create<math::FmaOp>(x, cstLog2e, e);
722 x = builder.
create<math::FmaOp>(e, cstLn2, x);
725 Value invalidMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::ULT,
726 op.getOperand(), cstZero);
727 Value zeroMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
728 op.getOperand(), cstZero);
729 Value posInfMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
730 op.getOperand(), cstPosInf);
736 Value aproximation = builder.
create<arith::SelectOp>(
737 zeroMask, cstMinusInf,
738 builder.
create<arith::SelectOp>(
740 builder.
create<arith::SelectOp>(posInfMask, cstPosInf, x)));
748 struct LogApproximation :
public LogApproximationBase<math::LogOp> {
749 using LogApproximationBase::LogApproximationBase;
751 LogicalResult matchAndRewrite(math::LogOp op,
753 return logMatchAndRewrite(op, rewriter,
false);
759 struct Log2Approximation :
public LogApproximationBase<math::Log2Op> {
760 using LogApproximationBase::LogApproximationBase;
762 LogicalResult matchAndRewrite(math::Log2Op op,
764 return logMatchAndRewrite(op, rewriter,
true);
778 LogicalResult matchAndRewrite(math::Log1pOp op,
785 Log1pApproximation::matchAndRewrite(math::Log1pOp op,
790 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
804 Value x = op.getOperand();
805 Value u = builder.
create<arith::AddFOp>(x, cstOne);
807 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne);
810 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU);
812 x, builder.
create<arith::DivFOp>(
813 logU, builder.
create<arith::SubFOp>(u, cstOne)));
814 Value approximation = builder.
create<arith::SelectOp>(
815 builder.
create<arith::OrIOp>(uSmall, uInf), x, logLarge);
828 struct AsinPolynomialApproximation :
public OpRewritePattern<math::AsinOp> {
832 LogicalResult matchAndRewrite(math::AsinOp op,
837 AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
839 Value operand = op.getOperand();
842 if (!(elementType.
isF32() || elementType.
isF16()))
844 "only f32 and f16 type is supported.");
845 std::optional<VectorShape> shape =
vectorShape(operand);
853 return builder.
create<math::FmaOp>(a, b, c);
857 return builder.
create<arith::MulFOp>(a, b);
861 return builder.
create<arith::SubFOp>(a, b);
866 auto sqrt = [&](
Value a) ->
Value {
return builder.
create<math::SqrtOp>(a); };
869 return builder.
create<math::CopySignOp>(a, b);
873 return builder.
create<arith::SelectOp>(a, b, c);
877 Value aa = mul(operand, operand);
878 Value opp = sqrt(sub(bcast(
floatCst(builder, 1.0, elementType)), aa));
881 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, aa,
882 bcast(
floatCst(builder, 0.5, elementType)));
884 Value x = sel(gt, opp, abso);
889 Value r = bcast(
floatCst(builder, 5.5579749017470502e-2, elementType));
890 Value t = bcast(
floatCst(builder, -6.2027913464120114e-2, elementType));
892 r = fma(r, q, bcast(
floatCst(builder, 5.4224464349245036e-2, elementType)));
893 t = fma(t, q, bcast(
floatCst(builder, -1.1326992890324464e-2, elementType)));
894 r = fma(r, q, bcast(
floatCst(builder, 1.5268872539397656e-2, elementType)));
895 t = fma(t, q, bcast(
floatCst(builder, 1.0493798473372081e-2, elementType)));
896 r = fma(r, q, bcast(
floatCst(builder, 1.4106045900607047e-2, elementType)));
897 t = fma(t, q, bcast(
floatCst(builder, 1.7339776384962050e-2, elementType)));
898 r = fma(r, q, bcast(
floatCst(builder, 2.2372961589651054e-2, elementType)));
899 t = fma(t, q, bcast(
floatCst(builder, 3.0381912707941005e-2, elementType)));
900 r = fma(r, q, bcast(
floatCst(builder, 4.4642857881094775e-2, elementType)));
901 t = fma(t, q, bcast(
floatCst(builder, 7.4999999991367292e-2, elementType)));
903 r = fma(r, s, bcast(
floatCst(builder, 1.6666666666670193e-1, elementType)));
907 Value rsub = sub(bcast(
floatCst(builder, 1.57079632679, elementType)), r);
908 r = sel(gt, rsub, r);
909 r = scopy(r, operand);
923 struct AcosPolynomialApproximation :
public OpRewritePattern<math::AcosOp> {
927 LogicalResult matchAndRewrite(math::AcosOp op,
932 AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
934 Value operand = op.getOperand();
937 if (!(elementType.
isF32() || elementType.
isF16()))
939 "only f32 and f16 type is supported.");
940 std::optional<VectorShape> shape =
vectorShape(operand);
948 return builder.
create<math::FmaOp>(a, b, c);
952 return builder.
create<arith::MulFOp>(a, b);
955 Value negOperand = builder.
create<arith::NegFOp>(operand);
960 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand, zero);
961 Value r = builder.
create<arith::SelectOp>(selR, negOperand, operand);
962 Value chkConst = bcast(
floatCst(builder, -0.5625, elementType));
964 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, r, chkConst);
967 fma(bcast(
floatCst(builder, 9.3282184640716537e-1, elementType)),
968 bcast(
floatCst(builder, 1.6839188885261840e+0, elementType)),
969 builder.
create<math::AsinOp>(r));
971 Value falseVal = builder.
create<math::SqrtOp>(fma(half, r, half));
972 falseVal = builder.
create<math::AsinOp>(falseVal);
973 falseVal = mul(bcast(
floatCst(builder, 2.0, elementType)), falseVal);
975 r = builder.
create<arith::SelectOp>(firstPred, trueVal, falseVal);
978 Value greaterThanNegOne =
979 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, negOne);
982 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
984 Value betweenNegOneZero =
985 builder.
create<arith::AndIOp>(greaterThanNegOne, lessThanZero);
987 trueVal = fma(bcast(
floatCst(builder, 1.8656436928143307e+0, elementType)),
988 bcast(
floatCst(builder, 1.6839188885261840e+0, elementType)),
989 builder.
create<arith::NegFOp>(r));
992 builder.
create<arith::SelectOp>(betweenNegOneZero, trueVal, r);
1012 Value operand = op.getOperand();
1015 if (!(elementType.
isF32() || elementType.
isF16()))
1017 "only f32 and f16 type is supported.");
1018 std::optional<VectorShape> shape =
vectorShape(operand);
1022 return broadcast(builder, value, shape);
1025 const int intervalsCount = 3;
1026 const int polyDegree = 4;
1030 Value pp[intervalsCount][polyDegree + 1];
1031 pp[0][0] = bcast(
floatCst(builder, +0.00000000000000000e+00f, elementType));
1032 pp[0][1] = bcast(
floatCst(builder, +1.12837916222975858e+00f, elementType));
1033 pp[0][2] = bcast(
floatCst(builder, -5.23018562988006470e-01f, elementType));
1034 pp[0][3] = bcast(
floatCst(builder, +2.09741709609267072e-01f, elementType));
1035 pp[0][4] = bcast(
floatCst(builder, +2.58146801602987875e-02f, elementType));
1036 pp[1][0] = bcast(
floatCst(builder, +0.00000000000000000e+00f, elementType));
1037 pp[1][1] = bcast(
floatCst(builder, +1.12750687816789140e+00f, elementType));
1038 pp[1][2] = bcast(
floatCst(builder, -3.64721408487825775e-01f, elementType));
1039 pp[1][3] = bcast(
floatCst(builder, +1.18407396425136952e-01f, elementType));
1040 pp[1][4] = bcast(
floatCst(builder, +3.70645533056476558e-02f, elementType));
1041 pp[2][0] = bcast(
floatCst(builder, -3.30093071049483172e-03f, elementType));
1042 pp[2][1] = bcast(
floatCst(builder, +3.51961938357697011e-03f, elementType));
1043 pp[2][2] = bcast(
floatCst(builder, -1.41373622814988039e-03f, elementType));
1044 pp[2][3] = bcast(
floatCst(builder, +2.53447094961941348e-04f, elementType));
1045 pp[2][4] = bcast(
floatCst(builder, -1.71048029455037401e-05f, elementType));
1047 Value qq[intervalsCount][polyDegree + 1];
1048 qq[0][0] = bcast(
floatCst(builder, +1.000000000000000000e+00f, elementType));
1049 qq[0][1] = bcast(
floatCst(builder, -4.635138185962547255e-01f, elementType));
1050 qq[0][2] = bcast(
floatCst(builder, +5.192301327279782447e-01f, elementType));
1051 qq[0][3] = bcast(
floatCst(builder, -1.318089722204810087e-01f, elementType));
1052 qq[0][4] = bcast(
floatCst(builder, +7.397964654672315005e-02f, elementType));
1053 qq[1][0] = bcast(
floatCst(builder, +1.00000000000000000e+00f, elementType));
1054 qq[1][1] = bcast(
floatCst(builder, -3.27607011824493086e-01f, elementType));
1055 qq[1][2] = bcast(
floatCst(builder, +4.48369090658821977e-01f, elementType));
1056 qq[1][3] = bcast(
floatCst(builder, -8.83462621207857930e-02f, elementType));
1057 qq[1][4] = bcast(
floatCst(builder, +5.72442770283176093e-02f, elementType));
1058 qq[2][0] = bcast(
floatCst(builder, +1.00000000000000000e+00f, elementType));
1059 qq[2][1] = bcast(
floatCst(builder, -2.06069165953913769e+00f, elementType));
1060 qq[2][2] = bcast(
floatCst(builder, +1.62705939945477759e+00f, elementType));
1061 qq[2][3] = bcast(
floatCst(builder, -5.83389859211130017e-01f, elementType));
1062 qq[2][4] = bcast(
floatCst(builder, +8.21908939856640930e-02f, elementType));
1064 Value offsets[intervalsCount];
1065 offsets[0] = bcast(
floatCst(builder, 0.0f, elementType));
1066 offsets[1] = bcast(
floatCst(builder, 0.0f, elementType));
1067 offsets[2] = bcast(
floatCst(builder, 1.0f, elementType));
1069 Value bounds[intervalsCount];
1070 bounds[0] = bcast(
floatCst(builder, 0.8f, elementType));
1071 bounds[1] = bcast(
floatCst(builder, 2.0f, elementType));
1072 bounds[2] = bcast(
floatCst(builder, 3.75f, elementType));
1074 Value isNegativeArg =
1075 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
1076 Value negArg = builder.
create<arith::NegFOp>(operand);
1077 Value x = builder.
create<arith::SelectOp>(isNegativeArg, negArg, operand);
1079 Value offset = offsets[0];
1080 Value p[polyDegree + 1];
1081 Value q[polyDegree + 1];
1082 for (
int i = 0; i <= polyDegree; ++i) {
1088 Value isLessThanBound[intervalsCount];
1089 for (
int j = 0;
j < intervalsCount - 1; ++
j) {
1090 isLessThanBound[
j] =
1091 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[
j]);
1092 for (
int i = 0; i <= polyDegree; ++i) {
1093 p[i] = builder.
create<arith::SelectOp>(isLessThanBound[
j], p[i],
1095 q[i] = builder.
create<arith::SelectOp>(isLessThanBound[
j], q[i],
1098 offset = builder.
create<arith::SelectOp>(isLessThanBound[
j], offset,
1101 isLessThanBound[intervalsCount - 1] = builder.
create<arith::CmpFOp>(
1102 arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
1104 Value pPoly = makePolynomialCalculation(builder, p, x);
1105 Value qPoly = makePolynomialCalculation(builder, q, x);
1106 Value rationalPoly = builder.
create<arith::DivFOp>(pPoly, qPoly);
1107 Value formula = builder.
create<arith::AddFOp>(offset, rationalPoly);
1108 formula = builder.
create<arith::SelectOp>(isLessThanBound[intervalsCount - 1],
1112 Value negFormula = builder.
create<arith::NegFOp>(formula);
1114 builder.
create<arith::SelectOp>(isNegativeArg, negFormula, formula);
1128 const std::optional<VectorShape> shape,
Value value,
1129 float lowerBound,
float upperBound) {
1130 assert(!std::isnan(lowerBound));
1131 assert(!std::isnan(upperBound));
1134 return broadcast(builder, value, shape);
1137 auto selectCmp = [&builder](
auto pred,
Value value,
Value bound) {
1138 return builder.
create<arith::SelectOp>(
1139 builder.
create<arith::CmpFOp>(pred, value, bound), value, bound);
1145 value = selectCmp(arith::CmpFPredicate::UGE, value,
1146 bcast(
f32Cst(builder, lowerBound)));
1147 value = selectCmp(arith::CmpFPredicate::ULE, value,
1148 bcast(
f32Cst(builder, upperBound)));
1156 LogicalResult matchAndRewrite(math::ExpOp op,
1161 ExpApproximation::matchAndRewrite(math::ExpOp op,
1163 auto shape =
vectorShape(op.getOperand().getType());
1165 if (!elementTy.isF32())
1171 return builder.
create<arith::AddFOp>(a, b);
1174 return broadcast(builder, value, shape);
1178 return builder.
create<math::FmaOp>(a, b, c);
1181 return builder.
create<arith::MulFOp>(a, b);
1209 Value cstLog2ef = bcast(
f32Cst(builder, 1.44269504088896341f));
1211 Value cstExpC1 = bcast(
f32Cst(builder, -0.693359375f));
1212 Value cstExpC2 = bcast(
f32Cst(builder, 2.12194440e-4f));
1213 Value cstExpP0 = bcast(
f32Cst(builder, 1.9875691500E-4f));
1214 Value cstExpP1 = bcast(
f32Cst(builder, 1.3981999507E-3f));
1215 Value cstExpP2 = bcast(
f32Cst(builder, 8.3334519073E-3f));
1216 Value cstExpP3 = bcast(
f32Cst(builder, 4.1665795894E-2f));
1217 Value cstExpP4 = bcast(
f32Cst(builder, 1.6666665459E-1f));
1218 Value cstExpP5 = bcast(
f32Cst(builder, 5.0000001201E-1f));
1225 Value x = op.getOperand();
1226 x = clampWithNormals(builder, shape, x, -87.8f, 88.8f);
1227 Value n =
floor(fmla(x, cstLog2ef, cstHalf));
1268 n = clampWithNormals(builder, shape, n, -127.0f, 127.0f);
1271 x = fmla(cstExpC1, n, x);
1272 x = fmla(cstExpC2, n, x);
1275 Value z = fmla(x, cstExpP0, cstExpP1);
1276 z = fmla(z, x, cstExpP2);
1277 z = fmla(z, x, cstExpP3);
1278 z = fmla(z, x, cstExpP4);
1279 z = fmla(z, x, cstExpP5);
1280 z = fmla(z, mul(x, x), x);
1285 Value nI32 = builder.
create<arith::FPToSIOp>(i32Vec, n);
1291 Value ret = mul(z, pow2);
1294 return mlir::success();
1309 LogicalResult matchAndRewrite(math::ExpM1Op op,
1315 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
1320 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
1324 return broadcast(builder, value, shape);
1332 Value x = op.getOperand();
1335 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
1336 Value uMinusOne = builder.
create<arith::SubFOp>(u, cstOne);
1337 Value uMinusOneEqNegOne = builder.
create<arith::CmpFOp>(
1338 arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
1344 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u);
1348 uMinusOne, builder.
create<arith::DivFOp>(x, logU));
1349 expm1 = builder.
create<arith::SelectOp>(isInf, u, expm1);
1350 Value approximation = builder.
create<arith::SelectOp>(
1352 builder.
create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
1363 template <
bool isSine,
typename OpTy>
1368 LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter)
const final;
1372 #define TWO_OVER_PI \
1373 0.6366197723675813430755350534900574481378385829618257949906693762L
1375 1.5707963267948966192313216916397514420985846996875529104874722961L
1380 template <
bool isSine,
typename OpTy>
1381 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
1384 llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
1385 "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
1390 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
1394 return broadcast(builder, value, shape);
1397 return builder.
create<arith::MulFOp>(a, b);
1400 return builder.
create<arith::SubFOp>(a, b);
1405 auto fPToSingedInteger = [&](
Value a) ->
Value {
1406 return builder.
create<arith::FPToSIOp>(i32Vec, a);
1410 return builder.
create<arith::AndIOp>(a, bcast(
i32Cst(builder, 3)));
1414 return builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b);
1418 return builder.
create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b);
1422 return builder.
create<arith::SelectOp>(cond, t, f);
1426 return builder.
create<math::FmaOp>(a, b, c);
1430 return builder.
create<arith::OrIOp>(a, b);
1436 Value x = op.getOperand();
1440 Value y = sub(x, mul(k, piOverTwo));
1443 Value cstNegativeOne = bcast(
f32Cst(builder, -1.0));
1445 Value cstSC2 = bcast(
f32Cst(builder, -0.16666667163372039794921875f));
1446 Value cstSC4 = bcast(
f32Cst(builder, 8.333347737789154052734375e-3f));
1447 Value cstSC6 = bcast(
f32Cst(builder, -1.9842604524455964565277099609375e-4f));
1449 bcast(
f32Cst(builder, 2.760012648650445044040679931640625e-6f));
1451 bcast(
f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f));
1454 Value cstCC4 = bcast(
f32Cst(builder, 4.166664183139801025390625e-2f));
1455 Value cstCC6 = bcast(
f32Cst(builder, -1.388833043165504932403564453125e-3f));
1456 Value cstCC8 = bcast(
f32Cst(builder, 2.47562347794882953166961669921875e-5f));
1458 bcast(
f32Cst(builder, -2.59630184018533327616751194000244140625e-7f));
1460 Value kMod4 = modulo4(fPToSingedInteger(k));
1462 Value kR0 = isEqualTo(kMod4, bcast(
i32Cst(builder, 0)));
1463 Value kR1 = isEqualTo(kMod4, bcast(
i32Cst(builder, 1)));
1464 Value kR2 = isEqualTo(kMod4, bcast(
i32Cst(builder, 2)));
1465 Value kR3 = isEqualTo(kMod4, bcast(
i32Cst(builder, 3)));
1467 Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
1468 Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(
i32Cst(builder, 1)))
1469 : bitwiseOr(kR1, kR2);
1471 Value y2 = mul(y, y);
1473 Value base = select(sinuseCos, cstOne, y);
1474 Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
1475 Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
1476 Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
1477 Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
1478 Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
1480 Value v1 = fmla(y2, cstC10, cstC8);
1481 Value v2 = fmla(y2, v1, cstC6);
1482 Value v3 = fmla(y2, v2, cstC4);
1483 Value v4 = fmla(y2, v3, cstC2);
1484 Value v5 = fmla(y2, v4, cstOne);
1485 Value v6 = mul(base, v5);
1487 Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6);
1502 LogicalResult matchAndRewrite(math::CbrtOp op,
1510 CbrtApproximation::matchAndRewrite(math::CbrtOp op,
1512 auto operand = op.getOperand();
1517 std::optional<VectorShape> shape =
vectorShape(operand);
1526 auto bconst = [&](TypedAttr attr) ->
Value {
1527 Value value = b.create<arith::ConstantOp>(attr);
1532 Value intTwo = bconst(b.getI32IntegerAttr(2));
1533 Value intFour = bconst(b.getI32IntegerAttr(4));
1534 Value intEight = bconst(b.getI32IntegerAttr(8));
1535 Value intMagic = bconst(b.getI32IntegerAttr(0x2a5137a0));
1536 Value fpThird = bconst(b.getF32FloatAttr(0.33333333f));
1537 Value fpTwo = bconst(b.getF32FloatAttr(2.0f));
1538 Value fpZero = bconst(b.getF32FloatAttr(0.0f));
1544 Value absValue = b.create<math::AbsFOp>(operand);
1545 Value intValue = b.create<arith::BitcastOp>(intTy, absValue);
1546 Value divideBy4 = b.create<arith::ShRSIOp>(intValue, intTwo);
1547 Value divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
1548 intValue = b.create<arith::AddIOp>(divideBy4, divideBy16);
1551 divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
1552 intValue = b.create<arith::AddIOp>(intValue, divideBy16);
1555 Value divideBy256 = b.create<arith::ShRSIOp>(intValue, intEight);
1556 intValue = b.create<arith::AddIOp>(intValue, divideBy256);
1559 intValue = b.create<arith::AddIOp>(intValue, intMagic);
1563 Value floatValue = b.create<arith::BitcastOp>(floatTy, intValue);
1564 Value squared = b.create<arith::MulFOp>(floatValue, floatValue);
1565 Value mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
1566 Value divSquared = b.create<arith::DivFOp>(absValue, squared);
1567 floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
1568 floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
1571 squared = b.create<arith::MulFOp>(floatValue, floatValue);
1572 mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
1573 divSquared = b.create<arith::DivFOp>(absValue, squared);
1574 floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
1575 floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
1579 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absValue, fpZero);
1580 floatValue = b.create<arith::SelectOp>(isZero, fpZero, floatValue);
1581 floatValue = b.create<math::CopySignOp>(floatValue, operand);
1595 LogicalResult matchAndRewrite(math::RsqrtOp op,
1601 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
1606 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
1609 if (!shape || shape->sizes.empty() || shape->sizes.back() % 8 != 0)
1614 return broadcast(builder, value, shape);
1618 Value cstOnePointFive = bcast(
f32Cst(builder, 1.5f));
1622 Value negHalf = builder.
create<arith::MulFOp>(op.getOperand(), cstNegHalf);
1627 arith::CmpFPredicate::OLT, op.getOperand(), cstMinNormPos);
1628 Value infMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
1629 op.getOperand(), cstPosInf);
1630 Value notNormalFiniteMask = builder.
create<arith::OrIOp>(ltMinMask, infMask);
1634 builder, op->getOperands(), 8, [&builder](
ValueRange operands) ->
Value {
1635 return builder.create<x86vector::RsqrtOp>(operands);
1642 Value inner = builder.
create<arith::MulFOp>(negHalf, yApprox);
1643 Value fma = builder.
create<math::FmaOp>(yApprox, inner, cstOnePointFive);
1644 Value yNewton = builder.
create<arith::MulFOp>(yApprox, fma);
1652 builder.
create<arith::SelectOp>(notNormalFiniteMask, yApprox, yNewton);
1675 .
add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
1676 ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
1677 ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
1678 ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
1679 ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
1680 ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
1684 .
add<AtanApproximation, Atan2Approximation, TanhApproximation,
1685 LogApproximation, Log2Approximation, Log1pApproximation,
1687 AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1688 CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
1689 SinAndCosApproximation<false, math::CosOp>>(patterns.
getContext());
1691 patterns.
add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
static llvm::ManagedStatic< PassManagerOptions > options
static std::pair< Value, Value > frexp(ImplicitLocOpBuilder &builder, Value arg, bool isPositive=false)
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg)
static std::optional< VectorShape > vectorShape(Type type)
static Value floatCst(ImplicitLocOpBuilder &builder, float value, Type elementType)
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value)
static Type broadcast(Type type, std::optional< VectorShape > shape)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits)
static Value f32Cst(ImplicitLocOpBuilder &builder, double value)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getI32IntegerAttr(int32_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
FloatAttr getF32FloatAttr(float value)
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DynamicAPInt floor(const Fraction &f)
Fraction abs(const Fraction &f)
Include the generated interface declarations.
void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns)
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
int64_t computeMaxLinearIndex(ArrayRef< int64_t > basis)
Return the number of elements of basis (i.e.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns, const MathPolynomialApproximationOptions &options={})
ArrayRef< int64_t > sizes
ArrayRef< bool > scalableFlags
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(math::ErfOp op, PatternRewriter &rewriter) const final
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.