34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/Support/MathExtras.h"
47 bool empty()
const {
return sizes.empty(); }
53 auto vectorType = dyn_cast<VectorType>(type);
55 ?
VectorShape{vectorType.getShape(), vectorType.getScalableDims()}
69 assert(!isa<VectorType>(type) &&
"must be scalar type");
78 assert(!isa<VectorType>(value.
getType()) &&
"must be scalar value");
80 return !shape.
empty() ? builder.
create<BroadcastOp>(type, value) : value;
106 assert(!operands.empty() &&
"operands must be not empty");
107 assert(vectorWidth > 0 &&
"vector width must be larger than 0");
109 VectorType inputType = cast<VectorType>(operands[0].getType());
115 return compute(operands);
119 int64_t innerDim = inputShape.back();
120 int64_t expansionDim = innerDim / vectorWidth;
121 assert((innerDim % vectorWidth == 0) &&
"invalid inner dimension size");
128 if (expansionDim > 1) {
130 expandedShape.insert(expandedShape.end() - 1, expansionDim);
131 expandedShape.back() = vectorWidth;
133 for (
unsigned i = 0; i < operands.size(); ++i) {
134 auto operand = operands[i];
135 auto eltType = cast<VectorType>(operand.getType()).getElementType();
137 expandedOperands[i] =
138 builder.
create<vector::ShapeCastOp>(expandedType, operand);
150 for (int64_t i = 0; i < maxIndex; ++i) {
155 extracted[tuple.index()] =
156 builder.
create<vector::ExtractOp>(tuple.value(), offsets);
158 results[i] = compute(extracted);
162 Type resultEltType = cast<VectorType>(results[0].getType()).getElementType();
165 resultExpandedType, builder.
getZeroAttr(resultExpandedType));
167 for (int64_t i = 0; i < maxIndex; ++i)
168 result = builder.
create<vector::InsertOp>(results[i], result,
172 return builder.
create<vector::ShapeCastOp>(
182 assert((elementType.
isF16() || elementType.
isF32()) &&
183 "x must be f16 or f32 type.");
184 return builder.
create<arith::ConstantOp>(
197 Value i32Value =
i32Cst(builder,
static_cast<int32_t
>(bits));
207 return builder.
create<arith::SelectOp>(
208 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::ULT, value, bound),
214 return builder.
create<arith::SelectOp>(
215 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::UGT, value, bound),
222 return max(builder,
min(builder, value, upperBound), lowerBound);
228 bool isPositive =
false) {
245 Value i32Half = builder.
create<arith::BitcastOp>(i32, cstHalf);
246 Value i32InvMantMask = builder.
create<arith::BitcastOp>(i32, cstInvMantMask);
247 Value i32Arg = builder.
create<arith::BitcastOp>(i32Vec, arg);
250 Value tmp0 = builder.
create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask));
251 Value tmp1 = builder.
create<arith::OrIOp>(tmp0, bcast(i32Half));
252 Value normalizedFraction = builder.
create<arith::BitcastOp>(f32Vec, tmp1);
255 Value arg0 = isPositive ? arg : builder.
create<math::AbsFOp>(arg);
256 Value biasedExponentBits = builder.
create<arith::ShRUIOp>(
257 builder.
create<arith::BitcastOp>(i32Vec, arg0),
258 bcast(
i32Cst(builder, 23)));
259 Value biasedExponent =
260 builder.
create<arith::SIToFPOp>(f32Vec, biasedExponentBits);
262 builder.
create<arith::SubFOp>(biasedExponent, bcast(cst126f));
264 return {normalizedFraction, exponent};
278 auto exponetBitLocation = bcast(
i32Cst(builder, 23));
280 auto bias = bcast(
i32Cst(builder, 127));
282 Value biasedArg = builder.
create<arith::AddIOp>(arg, bias);
284 builder.
create<arith::ShLIOp>(biasedArg, exponetBitLocation);
285 Value exp2ValueF32 = builder.
create<arith::BitcastOp>(f32Vec, exp2ValueInt);
294 assert((elementType.
isF32() || elementType.
isF16()) &&
295 "x must be f32 or f16 type");
301 if (coeffs.size() == 1)
304 Value res = builder.
create<math::FmaOp>(x, coeffs[coeffs.size() - 1],
305 coeffs[coeffs.size() - 2]);
306 for (
auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) {
307 res = builder.
create<math::FmaOp>(x, res, coeffs[i]);
317 template <
typename T>
335 if (
auto shaped = dyn_cast<ShapedType>(origType)) {
336 newType = shaped.clone(rewriter.
getF32Type());
337 }
else if (isa<FloatType>(origType)) {
341 "unable to find F32 equivalent type");
347 operands.push_back(rewriter.
create<arith::ExtFOp>(loc, newType, operand));
360 template <
typename T>
366 T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(),
367 "requires same operands and result types");
368 return insertCasts<T>(op, rewriter);
388 AtanApproximation::matchAndRewrite(math::AtanOp op,
404 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
abs, twoThirds);
408 Value xden = builder.
create<arith::SelectOp>(cmp2, addone, one);
415 auto tan3pio8 = bcast(
f32Cst(builder, 2.41421356237309504880));
417 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
abs, tan3pio8);
418 xnum = builder.
create<arith::SelectOp>(cmp1, one, xnum);
419 xden = builder.
create<arith::SelectOp>(cmp1,
abs, xden);
421 Value x = builder.
create<arith::DivFOp>(xnum, xden);
426 auto p0 = bcast(
f32Cst(builder, -8.750608600031904122785e-01));
427 auto p1 = bcast(
f32Cst(builder, -1.615753718733365076637e+01));
428 auto p2 = bcast(
f32Cst(builder, -7.500855792314704667340e+01));
429 auto p3 = bcast(
f32Cst(builder, -1.228866684490136173410e+02));
430 auto p4 = bcast(
f32Cst(builder, -6.485021904942025371773e+01));
431 auto q0 = bcast(
f32Cst(builder, +2.485846490142306297962e+01));
432 auto q1 = bcast(
f32Cst(builder, +1.650270098316988542046e+02));
433 auto q2 = bcast(
f32Cst(builder, +4.328810604912902668951e+02));
434 auto q3 = bcast(
f32Cst(builder, +4.853903996359136964868e+02));
435 auto q4 = bcast(
f32Cst(builder, +1.945506571482613964425e+02));
439 n = builder.
create<math::FmaOp>(xx, n, p1);
440 n = builder.
create<math::FmaOp>(xx, n, p2);
441 n = builder.
create<math::FmaOp>(xx, n, p3);
442 n = builder.
create<math::FmaOp>(xx, n, p4);
443 n = builder.
create<arith::MulFOp>(n, xx);
447 d = builder.
create<math::FmaOp>(xx, d, q1);
448 d = builder.
create<math::FmaOp>(xx, d, q2);
449 d = builder.
create<math::FmaOp>(xx, d, q3);
450 d = builder.
create<math::FmaOp>(xx, d, q4);
454 ans0 = builder.
create<math::FmaOp>(ans0, x, x);
457 Value mpi4 = bcast(
f32Cst(builder, llvm::numbers::pi / 4));
458 Value ans2 = builder.
create<arith::AddFOp>(mpi4, ans0);
459 Value ans = builder.
create<arith::SelectOp>(cmp2, ans2, ans0);
461 Value mpi2 = bcast(
f32Cst(builder, llvm::numbers::pi / 2));
462 Value ans1 = builder.
create<arith::SubFOp>(mpi2, ans0);
463 ans = builder.
create<arith::SelectOp>(cmp1, ans1, ans);
485 Atan2Approximation::matchAndRewrite(math::Atan2Op op,
496 auto div = builder.
create<arith::DivFOp>(y, x);
497 auto atan = builder.
create<math::AtanOp>(div);
502 auto addPi = builder.
create<arith::AddFOp>(atan, pi);
503 auto subPi = builder.
create<arith::SubFOp>(atan, pi);
505 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, atan, zero);
506 auto flippedAtan = builder.
create<arith::SelectOp>(atanGt, subPi, addPi);
509 auto xGt = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zero);
510 Value result = builder.
create<arith::SelectOp>(xGt, atan, flippedAtan);
514 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, x, zero);
515 Value yGt = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, y, zero);
516 Value isHalfPi = builder.
create<arith::AndIOp>(xZero, yGt);
517 auto halfPi =
broadcast(builder,
f32Cst(builder, 1.57079632679f), shape);
518 result = builder.
create<arith::SelectOp>(isHalfPi, halfPi, result);
521 Value yLt = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, y, zero);
522 Value isNegativeHalfPiPi = builder.
create<arith::AndIOp>(xZero, yLt);
523 auto negativeHalfPiPi =
525 result = builder.
create<arith::SelectOp>(isNegativeHalfPiPi, negativeHalfPiPi,
530 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, y, zero);
531 Value isNan = builder.
create<arith::AndIOp>(xZero, yZero);
533 result = builder.
create<arith::SelectOp>(isNan, cstNan, result);
554 TanhApproximation::matchAndRewrite(math::TanhOp op,
567 Value minusClamp = bcast(
f32Cst(builder, -7.99881172180175781f));
568 Value plusClamp = bcast(
f32Cst(builder, 7.99881172180175781f));
578 Value alpha1 = bcast(
f32Cst(builder, 4.89352455891786e-03f));
579 Value alpha3 = bcast(
f32Cst(builder, 6.37261928875436e-04f));
580 Value alpha5 = bcast(
f32Cst(builder, 1.48572235717979e-05f));
581 Value alpha7 = bcast(
f32Cst(builder, 5.12229709037114e-08f));
582 Value alpha9 = bcast(
f32Cst(builder, -8.60467152213735e-11f));
583 Value alpha11 = bcast(
f32Cst(builder, 2.00018790482477e-13f));
584 Value alpha13 = bcast(
f32Cst(builder, -2.76076847742355e-16f));
587 Value beta0 = bcast(
f32Cst(builder, 4.89352518554385e-03f));
588 Value beta2 = bcast(
f32Cst(builder, 2.26843463243900e-03f));
589 Value beta4 = bcast(
f32Cst(builder, 1.18534705686654e-04f));
590 Value beta6 = bcast(
f32Cst(builder, 1.19825839466702e-06f));
596 Value p = builder.
create<math::FmaOp>(x2, alpha13, alpha11);
597 p = builder.
create<math::FmaOp>(x2, p, alpha9);
598 p = builder.
create<math::FmaOp>(x2, p, alpha7);
599 p = builder.
create<math::FmaOp>(x2, p, alpha5);
600 p = builder.
create<math::FmaOp>(x2, p, alpha3);
601 p = builder.
create<math::FmaOp>(x2, p, alpha1);
602 p = builder.
create<arith::MulFOp>(x, p);
605 Value q = builder.
create<math::FmaOp>(x2, beta6, beta4);
606 q = builder.
create<math::FmaOp>(x2, q, beta2);
607 q = builder.
create<math::FmaOp>(x2, q, beta0);
611 tinyMask, x, builder.
create<arith::DivFOp>(p, q));
619 0.693147180559945309417232121458176568075500134360255254120680009493393621L
620 #define LOG2E_VALUE \
621 1.442695040888963407359924681001892137426645954152985934135449406931109219L
628 template <
typename Op>
640 template <
typename Op>
665 Value cstCephesSQRTHF = bcast(
f32Cst(builder, 0.707106781186547524f));
666 Value cstCephesLogP0 = bcast(
f32Cst(builder, 7.0376836292E-2f));
667 Value cstCephesLogP1 = bcast(
f32Cst(builder, -1.1514610310E-1f));
668 Value cstCephesLogP2 = bcast(
f32Cst(builder, 1.1676998740E-1f));
669 Value cstCephesLogP3 = bcast(
f32Cst(builder, -1.2420140846E-1f));
670 Value cstCephesLogP4 = bcast(
f32Cst(builder, +1.4249322787E-1f));
671 Value cstCephesLogP5 = bcast(
f32Cst(builder, -1.6668057665E-1f));
672 Value cstCephesLogP6 = bcast(
f32Cst(builder, +2.0000714765E-1f));
673 Value cstCephesLogP7 = bcast(
f32Cst(builder, -2.4999993993E-1f));
674 Value cstCephesLogP8 = bcast(
f32Cst(builder, +3.3333331174E-1f));
679 x =
max(builder, x, cstMinNormPos);
682 std::pair<Value, Value> pair =
frexp(builder, x,
true);
684 Value e = pair.second;
694 Value mask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x,
696 Value tmp = builder.
create<arith::SelectOp>(mask, x, cstZero);
698 x = builder.
create<arith::SubFOp>(x, cstOne);
699 e = builder.
create<arith::SubFOp>(
700 e, builder.
create<arith::SelectOp>(mask, cstOne, cstZero));
701 x = builder.
create<arith::AddFOp>(x, tmp);
708 y0 = builder.
create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1);
709 y1 = builder.
create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4);
710 y2 = builder.
create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7);
711 y0 = builder.
create<math::FmaOp>(y0, x, cstCephesLogP2);
712 y1 = builder.
create<math::FmaOp>(y1, x, cstCephesLogP5);
713 y2 = builder.
create<math::FmaOp>(y2, x, cstCephesLogP8);
714 y0 = builder.
create<math::FmaOp>(y0, x3, y1);
715 y0 = builder.
create<math::FmaOp>(y0, x3, y2);
716 y0 = builder.
create<arith::MulFOp>(y0, x3);
718 y0 = builder.
create<math::FmaOp>(cstNegHalf, x2, y0);
719 x = builder.
create<arith::AddFOp>(x, y0);
723 x = builder.
create<math::FmaOp>(x, cstLog2e, e);
726 x = builder.
create<math::FmaOp>(e, cstLn2, x);
729 Value invalidMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::ULT,
731 Value zeroMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
733 Value posInfMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
740 Value aproximation = builder.
create<arith::SelectOp>(
741 zeroMask, cstMinusInf,
742 builder.
create<arith::SelectOp>(
744 builder.
create<arith::SelectOp>(posInfMask, cstPosInf, x)));
752 struct LogApproximation :
public LogApproximationBase<math::LogOp> {
753 using LogApproximationBase::LogApproximationBase;
757 return logMatchAndRewrite(op, rewriter,
false);
763 struct Log2Approximation :
public LogApproximationBase<math::Log2Op> {
764 using LogApproximationBase::LogApproximationBase;
768 return logMatchAndRewrite(op, rewriter,
true);
789 Log1pApproximation::matchAndRewrite(math::Log1pOp op,
809 Value u = builder.
create<arith::AddFOp>(x, cstOne);
811 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne);
814 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU);
816 x, builder.
create<arith::DivFOp>(
817 logU, builder.
create<arith::SubFOp>(u, cstOne)));
818 Value approximation = builder.
create<arith::SelectOp>(
819 builder.
create<arith::OrIOp>(uSmall, uInf), x, logLarge);
841 if (!(elementType.
isF32() || elementType.
isF16()))
843 "only f32 and f16 type is supported.");
851 const int intervalsCount = 3;
852 const int polyDegree = 4;
856 Value pp[intervalsCount][polyDegree + 1];
857 pp[0][0] = bcast(
floatCst(builder, +0.00000000000000000e+00f, elementType));
858 pp[0][1] = bcast(
floatCst(builder, +1.12837916222975858e+00f, elementType));
859 pp[0][2] = bcast(
floatCst(builder, -5.23018562988006470e-01f, elementType));
860 pp[0][3] = bcast(
floatCst(builder, +2.09741709609267072e-01f, elementType));
861 pp[0][4] = bcast(
floatCst(builder, +2.58146801602987875e-02f, elementType));
862 pp[1][0] = bcast(
floatCst(builder, +0.00000000000000000e+00f, elementType));
863 pp[1][1] = bcast(
floatCst(builder, +1.12750687816789140e+00f, elementType));
864 pp[1][2] = bcast(
floatCst(builder, -3.64721408487825775e-01f, elementType));
865 pp[1][3] = bcast(
floatCst(builder, +1.18407396425136952e-01f, elementType));
866 pp[1][4] = bcast(
floatCst(builder, +3.70645533056476558e-02f, elementType));
867 pp[2][0] = bcast(
floatCst(builder, -3.30093071049483172e-03f, elementType));
868 pp[2][1] = bcast(
floatCst(builder, +3.51961938357697011e-03f, elementType));
869 pp[2][2] = bcast(
floatCst(builder, -1.41373622814988039e-03f, elementType));
870 pp[2][3] = bcast(
floatCst(builder, +2.53447094961941348e-04f, elementType));
871 pp[2][4] = bcast(
floatCst(builder, -1.71048029455037401e-05f, elementType));
873 Value qq[intervalsCount][polyDegree + 1];
874 qq[0][0] = bcast(
floatCst(builder, +1.000000000000000000e+00f, elementType));
875 qq[0][1] = bcast(
floatCst(builder, -4.635138185962547255e-01f, elementType));
876 qq[0][2] = bcast(
floatCst(builder, +5.192301327279782447e-01f, elementType));
877 qq[0][3] = bcast(
floatCst(builder, -1.318089722204810087e-01f, elementType));
878 qq[0][4] = bcast(
floatCst(builder, +7.397964654672315005e-02f, elementType));
879 qq[1][0] = bcast(
floatCst(builder, +1.00000000000000000e+00f, elementType));
880 qq[1][1] = bcast(
floatCst(builder, -3.27607011824493086e-01f, elementType));
881 qq[1][2] = bcast(
floatCst(builder, +4.48369090658821977e-01f, elementType));
882 qq[1][3] = bcast(
floatCst(builder, -8.83462621207857930e-02f, elementType));
883 qq[1][4] = bcast(
floatCst(builder, +5.72442770283176093e-02f, elementType));
884 qq[2][0] = bcast(
floatCst(builder, +1.00000000000000000e+00f, elementType));
885 qq[2][1] = bcast(
floatCst(builder, -2.06069165953913769e+00f, elementType));
886 qq[2][2] = bcast(
floatCst(builder, +1.62705939945477759e+00f, elementType));
887 qq[2][3] = bcast(
floatCst(builder, -5.83389859211130017e-01f, elementType));
888 qq[2][4] = bcast(
floatCst(builder, +8.21908939856640930e-02f, elementType));
890 Value offsets[intervalsCount];
891 offsets[0] = bcast(
floatCst(builder, 0.0f, elementType));
892 offsets[1] = bcast(
floatCst(builder, 0.0f, elementType));
893 offsets[2] = bcast(
floatCst(builder, 1.0f, elementType));
895 Value bounds[intervalsCount];
896 bounds[0] = bcast(
floatCst(builder, 0.8f, elementType));
897 bounds[1] = bcast(
floatCst(builder, 2.0f, elementType));
898 bounds[2] = bcast(
floatCst(builder, 3.75f, elementType));
900 Value isNegativeArg =
901 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
902 Value negArg = builder.
create<arith::NegFOp>(operand);
903 Value x = builder.
create<arith::SelectOp>(isNegativeArg, negArg, operand);
905 Value offset = offsets[0];
906 Value p[polyDegree + 1];
907 Value q[polyDegree + 1];
908 for (
int i = 0; i <= polyDegree; ++i) {
914 Value isLessThanBound[intervalsCount];
915 for (
int j = 0;
j < intervalsCount - 1; ++
j) {
917 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[
j]);
918 for (
int i = 0; i <= polyDegree; ++i) {
919 p[i] = builder.
create<arith::SelectOp>(isLessThanBound[
j], p[i],
921 q[i] = builder.
create<arith::SelectOp>(isLessThanBound[
j], q[i],
924 offset = builder.
create<arith::SelectOp>(isLessThanBound[
j], offset,
927 isLessThanBound[intervalsCount - 1] = builder.
create<arith::CmpFOp>(
928 arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
930 Value pPoly = makePolynomialCalculation(builder, p, x);
931 Value qPoly = makePolynomialCalculation(builder, q, x);
932 Value rationalPoly = builder.
create<arith::DivFOp>(pPoly, qPoly);
933 Value formula = builder.
create<arith::AddFOp>(offset, rationalPoly);
934 formula = builder.
create<arith::SelectOp>(isLessThanBound[intervalsCount - 1],
938 Value negFormula = builder.
create<arith::NegFOp>(formula);
940 builder.
create<arith::SelectOp>(isNegativeArg, negFormula, formula);
954 Value value,
float lowerBound,
float upperBound) {
955 assert(!std::isnan(lowerBound));
956 assert(!std::isnan(upperBound));
962 auto selectCmp = [&builder](
auto pred,
Value value,
Value bound) {
963 return builder.
create<arith::SelectOp>(
964 builder.
create<arith::CmpFOp>(pred, value, bound), value, bound);
970 value = selectCmp(arith::CmpFPredicate::UGE, value,
971 bcast(
f32Cst(builder, lowerBound)));
972 value = selectCmp(arith::CmpFPredicate::ULE, value,
973 bcast(
f32Cst(builder, upperBound)));
986 ExpApproximation::matchAndRewrite(math::ExpOp op,
990 if (!elementTy.isF32())
996 return builder.
create<arith::AddFOp>(a, b);
1003 return builder.
create<math::FmaOp>(a, b, c);
1006 return builder.
create<arith::MulFOp>(a, b);
1034 Value cstLog2ef = bcast(
f32Cst(builder, 1.44269504088896341f));
1036 Value cstExpC1 = bcast(
f32Cst(builder, -0.693359375f));
1037 Value cstExpC2 = bcast(
f32Cst(builder, 2.12194440e-4f));
1038 Value cstExpP0 = bcast(
f32Cst(builder, 1.9875691500E-4f));
1039 Value cstExpP1 = bcast(
f32Cst(builder, 1.3981999507E-3f));
1040 Value cstExpP2 = bcast(
f32Cst(builder, 8.3334519073E-3f));
1041 Value cstExpP3 = bcast(
f32Cst(builder, 4.1665795894E-2f));
1042 Value cstExpP4 = bcast(
f32Cst(builder, 1.6666665459E-1f));
1043 Value cstExpP5 = bcast(
f32Cst(builder, 5.0000001201E-1f));
1051 x = clampWithNormals(builder, shape, x, -87.8f, 88.8f);
1052 Value n =
floor(fmla(x, cstLog2ef, cstHalf));
1093 n = clampWithNormals(builder, shape, n, -127.0f, 127.0f);
1096 x = fmla(cstExpC1, n, x);
1097 x = fmla(cstExpC2, n, x);
1100 Value z = fmla(x, cstExpP0, cstExpP1);
1101 z = fmla(z, x, cstExpP2);
1102 z = fmla(z, x, cstExpP3);
1103 z = fmla(z, x, cstExpP4);
1104 z = fmla(z, x, cstExpP5);
1105 z = fmla(z, mul(x, x), x);
1110 Value nI32 = builder.
create<arith::FPToSIOp>(i32Vec, n);
1116 Value ret = mul(z, pow2);
1140 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
1149 return broadcast(builder, value, shape);
1160 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
1161 Value uMinusOne = builder.
create<arith::SubFOp>(u, cstOne);
1162 Value uMinusOneEqNegOne = builder.
create<arith::CmpFOp>(
1163 arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
1169 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u);
1173 uMinusOne, builder.
create<arith::DivFOp>(x, logU));
1174 expm1 = builder.
create<arith::SelectOp>(isInf, u, expm1);
1175 Value approximation = builder.
create<arith::SelectOp>(
1177 builder.
create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
1188 template <
bool isSine,
typename OpTy>
1197 #define TWO_OVER_PI \
1198 0.6366197723675813430755350534900574481378385829618257949906693762L
1200 1.5707963267948966192313216916397514420985846996875529104874722961L
1205 template <
bool isSine,
typename OpTy>
1206 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
1209 llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
1210 "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
1219 return broadcast(builder, value, shape);
1222 return builder.
create<arith::MulFOp>(a, b);
1225 return builder.
create<arith::SubFOp>(a, b);
1230 auto fPToSingedInteger = [&](
Value a) ->
Value {
1231 return builder.
create<arith::FPToSIOp>(i32Vec, a);
1235 return builder.
create<arith::AndIOp>(a, bcast(
i32Cst(builder, 3)));
1239 return builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b);
1243 return builder.
create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b);
1247 return builder.
create<arith::SelectOp>(cond, t, f);
1251 return builder.
create<math::FmaOp>(a, b, c);
1255 return builder.
create<arith::OrIOp>(a, b);
1265 Value y = sub(x, mul(k, piOverTwo));
1268 Value cstNegativeOne = bcast(
f32Cst(builder, -1.0));
1270 Value cstSC2 = bcast(
f32Cst(builder, -0.16666667163372039794921875f));
1271 Value cstSC4 = bcast(
f32Cst(builder, 8.333347737789154052734375e-3f));
1272 Value cstSC6 = bcast(
f32Cst(builder, -1.9842604524455964565277099609375e-4f));
1274 bcast(
f32Cst(builder, 2.760012648650445044040679931640625e-6f));
1276 bcast(
f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f));
1279 Value cstCC4 = bcast(
f32Cst(builder, 4.166664183139801025390625e-2f));
1280 Value cstCC6 = bcast(
f32Cst(builder, -1.388833043165504932403564453125e-3f));
1281 Value cstCC8 = bcast(
f32Cst(builder, 2.47562347794882953166961669921875e-5f));
1283 bcast(
f32Cst(builder, -2.59630184018533327616751194000244140625e-7f));
1285 Value kMod4 = modulo4(fPToSingedInteger(k));
1287 Value kR0 = isEqualTo(kMod4, bcast(
i32Cst(builder, 0)));
1288 Value kR1 = isEqualTo(kMod4, bcast(
i32Cst(builder, 1)));
1289 Value kR2 = isEqualTo(kMod4, bcast(
i32Cst(builder, 2)));
1290 Value kR3 = isEqualTo(kMod4, bcast(
i32Cst(builder, 3)));
1292 Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
1293 Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(
i32Cst(builder, 1)))
1294 : bitwiseOr(kR1, kR2);
1296 Value y2 = mul(y, y);
1298 Value base = select(sinuseCos, cstOne, y);
1299 Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
1300 Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
1301 Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
1302 Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
1303 Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
1305 Value v1 = fmla(y2, cstC10, cstC8);
1306 Value v2 = fmla(y2, v1, cstC6);
1307 Value v3 = fmla(y2, v2, cstC4);
1308 Value v4 = fmla(y2, v3, cstC2);
1309 Value v5 = fmla(y2, v4, cstOne);
1310 Value v6 = mul(base, v5);
1312 Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6);
1335 CbrtApproximation::matchAndRewrite(math::CbrtOp op,
1351 auto bconst = [&](TypedAttr attr) ->
Value {
1352 Value value = b.create<arith::ConstantOp>(attr);
1357 Value intTwo = bconst(b.getI32IntegerAttr(2));
1358 Value intFour = bconst(b.getI32IntegerAttr(4));
1359 Value intEight = bconst(b.getI32IntegerAttr(8));
1360 Value intMagic = bconst(b.getI32IntegerAttr(0x2a5137a0));
1361 Value fpThird = bconst(b.getF32FloatAttr(0.33333333f));
1362 Value fpTwo = bconst(b.getF32FloatAttr(2.0f));
1363 Value fpZero = bconst(b.getF32FloatAttr(0.0f));
1369 Value absValue = b.create<math::AbsFOp>(operand);
1370 Value intValue = b.create<arith::BitcastOp>(intTy, absValue);
1371 Value divideBy4 = b.create<arith::ShRSIOp>(intValue, intTwo);
1372 Value divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
1373 intValue = b.create<arith::AddIOp>(divideBy4, divideBy16);
1376 divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
1377 intValue = b.create<arith::AddIOp>(intValue, divideBy16);
1380 Value divideBy256 = b.create<arith::ShRSIOp>(intValue, intEight);
1381 intValue = b.create<arith::AddIOp>(intValue, divideBy256);
1384 intValue = b.create<arith::AddIOp>(intValue, intMagic);
1388 Value floatValue = b.create<arith::BitcastOp>(floatTy, intValue);
1389 Value squared = b.create<arith::MulFOp>(floatValue, floatValue);
1390 Value mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
1391 Value divSquared = b.create<arith::DivFOp>(absValue, squared);
1392 floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
1393 floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
1396 squared = b.create<arith::MulFOp>(floatValue, floatValue);
1397 mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
1398 divSquared = b.create<arith::DivFOp>(absValue, squared);
1399 floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
1400 floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
1404 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absValue, fpZero);
1405 floatValue = b.create<arith::SelectOp>(isZero, fpZero, floatValue);
1406 floatValue = b.create<math::CopySignOp>(floatValue, operand);
1426 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
1434 if (shape.
empty() || shape.
sizes.back() % 8 != 0)
1439 return broadcast(builder, value, shape);
1443 Value cstOnePointFive = bcast(
f32Cst(builder, 1.5f));
1452 arith::CmpFPredicate::OLT, op.
getOperand(), cstMinNormPos);
1453 Value infMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
1455 Value notNormalFiniteMask = builder.
create<arith::OrIOp>(ltMinMask, infMask);
1460 return builder.create<x86vector::RsqrtOp>(operands);
1467 Value inner = builder.
create<arith::MulFOp>(negHalf, yApprox);
1468 Value fma = builder.
create<math::FmaOp>(yApprox, inner, cstOnePointFive);
1469 Value yNewton = builder.
create<arith::MulFOp>(yApprox, fma);
1477 builder.
create<arith::SelectOp>(notNormalFiniteMask, yApprox, yNewton);
1500 .
add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
1501 ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
1502 ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
1503 ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
1504 ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
1505 ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
1508 patterns.
add<AtanApproximation, Atan2Approximation, TanhApproximation,
1509 LogApproximation, Log2Approximation, Log1pApproximation,
1511 CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
1512 SinAndCosApproximation<false, math::CosOp>>(
1515 patterns.
add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
static llvm::ManagedStatic< PassManagerOptions > options
static std::pair< Value, Value > frexp(ImplicitLocOpBuilder &builder, Value arg, bool isPositive=false)
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg)
static Type broadcast(Type type, VectorShape shape)
static Value floatCst(ImplicitLocOpBuilder &builder, float value, Type elementType)
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter)
static VectorShape vectorShape(Type type)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits)
static Value f32Cst(ImplicitLocOpBuilder &builder, double value)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getI32IntegerAttr(int32_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
FloatAttr getF32FloatAttr(float value)
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Fraction abs(const Fraction &f)
MPInt floor(const Fraction &f)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns)
void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns)
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
int64_t computeMaxLinearIndex(ArrayRef< int64_t > basis)
Return the number of elements of basis (i.e.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns, const MathPolynomialApproximationOptions &options={})
ArrayRef< int64_t > sizes
ArrayRef< bool > scalableFlags
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(math::ErfOp op, PatternRewriter &rewriter) const final
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.