34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/Support/MathExtras.h"
45 auto vectorType = dyn_cast<VectorType>(type);
59 assert(!isa<VectorType>(type) &&
"must be scalar type");
66 assert(!isa<VectorType>(value.
getType()) &&
"must be scalar value");
68 return !shape.empty() ? builder.
create<BroadcastOp>(type, value) : value;
94 assert(!operands.empty() &&
"operands must be not empty");
95 assert(vectorWidth > 0 &&
"vector width must be larger than 0");
97 VectorType inputType = cast<VectorType>(operands[0].getType());
103 return compute(operands);
107 int64_t innerDim = inputShape.back();
108 int64_t expansionDim = innerDim / vectorWidth;
109 assert((innerDim % vectorWidth == 0) &&
"invalid inner dimension size");
116 if (expansionDim > 1) {
118 expandedShape.insert(expandedShape.end() - 1, expansionDim);
119 expandedShape.back() = vectorWidth;
121 for (
unsigned i = 0; i < operands.size(); ++i) {
122 auto operand = operands[i];
123 auto eltType = cast<VectorType>(operand.getType()).getElementType();
125 expandedOperands[i] =
126 builder.
create<vector::ShapeCastOp>(expandedType, operand);
138 for (int64_t i = 0; i < maxIndex; ++i) {
143 extracted[tuple.index()] =
144 builder.
create<vector::ExtractOp>(tuple.value(), offsets);
146 results[i] = compute(extracted);
150 Type resultEltType = cast<VectorType>(results[0].getType()).getElementType();
153 resultExpandedType, builder.
getZeroAttr(resultExpandedType));
155 for (int64_t i = 0; i < maxIndex; ++i)
156 result = builder.
create<vector::InsertOp>(results[i], result,
160 return builder.
create<vector::ShapeCastOp>(
170 assert((elementType.
isF16() || elementType.
isF32()) &&
171 "x must be f16 or f32 type.");
172 return builder.
create<arith::ConstantOp>(
185 Value i32Value =
i32Cst(builder,
static_cast<int32_t
>(bits));
195 return builder.
create<arith::SelectOp>(
196 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::ULT, value, bound),
202 return builder.
create<arith::SelectOp>(
203 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::UGT, value, bound),
210 return max(builder,
min(builder, value, upperBound), lowerBound);
216 bool isPositive =
false) {
233 Value i32Half = builder.
create<arith::BitcastOp>(i32, cstHalf);
234 Value i32InvMantMask = builder.
create<arith::BitcastOp>(i32, cstInvMantMask);
235 Value i32Arg = builder.
create<arith::BitcastOp>(i32Vec, arg);
238 Value tmp0 = builder.
create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask));
239 Value tmp1 = builder.
create<arith::OrIOp>(tmp0, bcast(i32Half));
240 Value normalizedFraction = builder.
create<arith::BitcastOp>(f32Vec, tmp1);
243 Value arg0 = isPositive ? arg : builder.
create<math::AbsFOp>(arg);
244 Value biasedExponentBits = builder.
create<arith::ShRUIOp>(
245 builder.
create<arith::BitcastOp>(i32Vec, arg0),
246 bcast(
i32Cst(builder, 23)));
247 Value biasedExponent =
248 builder.
create<arith::SIToFPOp>(f32Vec, biasedExponentBits);
250 builder.
create<arith::SubFOp>(biasedExponent, bcast(cst126f));
252 return {normalizedFraction, exponent};
266 auto exponetBitLocation = bcast(
i32Cst(builder, 23));
268 auto bias = bcast(
i32Cst(builder, 127));
270 Value biasedArg = builder.
create<arith::AddIOp>(arg, bias);
272 builder.
create<arith::ShLIOp>(biasedArg, exponetBitLocation);
273 Value exp2ValueF32 = builder.
create<arith::BitcastOp>(f32Vec, exp2ValueInt);
282 assert((elementType.
isF32() || elementType.
isF16()) &&
283 "x must be f32 or f16 type");
289 if (coeffs.size() == 1)
292 Value res = builder.
create<math::FmaOp>(x, coeffs[coeffs.size() - 1],
293 coeffs[coeffs.size() - 2]);
294 for (
auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) {
295 res = builder.
create<math::FmaOp>(x, res, coeffs[i]);
305 template <
typename T>
323 if (
auto shaped = dyn_cast<ShapedType>(origType)) {
324 newType = shaped.clone(rewriter.
getF32Type());
325 }
else if (isa<FloatType>(origType)) {
329 "unable to find F32 equivalent type");
335 operands.push_back(rewriter.
create<arith::ExtFOp>(loc, newType, operand));
348 template <
typename T>
354 T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(),
355 "requires same operands and result types");
356 return insertCasts<T>(op, rewriter);
376 AtanApproximation::matchAndRewrite(math::AtanOp op,
392 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
abs, twoThirds);
396 Value xden = builder.
create<arith::SelectOp>(cmp2, addone, one);
403 auto tan3pio8 = bcast(
f32Cst(builder, 2.41421356237309504880));
405 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
abs, tan3pio8);
406 xnum = builder.
create<arith::SelectOp>(cmp1, one, xnum);
407 xden = builder.
create<arith::SelectOp>(cmp1,
abs, xden);
409 Value x = builder.
create<arith::DivFOp>(xnum, xden);
414 auto p0 = bcast(
f32Cst(builder, -8.750608600031904122785e-01));
415 auto p1 = bcast(
f32Cst(builder, -1.615753718733365076637e+01));
416 auto p2 = bcast(
f32Cst(builder, -7.500855792314704667340e+01));
417 auto p3 = bcast(
f32Cst(builder, -1.228866684490136173410e+02));
418 auto p4 = bcast(
f32Cst(builder, -6.485021904942025371773e+01));
419 auto q0 = bcast(
f32Cst(builder, +2.485846490142306297962e+01));
420 auto q1 = bcast(
f32Cst(builder, +1.650270098316988542046e+02));
421 auto q2 = bcast(
f32Cst(builder, +4.328810604912902668951e+02));
422 auto q3 = bcast(
f32Cst(builder, +4.853903996359136964868e+02));
423 auto q4 = bcast(
f32Cst(builder, +1.945506571482613964425e+02));
427 n = builder.
create<math::FmaOp>(xx, n, p1);
428 n = builder.
create<math::FmaOp>(xx, n, p2);
429 n = builder.
create<math::FmaOp>(xx, n, p3);
430 n = builder.
create<math::FmaOp>(xx, n, p4);
431 n = builder.
create<arith::MulFOp>(n, xx);
435 d = builder.
create<math::FmaOp>(xx, d, q1);
436 d = builder.
create<math::FmaOp>(xx, d, q2);
437 d = builder.
create<math::FmaOp>(xx, d, q3);
438 d = builder.
create<math::FmaOp>(xx, d, q4);
442 ans0 = builder.
create<math::FmaOp>(ans0, x, x);
445 Value mpi4 = bcast(
f32Cst(builder, llvm::numbers::pi / 4));
446 Value ans2 = builder.
create<arith::AddFOp>(mpi4, ans0);
447 Value ans = builder.
create<arith::SelectOp>(cmp2, ans2, ans0);
449 Value mpi2 = bcast(
f32Cst(builder, llvm::numbers::pi / 2));
450 Value ans1 = builder.
create<arith::SubFOp>(mpi2, ans0);
451 ans = builder.
create<arith::SelectOp>(cmp1, ans1, ans);
473 Atan2Approximation::matchAndRewrite(math::Atan2Op op,
484 auto div = builder.
create<arith::DivFOp>(y, x);
485 auto atan = builder.
create<math::AtanOp>(div);
490 auto addPi = builder.
create<arith::AddFOp>(atan, pi);
491 auto subPi = builder.
create<arith::SubFOp>(atan, pi);
493 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, atan, zero);
494 auto flippedAtan = builder.
create<arith::SelectOp>(atanGt, subPi, addPi);
497 auto xGt = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zero);
498 Value result = builder.
create<arith::SelectOp>(xGt, atan, flippedAtan);
502 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, x, zero);
503 Value yGt = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, y, zero);
504 Value isHalfPi = builder.
create<arith::AndIOp>(xZero, yGt);
505 auto halfPi =
broadcast(builder,
f32Cst(builder, 1.57079632679f), shape);
506 result = builder.
create<arith::SelectOp>(isHalfPi, halfPi, result);
509 Value yLt = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, y, zero);
510 Value isNegativeHalfPiPi = builder.
create<arith::AndIOp>(xZero, yLt);
511 auto negativeHalfPiPi =
513 result = builder.
create<arith::SelectOp>(isNegativeHalfPiPi, negativeHalfPiPi,
518 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, y, zero);
519 Value isNan = builder.
create<arith::AndIOp>(xZero, yZero);
521 result = builder.
create<arith::SelectOp>(isNan, cstNan, result);
542 TanhApproximation::matchAndRewrite(math::TanhOp op,
555 Value minusClamp = bcast(
f32Cst(builder, -7.99881172180175781f));
556 Value plusClamp = bcast(
f32Cst(builder, 7.99881172180175781f));
566 Value alpha1 = bcast(
f32Cst(builder, 4.89352455891786e-03f));
567 Value alpha3 = bcast(
f32Cst(builder, 6.37261928875436e-04f));
568 Value alpha5 = bcast(
f32Cst(builder, 1.48572235717979e-05f));
569 Value alpha7 = bcast(
f32Cst(builder, 5.12229709037114e-08f));
570 Value alpha9 = bcast(
f32Cst(builder, -8.60467152213735e-11f));
571 Value alpha11 = bcast(
f32Cst(builder, 2.00018790482477e-13f));
572 Value alpha13 = bcast(
f32Cst(builder, -2.76076847742355e-16f));
575 Value beta0 = bcast(
f32Cst(builder, 4.89352518554385e-03f));
576 Value beta2 = bcast(
f32Cst(builder, 2.26843463243900e-03f));
577 Value beta4 = bcast(
f32Cst(builder, 1.18534705686654e-04f));
578 Value beta6 = bcast(
f32Cst(builder, 1.19825839466702e-06f));
584 Value p = builder.
create<math::FmaOp>(x2, alpha13, alpha11);
585 p = builder.
create<math::FmaOp>(x2, p, alpha9);
586 p = builder.
create<math::FmaOp>(x2, p, alpha7);
587 p = builder.
create<math::FmaOp>(x2, p, alpha5);
588 p = builder.
create<math::FmaOp>(x2, p, alpha3);
589 p = builder.
create<math::FmaOp>(x2, p, alpha1);
590 p = builder.
create<arith::MulFOp>(x, p);
593 Value q = builder.
create<math::FmaOp>(x2, beta6, beta4);
594 q = builder.
create<math::FmaOp>(x2, q, beta2);
595 q = builder.
create<math::FmaOp>(x2, q, beta0);
599 tinyMask, x, builder.
create<arith::DivFOp>(p, q));
607 0.693147180559945309417232121458176568075500134360255254120680009493393621L
608 #define LOG2E_VALUE \
609 1.442695040888963407359924681001892137426645954152985934135449406931109219L
616 template <
typename Op>
628 template <
typename Op>
653 Value cstCephesSQRTHF = bcast(
f32Cst(builder, 0.707106781186547524f));
654 Value cstCephesLogP0 = bcast(
f32Cst(builder, 7.0376836292E-2f));
655 Value cstCephesLogP1 = bcast(
f32Cst(builder, -1.1514610310E-1f));
656 Value cstCephesLogP2 = bcast(
f32Cst(builder, 1.1676998740E-1f));
657 Value cstCephesLogP3 = bcast(
f32Cst(builder, -1.2420140846E-1f));
658 Value cstCephesLogP4 = bcast(
f32Cst(builder, +1.4249322787E-1f));
659 Value cstCephesLogP5 = bcast(
f32Cst(builder, -1.6668057665E-1f));
660 Value cstCephesLogP6 = bcast(
f32Cst(builder, +2.0000714765E-1f));
661 Value cstCephesLogP7 = bcast(
f32Cst(builder, -2.4999993993E-1f));
662 Value cstCephesLogP8 = bcast(
f32Cst(builder, +3.3333331174E-1f));
667 x =
max(builder, x, cstMinNormPos);
670 std::pair<Value, Value> pair =
frexp(builder, x,
true);
672 Value e = pair.second;
682 Value mask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x,
684 Value tmp = builder.
create<arith::SelectOp>(mask, x, cstZero);
686 x = builder.
create<arith::SubFOp>(x, cstOne);
687 e = builder.
create<arith::SubFOp>(
688 e, builder.
create<arith::SelectOp>(mask, cstOne, cstZero));
689 x = builder.
create<arith::AddFOp>(x, tmp);
696 y0 = builder.
create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1);
697 y1 = builder.
create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4);
698 y2 = builder.
create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7);
699 y0 = builder.
create<math::FmaOp>(y0, x, cstCephesLogP2);
700 y1 = builder.
create<math::FmaOp>(y1, x, cstCephesLogP5);
701 y2 = builder.
create<math::FmaOp>(y2, x, cstCephesLogP8);
702 y0 = builder.
create<math::FmaOp>(y0, x3, y1);
703 y0 = builder.
create<math::FmaOp>(y0, x3, y2);
704 y0 = builder.
create<arith::MulFOp>(y0, x3);
706 y0 = builder.
create<math::FmaOp>(cstNegHalf, x2, y0);
707 x = builder.
create<arith::AddFOp>(x, y0);
711 x = builder.
create<math::FmaOp>(x, cstLog2e, e);
714 x = builder.
create<math::FmaOp>(e, cstLn2, x);
717 Value invalidMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::ULT,
719 Value zeroMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
721 Value posInfMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
728 Value aproximation = builder.
create<arith::SelectOp>(
729 zeroMask, cstMinusInf,
730 builder.
create<arith::SelectOp>(
732 builder.
create<arith::SelectOp>(posInfMask, cstPosInf, x)));
740 struct LogApproximation :
public LogApproximationBase<math::LogOp> {
741 using LogApproximationBase::LogApproximationBase;
745 return logMatchAndRewrite(op, rewriter,
false);
751 struct Log2Approximation :
public LogApproximationBase<math::Log2Op> {
752 using LogApproximationBase::LogApproximationBase;
756 return logMatchAndRewrite(op, rewriter,
true);
777 Log1pApproximation::matchAndRewrite(math::Log1pOp op,
797 Value u = builder.
create<arith::AddFOp>(x, cstOne);
799 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne);
802 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU);
804 x, builder.
create<arith::DivFOp>(
805 logU, builder.
create<arith::SubFOp>(u, cstOne)));
806 Value approximation = builder.
create<arith::SelectOp>(
807 builder.
create<arith::OrIOp>(uSmall, uInf), x, logLarge);
829 if (!(elementType.
isF32() || elementType.
isF16()))
831 "only f32 and f16 type is supported.");
839 const int intervalsCount = 3;
840 const int polyDegree = 4;
844 Value pp[intervalsCount][polyDegree + 1];
845 pp[0][0] = bcast(
floatCst(builder, +0.00000000000000000e+00f, elementType));
846 pp[0][1] = bcast(
floatCst(builder, +1.12837916222975858e+00f, elementType));
847 pp[0][2] = bcast(
floatCst(builder, -5.23018562988006470e-01f, elementType));
848 pp[0][3] = bcast(
floatCst(builder, +2.09741709609267072e-01f, elementType));
849 pp[0][4] = bcast(
floatCst(builder, +2.58146801602987875e-02f, elementType));
850 pp[1][0] = bcast(
floatCst(builder, +0.00000000000000000e+00f, elementType));
851 pp[1][1] = bcast(
floatCst(builder, +1.12750687816789140e+00f, elementType));
852 pp[1][2] = bcast(
floatCst(builder, -3.64721408487825775e-01f, elementType));
853 pp[1][3] = bcast(
floatCst(builder, +1.18407396425136952e-01f, elementType));
854 pp[1][4] = bcast(
floatCst(builder, +3.70645533056476558e-02f, elementType));
855 pp[2][0] = bcast(
floatCst(builder, -3.30093071049483172e-03f, elementType));
856 pp[2][1] = bcast(
floatCst(builder, +3.51961938357697011e-03f, elementType));
857 pp[2][2] = bcast(
floatCst(builder, -1.41373622814988039e-03f, elementType));
858 pp[2][3] = bcast(
floatCst(builder, +2.53447094961941348e-04f, elementType));
859 pp[2][4] = bcast(
floatCst(builder, -1.71048029455037401e-05f, elementType));
861 Value qq[intervalsCount][polyDegree + 1];
862 qq[0][0] = bcast(
floatCst(builder, +1.000000000000000000e+00f, elementType));
863 qq[0][1] = bcast(
floatCst(builder, -4.635138185962547255e-01f, elementType));
864 qq[0][2] = bcast(
floatCst(builder, +5.192301327279782447e-01f, elementType));
865 qq[0][3] = bcast(
floatCst(builder, -1.318089722204810087e-01f, elementType));
866 qq[0][4] = bcast(
floatCst(builder, +7.397964654672315005e-02f, elementType));
867 qq[1][0] = bcast(
floatCst(builder, +1.00000000000000000e+00f, elementType));
868 qq[1][1] = bcast(
floatCst(builder, -3.27607011824493086e-01f, elementType));
869 qq[1][2] = bcast(
floatCst(builder, +4.48369090658821977e-01f, elementType));
870 qq[1][3] = bcast(
floatCst(builder, -8.83462621207857930e-02f, elementType));
871 qq[1][4] = bcast(
floatCst(builder, +5.72442770283176093e-02f, elementType));
872 qq[2][0] = bcast(
floatCst(builder, +1.00000000000000000e+00f, elementType));
873 qq[2][1] = bcast(
floatCst(builder, -2.06069165953913769e+00f, elementType));
874 qq[2][2] = bcast(
floatCst(builder, +1.62705939945477759e+00f, elementType));
875 qq[2][3] = bcast(
floatCst(builder, -5.83389859211130017e-01f, elementType));
876 qq[2][4] = bcast(
floatCst(builder, +8.21908939856640930e-02f, elementType));
878 Value offsets[intervalsCount];
879 offsets[0] = bcast(
floatCst(builder, 0.0f, elementType));
880 offsets[1] = bcast(
floatCst(builder, 0.0f, elementType));
881 offsets[2] = bcast(
floatCst(builder, 1.0f, elementType));
883 Value bounds[intervalsCount];
884 bounds[0] = bcast(
floatCst(builder, 0.8f, elementType));
885 bounds[1] = bcast(
floatCst(builder, 2.0f, elementType));
886 bounds[2] = bcast(
floatCst(builder, 3.75f, elementType));
888 Value isNegativeArg =
889 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
890 Value negArg = builder.
create<arith::NegFOp>(operand);
891 Value x = builder.
create<arith::SelectOp>(isNegativeArg, negArg, operand);
893 Value offset = offsets[0];
894 Value p[polyDegree + 1];
895 Value q[polyDegree + 1];
896 for (
int i = 0; i <= polyDegree; ++i) {
902 Value isLessThanBound[intervalsCount];
903 for (
int j = 0;
j < intervalsCount - 1; ++
j) {
905 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[
j]);
906 for (
int i = 0; i <= polyDegree; ++i) {
907 p[i] = builder.
create<arith::SelectOp>(isLessThanBound[
j], p[i],
909 q[i] = builder.
create<arith::SelectOp>(isLessThanBound[
j], q[i],
912 offset = builder.
create<arith::SelectOp>(isLessThanBound[
j], offset,
915 isLessThanBound[intervalsCount - 1] = builder.
create<arith::CmpFOp>(
916 arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
918 Value pPoly = makePolynomialCalculation(builder, p, x);
919 Value qPoly = makePolynomialCalculation(builder, q, x);
920 Value rationalPoly = builder.
create<arith::DivFOp>(pPoly, qPoly);
921 Value formula = builder.
create<arith::AddFOp>(offset, rationalPoly);
922 formula = builder.
create<arith::SelectOp>(isLessThanBound[intervalsCount - 1],
926 Value negFormula = builder.
create<arith::NegFOp>(formula);
928 builder.
create<arith::SelectOp>(isNegativeArg, negFormula, formula);
943 float lowerBound,
float upperBound) {
944 assert(!std::isnan(lowerBound));
945 assert(!std::isnan(upperBound));
951 auto selectCmp = [&builder](
auto pred,
Value value,
Value bound) {
952 return builder.
create<arith::SelectOp>(
953 builder.
create<arith::CmpFOp>(pred, value, bound), value, bound);
959 value = selectCmp(arith::CmpFPredicate::UGE, value,
960 bcast(
f32Cst(builder, lowerBound)));
961 value = selectCmp(arith::CmpFPredicate::ULE, value,
962 bcast(
f32Cst(builder, upperBound)));
975 ExpApproximation::matchAndRewrite(math::ExpOp op,
979 if (!elementTy.isF32())
985 return builder.
create<arith::AddFOp>(a, b);
992 return builder.
create<math::FmaOp>(a, b, c);
995 return builder.
create<arith::MulFOp>(a, b);
1023 Value cst_log2ef = bcast(
f32Cst(builder, 1.44269504088896341f));
1025 Value cst_exp_c1 = bcast(
f32Cst(builder, -0.693359375f));
1026 Value cst_exp_c2 = bcast(
f32Cst(builder, 2.12194440e-4f));
1027 Value cst_exp_p0 = bcast(
f32Cst(builder, 1.9875691500E-4f));
1028 Value cst_exp_p1 = bcast(
f32Cst(builder, 1.3981999507E-3f));
1029 Value cst_exp_p2 = bcast(
f32Cst(builder, 8.3334519073E-3f));
1030 Value cst_exp_p3 = bcast(
f32Cst(builder, 4.1665795894E-2f));
1031 Value cst_exp_p4 = bcast(
f32Cst(builder, 1.6666665459E-1f));
1032 Value cst_exp_p5 = bcast(
f32Cst(builder, 5.0000001201E-1f));
1040 x = clampWithNormals(builder, shape, x, -87.8f, 88.8f);
1041 Value n =
floor(fmla(x, cst_log2ef, cst_half));
1082 n = clampWithNormals(builder, shape, n, -127.0f, 127.0f);
1085 x = fmla(cst_exp_c1, n, x);
1086 x = fmla(cst_exp_c2, n, x);
1089 Value z = fmla(x, cst_exp_p0, cst_exp_p1);
1090 z = fmla(z, x, cst_exp_p2);
1091 z = fmla(z, x, cst_exp_p3);
1092 z = fmla(z, x, cst_exp_p4);
1093 z = fmla(z, x, cst_exp_p5);
1094 z = fmla(z, mul(x, x), x);
1095 z = add(cst_one, z);
1099 Value n_i32 = builder.
create<arith::FPToSIOp>(i32_vec, n);
1105 Value ret = mul(z, pow2);
1129 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
1138 return broadcast(builder, value, shape);
1149 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
1150 Value uMinusOne = builder.
create<arith::SubFOp>(u, cstOne);
1151 Value uMinusOneEqNegOne = builder.
create<arith::CmpFOp>(
1152 arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
1158 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u);
1162 uMinusOne, builder.
create<arith::DivFOp>(x, logU));
1163 expm1 = builder.
create<arith::SelectOp>(isInf, u, expm1);
1164 Value approximation = builder.
create<arith::SelectOp>(
1166 builder.
create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
1177 template <
bool isSine,
typename OpTy>
1186 #define TWO_OVER_PI \
1187 0.6366197723675813430755350534900574481378385829618257949906693762L
1189 1.5707963267948966192313216916397514420985846996875529104874722961L
1194 template <
bool isSine,
typename OpTy>
1195 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
1198 llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
1199 "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
1208 return broadcast(builder, value, shape);
1211 return builder.
create<arith::MulFOp>(a, b);
1214 return builder.
create<arith::SubFOp>(a, b);
1219 auto fPToSingedInteger = [&](
Value a) ->
Value {
1220 return builder.
create<arith::FPToSIOp>(i32Vec, a);
1224 return builder.
create<arith::AndIOp>(a, bcast(
i32Cst(builder, 3)));
1228 return builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b);
1232 return builder.
create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b);
1236 return builder.
create<arith::SelectOp>(cond, t, f);
1240 return builder.
create<math::FmaOp>(a, b, c);
1244 return builder.
create<arith::OrIOp>(a, b);
1254 Value y = sub(x, mul(k, piOverTwo));
1257 Value cstNegativeOne = bcast(
f32Cst(builder, -1.0));
1259 Value cstSC2 = bcast(
f32Cst(builder, -0.16666667163372039794921875f));
1260 Value cstSC4 = bcast(
f32Cst(builder, 8.333347737789154052734375e-3f));
1261 Value cstSC6 = bcast(
f32Cst(builder, -1.9842604524455964565277099609375e-4f));
1263 bcast(
f32Cst(builder, 2.760012648650445044040679931640625e-6f));
1265 bcast(
f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f));
1268 Value cstCC4 = bcast(
f32Cst(builder, 4.166664183139801025390625e-2f));
1269 Value cstCC6 = bcast(
f32Cst(builder, -1.388833043165504932403564453125e-3f));
1270 Value cstCC8 = bcast(
f32Cst(builder, 2.47562347794882953166961669921875e-5f));
1272 bcast(
f32Cst(builder, -2.59630184018533327616751194000244140625e-7f));
1274 Value kMod4 = modulo4(fPToSingedInteger(k));
1276 Value kR0 = isEqualTo(kMod4, bcast(
i32Cst(builder, 0)));
1277 Value kR1 = isEqualTo(kMod4, bcast(
i32Cst(builder, 1)));
1278 Value kR2 = isEqualTo(kMod4, bcast(
i32Cst(builder, 2)));
1279 Value kR3 = isEqualTo(kMod4, bcast(
i32Cst(builder, 3)));
1281 Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
1282 Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(
i32Cst(builder, 1)))
1283 : bitwiseOr(kR1, kR2);
1285 Value y2 = mul(y, y);
1287 Value base = select(sinuseCos, cstOne, y);
1288 Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
1289 Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
1290 Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
1291 Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
1292 Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
1294 Value v1 = fmla(y2, cstC10, cstC8);
1295 Value v2 = fmla(y2, v1, cstC6);
1296 Value v3 = fmla(y2, v2, cstC4);
1297 Value v4 = fmla(y2, v3, cstC2);
1298 Value v5 = fmla(y2, v4, cstOne);
1299 Value v6 = mul(base, v5);
1301 Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6);
1324 CbrtApproximation::matchAndRewrite(math::CbrtOp op,
1340 auto bconst = [&](TypedAttr attr) ->
Value {
1341 Value value = b.create<arith::ConstantOp>(attr);
1346 Value intTwo = bconst(b.getI32IntegerAttr(2));
1347 Value intFour = bconst(b.getI32IntegerAttr(4));
1348 Value intEight = bconst(b.getI32IntegerAttr(8));
1349 Value intMagic = bconst(b.getI32IntegerAttr(0x2a5137a0));
1350 Value fpThird = bconst(b.getF32FloatAttr(0.33333333f));
1351 Value fpTwo = bconst(b.getF32FloatAttr(2.0f));
1352 Value fpZero = bconst(b.getF32FloatAttr(0.0f));
1358 Value absValue = b.create<math::AbsFOp>(operand);
1359 Value intValue = b.create<arith::BitcastOp>(intTy, absValue);
1360 Value divideBy4 = b.create<arith::ShRSIOp>(intValue, intTwo);
1361 Value divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
1362 intValue = b.create<arith::AddIOp>(divideBy4, divideBy16);
1365 divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
1366 intValue = b.create<arith::AddIOp>(intValue, divideBy16);
1369 Value divideBy256 = b.create<arith::ShRSIOp>(intValue, intEight);
1370 intValue = b.create<arith::AddIOp>(intValue, divideBy256);
1373 intValue = b.create<arith::AddIOp>(intValue, intMagic);
1377 Value floatValue = b.create<arith::BitcastOp>(floatTy, intValue);
1378 Value squared = b.create<arith::MulFOp>(floatValue, floatValue);
1379 Value mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
1380 Value divSquared = b.create<arith::DivFOp>(absValue, squared);
1381 floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
1382 floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
1385 squared = b.create<arith::MulFOp>(floatValue, floatValue);
1386 mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
1387 divSquared = b.create<arith::DivFOp>(absValue, squared);
1388 floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
1389 floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
1393 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absValue, fpZero);
1394 floatValue = b.create<arith::SelectOp>(isZero, fpZero, floatValue);
1395 floatValue = b.create<math::CopySignOp>(floatValue, operand);
1415 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
1423 if (shape.empty() || shape.back() % 8 != 0)
1428 return broadcast(builder, value, shape);
1432 Value cstOnePointFive = bcast(
f32Cst(builder, 1.5f));
1441 arith::CmpFPredicate::OLT, op.
getOperand(), cstMinNormPos);
1442 Value infMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
1444 Value notNormalFiniteMask = builder.
create<arith::OrIOp>(ltMinMask, infMask);
1449 return builder.create<x86vector::RsqrtOp>(operands);
1456 Value inner = builder.
create<arith::MulFOp>(negHalf, yApprox);
1457 Value fma = builder.
create<math::FmaOp>(yApprox, inner, cstOnePointFive);
1458 Value yNewton = builder.
create<arith::MulFOp>(yApprox, fma);
1466 builder.
create<arith::SelectOp>(notNormalFiniteMask, yApprox, yNewton);
1479 .
add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
1480 ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
1481 ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
1482 ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
1483 ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
1484 ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
1487 patterns.
add<AtanApproximation, Atan2Approximation, TanhApproximation,
1488 LogApproximation, Log2Approximation, Log1pApproximation,
1490 CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
1491 SinAndCosApproximation<false, math::CosOp>>(
1494 patterns.
add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
static llvm::ManagedStatic< PassManagerOptions > options
static std::pair< Value, Value > frexp(ImplicitLocOpBuilder &builder, Value arg, bool isPositive=false)
static Type broadcast(Type type, ArrayRef< int64_t > shape)
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg)
static ArrayRef< int64_t > vectorShape(Type type)
static Value floatCst(ImplicitLocOpBuilder &builder, float value, Type elementType)
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits)
static Value f32Cst(ImplicitLocOpBuilder &builder, double value)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getI32IntegerAttr(int32_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
FloatAttr getF32FloatAttr(float value)
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
MPInt floor(const Fraction &f)
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
int64_t computeMaxLinearIndex(ArrayRef< int64_t > basis)
Return the number of elements of basis (i.e.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns, const MathPolynomialApproximationOptions &options={})
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(math::ErfOp op, PatternRewriter &rewriter) const final
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.