34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/Support/MathExtras.h"
47 bool empty()
const {
return sizes.empty(); }
53 auto vectorType = dyn_cast<VectorType>(type);
55 ?
VectorShape{vectorType.getShape(), vectorType.getScalableDims()}
69 assert(!isa<VectorType>(type) &&
"must be scalar type");
78 assert(!isa<VectorType>(value.
getType()) &&
"must be scalar value");
80 return !shape.
empty() ? builder.
create<BroadcastOp>(type, value) : value;
106 assert(!operands.empty() &&
"operands must be not empty");
107 assert(vectorWidth > 0 &&
"vector width must be larger than 0");
109 VectorType inputType = cast<VectorType>(operands[0].
getType());
115 return compute(operands);
119 int64_t innerDim = inputShape.back();
120 int64_t expansionDim = innerDim / vectorWidth;
121 assert((innerDim % vectorWidth == 0) &&
"invalid inner dimension size");
128 if (expansionDim > 1) {
130 expandedShape.insert(expandedShape.end() - 1, expansionDim);
131 expandedShape.back() = vectorWidth;
133 for (
unsigned i = 0; i < operands.size(); ++i) {
134 auto operand = operands[i];
135 auto eltType = cast<VectorType>(operand.getType()).getElementType();
137 expandedOperands[i] =
138 builder.
create<vector::ShapeCastOp>(expandedType, operand);
150 for (int64_t i = 0; i < maxIndex; ++i) {
155 extracted[tuple.index()] =
156 builder.
create<vector::ExtractOp>(tuple.value(), offsets);
158 results[i] = compute(extracted);
162 Type resultEltType = cast<VectorType>(results[0].
getType()).getElementType();
165 resultExpandedType, builder.
getZeroAttr(resultExpandedType));
167 for (int64_t i = 0; i < maxIndex; ++i)
168 result = builder.
create<vector::InsertOp>(results[i], result,
172 return builder.
create<vector::ShapeCastOp>(
182 assert((elementType.
isF16() || elementType.
isF32()) &&
183 "x must be f16 or f32 type.");
184 return builder.
create<arith::ConstantOp>(
197 Value i32Value =
i32Cst(builder,
static_cast<int32_t
>(bits));
207 return builder.
create<arith::SelectOp>(
208 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::ULT, value, bound),
214 return builder.
create<arith::SelectOp>(
215 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::UGT, value, bound),
222 return max(builder,
min(builder, value, upperBound), lowerBound);
228 bool isPositive =
false) {
245 Value i32Half = builder.
create<arith::BitcastOp>(i32, cstHalf);
246 Value i32InvMantMask = builder.
create<arith::BitcastOp>(i32, cstInvMantMask);
247 Value i32Arg = builder.
create<arith::BitcastOp>(i32Vec, arg);
250 Value tmp0 = builder.
create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask));
251 Value tmp1 = builder.
create<arith::OrIOp>(tmp0, bcast(i32Half));
252 Value normalizedFraction = builder.
create<arith::BitcastOp>(f32Vec, tmp1);
255 Value arg0 = isPositive ? arg : builder.
create<math::AbsFOp>(arg);
256 Value biasedExponentBits = builder.
create<arith::ShRUIOp>(
257 builder.
create<arith::BitcastOp>(i32Vec, arg0),
258 bcast(
i32Cst(builder, 23)));
259 Value biasedExponent =
260 builder.
create<arith::SIToFPOp>(f32Vec, biasedExponentBits);
262 builder.
create<arith::SubFOp>(biasedExponent, bcast(cst126f));
264 return {normalizedFraction, exponent};
278 auto exponetBitLocation = bcast(
i32Cst(builder, 23));
280 auto bias = bcast(
i32Cst(builder, 127));
282 Value biasedArg = builder.
create<arith::AddIOp>(arg, bias);
284 builder.
create<arith::ShLIOp>(biasedArg, exponetBitLocation);
285 Value exp2ValueF32 = builder.
create<arith::BitcastOp>(f32Vec, exp2ValueInt);
294 assert((elementType.
isF32() || elementType.
isF16()) &&
295 "x must be f32 or f16 type");
301 if (coeffs.size() == 1)
304 Value res = builder.
create<math::FmaOp>(x, coeffs[coeffs.size() - 1],
305 coeffs[coeffs.size() - 2]);
306 for (
auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) {
307 res = builder.
create<math::FmaOp>(x, res, coeffs[i]);
317 template <
typename T>
335 if (
auto shaped = dyn_cast<ShapedType>(origType)) {
336 newType = shaped.clone(rewriter.
getF32Type());
337 }
else if (isa<FloatType>(origType)) {
341 "unable to find F32 equivalent type");
347 operands.push_back(rewriter.
create<arith::ExtFOp>(loc, newType, operand));
360 template <
typename T>
364 LogicalResult matchAndRewrite(T op,
PatternRewriter &rewriter)
const final {
366 T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(),
367 "requires same operands and result types");
368 return insertCasts<T>(op, rewriter);
382 LogicalResult matchAndRewrite(math::AtanOp op,
388 AtanApproximation::matchAndRewrite(math::AtanOp op,
404 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
abs, twoThirds);
408 Value xden = builder.
create<arith::SelectOp>(cmp2, addone, one);
415 auto tan3pio8 = bcast(
f32Cst(builder, 2.41421356237309504880));
417 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
abs, tan3pio8);
418 xnum = builder.
create<arith::SelectOp>(cmp1, one, xnum);
419 xden = builder.
create<arith::SelectOp>(cmp1,
abs, xden);
421 Value x = builder.
create<arith::DivFOp>(xnum, xden);
426 auto p0 = bcast(
f32Cst(builder, -8.750608600031904122785e-01));
427 auto p1 = bcast(
f32Cst(builder, -1.615753718733365076637e+01));
428 auto p2 = bcast(
f32Cst(builder, -7.500855792314704667340e+01));
429 auto p3 = bcast(
f32Cst(builder, -1.228866684490136173410e+02));
430 auto p4 = bcast(
f32Cst(builder, -6.485021904942025371773e+01));
431 auto q0 = bcast(
f32Cst(builder, +2.485846490142306297962e+01));
432 auto q1 = bcast(
f32Cst(builder, +1.650270098316988542046e+02));
433 auto q2 = bcast(
f32Cst(builder, +4.328810604912902668951e+02));
434 auto q3 = bcast(
f32Cst(builder, +4.853903996359136964868e+02));
435 auto q4 = bcast(
f32Cst(builder, +1.945506571482613964425e+02));
439 n = builder.
create<math::FmaOp>(xx, n, p1);
440 n = builder.
create<math::FmaOp>(xx, n, p2);
441 n = builder.
create<math::FmaOp>(xx, n, p3);
442 n = builder.
create<math::FmaOp>(xx, n, p4);
443 n = builder.
create<arith::MulFOp>(n, xx);
447 d = builder.
create<math::FmaOp>(xx, d, q1);
448 d = builder.
create<math::FmaOp>(xx, d, q2);
449 d = builder.
create<math::FmaOp>(xx, d, q3);
450 d = builder.
create<math::FmaOp>(xx, d, q4);
454 ans0 = builder.
create<math::FmaOp>(ans0, x, x);
457 Value mpi4 = bcast(
f32Cst(builder, llvm::numbers::pi / 4));
458 Value ans2 = builder.
create<arith::AddFOp>(mpi4, ans0);
459 Value ans = builder.
create<arith::SelectOp>(cmp2, ans2, ans0);
461 Value mpi2 = bcast(
f32Cst(builder, llvm::numbers::pi / 2));
462 Value ans1 = builder.
create<arith::SubFOp>(mpi2, ans0);
463 ans = builder.
create<arith::SelectOp>(cmp1, ans1, ans);
479 LogicalResult matchAndRewrite(math::Atan2Op op,
485 Atan2Approximation::matchAndRewrite(math::Atan2Op op,
496 auto div = builder.
create<arith::DivFOp>(y, x);
497 auto atan = builder.
create<math::AtanOp>(div);
502 auto addPi = builder.
create<arith::AddFOp>(atan, pi);
503 auto subPi = builder.
create<arith::SubFOp>(atan, pi);
505 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, atan, zero);
506 auto flippedAtan = builder.
create<arith::SelectOp>(atanGt, subPi, addPi);
509 auto xGt = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zero);
510 Value result = builder.
create<arith::SelectOp>(xGt, atan, flippedAtan);
514 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, x, zero);
515 Value yGt = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, y, zero);
516 Value isHalfPi = builder.
create<arith::AndIOp>(xZero, yGt);
517 auto halfPi =
broadcast(builder,
f32Cst(builder, 1.57079632679f), shape);
518 result = builder.
create<arith::SelectOp>(isHalfPi, halfPi, result);
521 Value yLt = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, y, zero);
522 Value isNegativeHalfPiPi = builder.
create<arith::AndIOp>(xZero, yLt);
523 auto negativeHalfPiPi =
525 result = builder.
create<arith::SelectOp>(isNegativeHalfPiPi, negativeHalfPiPi,
530 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, y, zero);
531 Value isNan = builder.
create<arith::AndIOp>(xZero, yZero);
533 result = builder.
create<arith::SelectOp>(isNan, cstNan, result);
548 LogicalResult matchAndRewrite(math::TanhOp op,
554 TanhApproximation::matchAndRewrite(math::TanhOp op,
567 Value minusClamp = bcast(
f32Cst(builder, -7.99881172180175781f));
568 Value plusClamp = bcast(
f32Cst(builder, 7.99881172180175781f));
578 Value alpha1 = bcast(
f32Cst(builder, 4.89352455891786e-03f));
579 Value alpha3 = bcast(
f32Cst(builder, 6.37261928875436e-04f));
580 Value alpha5 = bcast(
f32Cst(builder, 1.48572235717979e-05f));
581 Value alpha7 = bcast(
f32Cst(builder, 5.12229709037114e-08f));
582 Value alpha9 = bcast(
f32Cst(builder, -8.60467152213735e-11f));
583 Value alpha11 = bcast(
f32Cst(builder, 2.00018790482477e-13f));
584 Value alpha13 = bcast(
f32Cst(builder, -2.76076847742355e-16f));
587 Value beta0 = bcast(
f32Cst(builder, 4.89352518554385e-03f));
588 Value beta2 = bcast(
f32Cst(builder, 2.26843463243900e-03f));
589 Value beta4 = bcast(
f32Cst(builder, 1.18534705686654e-04f));
590 Value beta6 = bcast(
f32Cst(builder, 1.19825839466702e-06f));
596 Value p = builder.
create<math::FmaOp>(x2, alpha13, alpha11);
597 p = builder.
create<math::FmaOp>(x2, p, alpha9);
598 p = builder.
create<math::FmaOp>(x2, p, alpha7);
599 p = builder.
create<math::FmaOp>(x2, p, alpha5);
600 p = builder.
create<math::FmaOp>(x2, p, alpha3);
601 p = builder.
create<math::FmaOp>(x2, p, alpha1);
602 p = builder.
create<arith::MulFOp>(x, p);
605 Value q = builder.
create<math::FmaOp>(x2, beta6, beta4);
606 q = builder.
create<math::FmaOp>(x2, q, beta2);
607 q = builder.
create<math::FmaOp>(x2, q, beta0);
611 tinyMask, x, builder.
create<arith::DivFOp>(p, q));
619 0.693147180559945309417232121458176568075500134360255254120680009493393621L
620 #define LOG2E_VALUE \
621 1.442695040888963407359924681001892137426645954152985934135449406931109219L
628 template <
typename Op>
640 template <
typename Op>
665 Value cstCephesSQRTHF = bcast(
f32Cst(builder, 0.707106781186547524f));
666 Value cstCephesLogP0 = bcast(
f32Cst(builder, 7.0376836292E-2f));
667 Value cstCephesLogP1 = bcast(
f32Cst(builder, -1.1514610310E-1f));
668 Value cstCephesLogP2 = bcast(
f32Cst(builder, 1.1676998740E-1f));
669 Value cstCephesLogP3 = bcast(
f32Cst(builder, -1.2420140846E-1f));
670 Value cstCephesLogP4 = bcast(
f32Cst(builder, +1.4249322787E-1f));
671 Value cstCephesLogP5 = bcast(
f32Cst(builder, -1.6668057665E-1f));
672 Value cstCephesLogP6 = bcast(
f32Cst(builder, +2.0000714765E-1f));
673 Value cstCephesLogP7 = bcast(
f32Cst(builder, -2.4999993993E-1f));
674 Value cstCephesLogP8 = bcast(
f32Cst(builder, +3.3333331174E-1f));
679 x =
max(builder, x, cstMinNormPos);
682 std::pair<Value, Value> pair =
frexp(builder, x,
true);
684 Value e = pair.second;
694 Value mask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x,
696 Value tmp = builder.
create<arith::SelectOp>(mask, x, cstZero);
698 x = builder.
create<arith::SubFOp>(x, cstOne);
699 e = builder.
create<arith::SubFOp>(
700 e, builder.
create<arith::SelectOp>(mask, cstOne, cstZero));
701 x = builder.
create<arith::AddFOp>(x, tmp);
708 y0 = builder.
create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1);
709 y1 = builder.
create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4);
710 y2 = builder.
create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7);
711 y0 = builder.
create<math::FmaOp>(y0, x, cstCephesLogP2);
712 y1 = builder.
create<math::FmaOp>(y1, x, cstCephesLogP5);
713 y2 = builder.
create<math::FmaOp>(y2, x, cstCephesLogP8);
714 y0 = builder.
create<math::FmaOp>(y0, x3, y1);
715 y0 = builder.
create<math::FmaOp>(y0, x3, y2);
716 y0 = builder.
create<arith::MulFOp>(y0, x3);
718 y0 = builder.
create<math::FmaOp>(cstNegHalf, x2, y0);
719 x = builder.
create<arith::AddFOp>(x, y0);
723 x = builder.
create<math::FmaOp>(x, cstLog2e, e);
726 x = builder.
create<math::FmaOp>(e, cstLn2, x);
729 Value invalidMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::ULT,
731 Value zeroMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
733 Value posInfMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
740 Value aproximation = builder.
create<arith::SelectOp>(
741 zeroMask, cstMinusInf,
742 builder.
create<arith::SelectOp>(
744 builder.
create<arith::SelectOp>(posInfMask, cstPosInf, x)));
752 struct LogApproximation :
public LogApproximationBase<math::LogOp> {
753 using LogApproximationBase::LogApproximationBase;
755 LogicalResult matchAndRewrite(math::LogOp op,
757 return logMatchAndRewrite(op, rewriter,
false);
763 struct Log2Approximation :
public LogApproximationBase<math::Log2Op> {
764 using LogApproximationBase::LogApproximationBase;
766 LogicalResult matchAndRewrite(math::Log2Op op,
768 return logMatchAndRewrite(op, rewriter,
true);
782 LogicalResult matchAndRewrite(math::Log1pOp op,
789 Log1pApproximation::matchAndRewrite(math::Log1pOp op,
809 Value u = builder.
create<arith::AddFOp>(x, cstOne);
811 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne);
814 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU);
816 x, builder.
create<arith::DivFOp>(
817 logU, builder.
create<arith::SubFOp>(u, cstOne)));
818 Value approximation = builder.
create<arith::SelectOp>(
819 builder.
create<arith::OrIOp>(uSmall, uInf), x, logLarge);
832 struct AsinPolynomialApproximation :
public OpRewritePattern<math::AsinOp> {
836 LogicalResult matchAndRewrite(math::AsinOp op,
841 AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
846 if (!(elementType.
isF32() || elementType.
isF16()))
848 "only f32 and f16 type is supported.");
857 return builder.
create<math::FmaOp>(a, b, c);
861 return builder.
create<arith::MulFOp>(a, b);
864 Value s = mul(operand, operand);
866 Value r = bcast(
floatCst(builder, 5.5579749017470502e-2, elementType));
867 Value t = bcast(
floatCst(builder, -6.2027913464120114e-2, elementType));
869 r = fma(r, q, bcast(
floatCst(builder, 5.4224464349245036e-2, elementType)));
870 t = fma(t, q, bcast(
floatCst(builder, -1.1326992890324464e-2, elementType)));
871 r = fma(r, q, bcast(
floatCst(builder, 1.5268872539397656e-2, elementType)));
872 t = fma(t, q, bcast(
floatCst(builder, 1.0493798473372081e-2, elementType)));
873 r = fma(r, q, bcast(
floatCst(builder, 1.4106045900607047e-2, elementType)));
874 t = fma(t, q, bcast(
floatCst(builder, 1.7339776384962050e-2, elementType)));
875 r = fma(r, q, bcast(
floatCst(builder, 2.2372961589651054e-2, elementType)));
876 t = fma(t, q, bcast(
floatCst(builder, 3.0381912707941005e-2, elementType)));
877 r = fma(r, q, bcast(
floatCst(builder, 4.4642857881094775e-2, elementType)));
878 t = fma(t, q, bcast(
floatCst(builder, 7.4999999991367292e-2, elementType)));
880 r = fma(r, s, bcast(
floatCst(builder, 1.6666666666670193e-1, elementType)));
882 r = fma(r, t, operand);
896 struct AcosPolynomialApproximation :
public OpRewritePattern<math::AcosOp> {
900 LogicalResult matchAndRewrite(math::AcosOp op,
905 AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
910 if (!(elementType.
isF32() || elementType.
isF16()))
912 "only f32 and f16 type is supported.");
921 return builder.
create<math::FmaOp>(a, b, c);
925 return builder.
create<arith::MulFOp>(a, b);
928 Value negOperand = builder.
create<arith::NegFOp>(operand);
933 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand, zero);
934 Value r = builder.
create<arith::SelectOp>(selR, negOperand, operand);
935 Value chkConst = bcast(
floatCst(builder, -0.5625, elementType));
937 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, r, chkConst);
940 fma(bcast(
floatCst(builder, 9.3282184640716537e-1, elementType)),
941 bcast(
floatCst(builder, 1.6839188885261840e+0, elementType)),
942 builder.
create<math::AsinOp>(r));
944 Value falseVal = builder.
create<math::SqrtOp>(fma(half, r, half));
945 falseVal = builder.
create<math::AsinOp>(falseVal);
946 falseVal = mul(bcast(
floatCst(builder, 2.0, elementType)), falseVal);
948 r = builder.
create<arith::SelectOp>(firstPred, trueVal, falseVal);
951 Value greaterThanNegOne =
952 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, negOne);
955 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
957 Value betweenNegOneZero =
958 builder.
create<arith::AndIOp>(greaterThanNegOne, lessThanZero);
960 trueVal = fma(bcast(
floatCst(builder, 1.8656436928143307e+0, elementType)),
961 bcast(
floatCst(builder, 1.6839188885261840e+0, elementType)),
962 builder.
create<arith::NegFOp>(r));
965 builder.
create<arith::SelectOp>(betweenNegOneZero, trueVal, r);
988 if (!(elementType.
isF32() || elementType.
isF16()))
990 "only f32 and f16 type is supported.");
998 const int intervalsCount = 3;
999 const int polyDegree = 4;
1003 Value pp[intervalsCount][polyDegree + 1];
1004 pp[0][0] = bcast(
floatCst(builder, +0.00000000000000000e+00f, elementType));
1005 pp[0][1] = bcast(
floatCst(builder, +1.12837916222975858e+00f, elementType));
1006 pp[0][2] = bcast(
floatCst(builder, -5.23018562988006470e-01f, elementType));
1007 pp[0][3] = bcast(
floatCst(builder, +2.09741709609267072e-01f, elementType));
1008 pp[0][4] = bcast(
floatCst(builder, +2.58146801602987875e-02f, elementType));
1009 pp[1][0] = bcast(
floatCst(builder, +0.00000000000000000e+00f, elementType));
1010 pp[1][1] = bcast(
floatCst(builder, +1.12750687816789140e+00f, elementType));
1011 pp[1][2] = bcast(
floatCst(builder, -3.64721408487825775e-01f, elementType));
1012 pp[1][3] = bcast(
floatCst(builder, +1.18407396425136952e-01f, elementType));
1013 pp[1][4] = bcast(
floatCst(builder, +3.70645533056476558e-02f, elementType));
1014 pp[2][0] = bcast(
floatCst(builder, -3.30093071049483172e-03f, elementType));
1015 pp[2][1] = bcast(
floatCst(builder, +3.51961938357697011e-03f, elementType));
1016 pp[2][2] = bcast(
floatCst(builder, -1.41373622814988039e-03f, elementType));
1017 pp[2][3] = bcast(
floatCst(builder, +2.53447094961941348e-04f, elementType));
1018 pp[2][4] = bcast(
floatCst(builder, -1.71048029455037401e-05f, elementType));
1020 Value qq[intervalsCount][polyDegree + 1];
1021 qq[0][0] = bcast(
floatCst(builder, +1.000000000000000000e+00f, elementType));
1022 qq[0][1] = bcast(
floatCst(builder, -4.635138185962547255e-01f, elementType));
1023 qq[0][2] = bcast(
floatCst(builder, +5.192301327279782447e-01f, elementType));
1024 qq[0][3] = bcast(
floatCst(builder, -1.318089722204810087e-01f, elementType));
1025 qq[0][4] = bcast(
floatCst(builder, +7.397964654672315005e-02f, elementType));
1026 qq[1][0] = bcast(
floatCst(builder, +1.00000000000000000e+00f, elementType));
1027 qq[1][1] = bcast(
floatCst(builder, -3.27607011824493086e-01f, elementType));
1028 qq[1][2] = bcast(
floatCst(builder, +4.48369090658821977e-01f, elementType));
1029 qq[1][3] = bcast(
floatCst(builder, -8.83462621207857930e-02f, elementType));
1030 qq[1][4] = bcast(
floatCst(builder, +5.72442770283176093e-02f, elementType));
1031 qq[2][0] = bcast(
floatCst(builder, +1.00000000000000000e+00f, elementType));
1032 qq[2][1] = bcast(
floatCst(builder, -2.06069165953913769e+00f, elementType));
1033 qq[2][2] = bcast(
floatCst(builder, +1.62705939945477759e+00f, elementType));
1034 qq[2][3] = bcast(
floatCst(builder, -5.83389859211130017e-01f, elementType));
1035 qq[2][4] = bcast(
floatCst(builder, +8.21908939856640930e-02f, elementType));
1037 Value offsets[intervalsCount];
1038 offsets[0] = bcast(
floatCst(builder, 0.0f, elementType));
1039 offsets[1] = bcast(
floatCst(builder, 0.0f, elementType));
1040 offsets[2] = bcast(
floatCst(builder, 1.0f, elementType));
1042 Value bounds[intervalsCount];
1043 bounds[0] = bcast(
floatCst(builder, 0.8f, elementType));
1044 bounds[1] = bcast(
floatCst(builder, 2.0f, elementType));
1045 bounds[2] = bcast(
floatCst(builder, 3.75f, elementType));
1047 Value isNegativeArg =
1048 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
1049 Value negArg = builder.
create<arith::NegFOp>(operand);
1050 Value x = builder.
create<arith::SelectOp>(isNegativeArg, negArg, operand);
1052 Value offset = offsets[0];
1053 Value p[polyDegree + 1];
1054 Value q[polyDegree + 1];
1055 for (
int i = 0; i <= polyDegree; ++i) {
1061 Value isLessThanBound[intervalsCount];
1062 for (
int j = 0;
j < intervalsCount - 1; ++
j) {
1063 isLessThanBound[
j] =
1064 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[
j]);
1065 for (
int i = 0; i <= polyDegree; ++i) {
1066 p[i] = builder.
create<arith::SelectOp>(isLessThanBound[
j], p[i],
1068 q[i] = builder.
create<arith::SelectOp>(isLessThanBound[
j], q[i],
1071 offset = builder.
create<arith::SelectOp>(isLessThanBound[
j], offset,
1074 isLessThanBound[intervalsCount - 1] = builder.
create<arith::CmpFOp>(
1075 arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
1077 Value pPoly = makePolynomialCalculation(builder, p, x);
1078 Value qPoly = makePolynomialCalculation(builder, q, x);
1079 Value rationalPoly = builder.
create<arith::DivFOp>(pPoly, qPoly);
1080 Value formula = builder.
create<arith::AddFOp>(offset, rationalPoly);
1081 formula = builder.
create<arith::SelectOp>(isLessThanBound[intervalsCount - 1],
1085 Value negFormula = builder.
create<arith::NegFOp>(formula);
1087 builder.
create<arith::SelectOp>(isNegativeArg, negFormula, formula);
1101 Value value,
float lowerBound,
float upperBound) {
1102 assert(!std::isnan(lowerBound));
1103 assert(!std::isnan(upperBound));
1106 return broadcast(builder, value, shape);
1109 auto selectCmp = [&builder](
auto pred,
Value value,
Value bound) {
1110 return builder.
create<arith::SelectOp>(
1111 builder.
create<arith::CmpFOp>(pred, value, bound), value, bound);
1117 value = selectCmp(arith::CmpFPredicate::UGE, value,
1118 bcast(
f32Cst(builder, lowerBound)));
1119 value = selectCmp(arith::CmpFPredicate::ULE, value,
1120 bcast(
f32Cst(builder, upperBound)));
1128 LogicalResult matchAndRewrite(math::ExpOp op,
1133 ExpApproximation::matchAndRewrite(math::ExpOp op,
1137 if (!elementTy.isF32())
1143 return builder.
create<arith::AddFOp>(a, b);
1146 return broadcast(builder, value, shape);
1150 return builder.
create<math::FmaOp>(a, b, c);
1153 return builder.
create<arith::MulFOp>(a, b);
1181 Value cstLog2ef = bcast(
f32Cst(builder, 1.44269504088896341f));
1183 Value cstExpC1 = bcast(
f32Cst(builder, -0.693359375f));
1184 Value cstExpC2 = bcast(
f32Cst(builder, 2.12194440e-4f));
1185 Value cstExpP0 = bcast(
f32Cst(builder, 1.9875691500E-4f));
1186 Value cstExpP1 = bcast(
f32Cst(builder, 1.3981999507E-3f));
1187 Value cstExpP2 = bcast(
f32Cst(builder, 8.3334519073E-3f));
1188 Value cstExpP3 = bcast(
f32Cst(builder, 4.1665795894E-2f));
1189 Value cstExpP4 = bcast(
f32Cst(builder, 1.6666665459E-1f));
1190 Value cstExpP5 = bcast(
f32Cst(builder, 5.0000001201E-1f));
1198 x = clampWithNormals(builder, shape, x, -87.8f, 88.8f);
1199 Value n =
floor(fmla(x, cstLog2ef, cstHalf));
1240 n = clampWithNormals(builder, shape, n, -127.0f, 127.0f);
1243 x = fmla(cstExpC1, n, x);
1244 x = fmla(cstExpC2, n, x);
1247 Value z = fmla(x, cstExpP0, cstExpP1);
1248 z = fmla(z, x, cstExpP2);
1249 z = fmla(z, x, cstExpP3);
1250 z = fmla(z, x, cstExpP4);
1251 z = fmla(z, x, cstExpP5);
1252 z = fmla(z, mul(x, x), x);
1257 Value nI32 = builder.
create<arith::FPToSIOp>(i32Vec, n);
1263 Value ret = mul(z, pow2);
1266 return mlir::success();
1281 LogicalResult matchAndRewrite(math::ExpM1Op op,
1287 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
1296 return broadcast(builder, value, shape);
1307 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
1308 Value uMinusOne = builder.
create<arith::SubFOp>(u, cstOne);
1309 Value uMinusOneEqNegOne = builder.
create<arith::CmpFOp>(
1310 arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
1316 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u);
1320 uMinusOne, builder.
create<arith::DivFOp>(x, logU));
1321 expm1 = builder.
create<arith::SelectOp>(isInf, u, expm1);
1322 Value approximation = builder.
create<arith::SelectOp>(
1324 builder.
create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
1335 template <
bool isSine,
typename OpTy>
1340 LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter)
const final;
1344 #define TWO_OVER_PI \
1345 0.6366197723675813430755350534900574481378385829618257949906693762L
1347 1.5707963267948966192313216916397514420985846996875529104874722961L
1352 template <
bool isSine,
typename OpTy>
1353 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
1356 llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
1357 "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
1366 return broadcast(builder, value, shape);
1369 return builder.
create<arith::MulFOp>(a, b);
1372 return builder.
create<arith::SubFOp>(a, b);
1377 auto fPToSingedInteger = [&](
Value a) ->
Value {
1378 return builder.
create<arith::FPToSIOp>(i32Vec, a);
1382 return builder.
create<arith::AndIOp>(a, bcast(
i32Cst(builder, 3)));
1386 return builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b);
1390 return builder.
create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b);
1394 return builder.
create<arith::SelectOp>(cond, t, f);
1398 return builder.
create<math::FmaOp>(a, b, c);
1402 return builder.
create<arith::OrIOp>(a, b);
1412 Value y = sub(x, mul(k, piOverTwo));
1415 Value cstNegativeOne = bcast(
f32Cst(builder, -1.0));
1417 Value cstSC2 = bcast(
f32Cst(builder, -0.16666667163372039794921875f));
1418 Value cstSC4 = bcast(
f32Cst(builder, 8.333347737789154052734375e-3f));
1419 Value cstSC6 = bcast(
f32Cst(builder, -1.9842604524455964565277099609375e-4f));
1421 bcast(
f32Cst(builder, 2.760012648650445044040679931640625e-6f));
1423 bcast(
f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f));
1426 Value cstCC4 = bcast(
f32Cst(builder, 4.166664183139801025390625e-2f));
1427 Value cstCC6 = bcast(
f32Cst(builder, -1.388833043165504932403564453125e-3f));
1428 Value cstCC8 = bcast(
f32Cst(builder, 2.47562347794882953166961669921875e-5f));
1430 bcast(
f32Cst(builder, -2.59630184018533327616751194000244140625e-7f));
1432 Value kMod4 = modulo4(fPToSingedInteger(k));
1434 Value kR0 = isEqualTo(kMod4, bcast(
i32Cst(builder, 0)));
1435 Value kR1 = isEqualTo(kMod4, bcast(
i32Cst(builder, 1)));
1436 Value kR2 = isEqualTo(kMod4, bcast(
i32Cst(builder, 2)));
1437 Value kR3 = isEqualTo(kMod4, bcast(
i32Cst(builder, 3)));
1439 Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
1440 Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(
i32Cst(builder, 1)))
1441 : bitwiseOr(kR1, kR2);
1443 Value y2 = mul(y, y);
1445 Value base = select(sinuseCos, cstOne, y);
1446 Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
1447 Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
1448 Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
1449 Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
1450 Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
1452 Value v1 = fmla(y2, cstC10, cstC8);
1453 Value v2 = fmla(y2, v1, cstC6);
1454 Value v3 = fmla(y2, v2, cstC4);
1455 Value v4 = fmla(y2, v3, cstC2);
1456 Value v5 = fmla(y2, v4, cstOne);
1457 Value v6 = mul(base, v5);
1459 Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6);
1474 LogicalResult matchAndRewrite(math::CbrtOp op,
1482 CbrtApproximation::matchAndRewrite(math::CbrtOp op,
1498 auto bconst = [&](TypedAttr attr) ->
Value {
1499 Value value = b.create<arith::ConstantOp>(attr);
1504 Value intTwo = bconst(b.getI32IntegerAttr(2));
1505 Value intFour = bconst(b.getI32IntegerAttr(4));
1506 Value intEight = bconst(b.getI32IntegerAttr(8));
1507 Value intMagic = bconst(b.getI32IntegerAttr(0x2a5137a0));
1508 Value fpThird = bconst(b.getF32FloatAttr(0.33333333f));
1509 Value fpTwo = bconst(b.getF32FloatAttr(2.0f));
1510 Value fpZero = bconst(b.getF32FloatAttr(0.0f));
1516 Value absValue = b.create<math::AbsFOp>(operand);
1517 Value intValue = b.create<arith::BitcastOp>(intTy, absValue);
1518 Value divideBy4 = b.create<arith::ShRSIOp>(intValue, intTwo);
1519 Value divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
1520 intValue = b.create<arith::AddIOp>(divideBy4, divideBy16);
1523 divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
1524 intValue = b.create<arith::AddIOp>(intValue, divideBy16);
1527 Value divideBy256 = b.create<arith::ShRSIOp>(intValue, intEight);
1528 intValue = b.create<arith::AddIOp>(intValue, divideBy256);
1531 intValue = b.create<arith::AddIOp>(intValue, intMagic);
1535 Value floatValue = b.create<arith::BitcastOp>(floatTy, intValue);
1536 Value squared = b.create<arith::MulFOp>(floatValue, floatValue);
1537 Value mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
1538 Value divSquared = b.create<arith::DivFOp>(absValue, squared);
1539 floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
1540 floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
1543 squared = b.create<arith::MulFOp>(floatValue, floatValue);
1544 mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
1545 divSquared = b.create<arith::DivFOp>(absValue, squared);
1546 floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
1547 floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
1551 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absValue, fpZero);
1552 floatValue = b.create<arith::SelectOp>(isZero, fpZero, floatValue);
1553 floatValue = b.create<math::CopySignOp>(floatValue, operand);
1567 LogicalResult matchAndRewrite(math::RsqrtOp op,
1573 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
1581 if (shape.
empty() || shape.
sizes.back() % 8 != 0)
1586 return broadcast(builder, value, shape);
1590 Value cstOnePointFive = bcast(
f32Cst(builder, 1.5f));
1599 arith::CmpFPredicate::OLT, op.
getOperand(), cstMinNormPos);
1600 Value infMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
1602 Value notNormalFiniteMask = builder.
create<arith::OrIOp>(ltMinMask, infMask);
1607 return builder.create<x86vector::RsqrtOp>(operands);
1614 Value inner = builder.
create<arith::MulFOp>(negHalf, yApprox);
1615 Value fma = builder.
create<math::FmaOp>(yApprox, inner, cstOnePointFive);
1616 Value yNewton = builder.
create<arith::MulFOp>(yApprox, fma);
1624 builder.
create<arith::SelectOp>(notNormalFiniteMask, yApprox, yNewton);
1647 .
add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
1648 ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
1649 ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
1650 ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
1651 ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
1652 ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
1656 .
add<AtanApproximation, Atan2Approximation, TanhApproximation,
1657 LogApproximation, Log2Approximation, Log1pApproximation,
1659 AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1660 CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
1661 SinAndCosApproximation<false, math::CosOp>>(patterns.
getContext());
1663 patterns.
add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
static llvm::ManagedStatic< PassManagerOptions > options
static std::pair< Value, Value > frexp(ImplicitLocOpBuilder &builder, Value arg, bool isPositive=false)
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg)
static Type broadcast(Type type, VectorShape shape)
static Value floatCst(ImplicitLocOpBuilder &builder, float value, Type elementType)
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter)
static VectorShape vectorShape(Type type)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits)
static Value f32Cst(ImplicitLocOpBuilder &builder, double value)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getI32IntegerAttr(int32_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
FloatAttr getF32FloatAttr(float value)
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DynamicAPInt floor(const Fraction &f)
Fraction abs(const Fraction &f)
Include the generated interface declarations.
void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns)
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
int64_t computeMaxLinearIndex(ArrayRef< int64_t > basis)
Return the number of elements of basis (i.e.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns, const MathPolynomialApproximationOptions &options={})
ArrayRef< int64_t > sizes
ArrayRef< bool > scalableFlags
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(math::ErfOp op, PatternRewriter &rewriter) const final
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.