34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/Support/MathExtras.h"
47 bool empty()
const {
return sizes.empty(); }
53 auto vectorType = dyn_cast<VectorType>(type);
55 ?
VectorShape{vectorType.getShape(), vectorType.getScalableDims()}
69 assert(!isa<VectorType>(type) &&
"must be scalar type");
78 assert(!isa<VectorType>(value.
getType()) &&
"must be scalar value");
80 return !shape.
empty() ? builder.
create<BroadcastOp>(type, value) : value;
106 assert(!operands.empty() &&
"operands must be not empty");
107 assert(vectorWidth > 0 &&
"vector width must be larger than 0");
109 VectorType inputType = cast<VectorType>(operands[0].
getType());
115 return compute(operands);
119 int64_t innerDim = inputShape.back();
120 int64_t expansionDim = innerDim / vectorWidth;
121 assert((innerDim % vectorWidth == 0) &&
"invalid inner dimension size");
128 if (expansionDim > 1) {
130 expandedShape.insert(expandedShape.end() - 1, expansionDim);
131 expandedShape.back() = vectorWidth;
133 for (
unsigned i = 0; i < operands.size(); ++i) {
134 auto operand = operands[i];
135 auto eltType = cast<VectorType>(operand.getType()).getElementType();
137 expandedOperands[i] =
138 builder.
create<vector::ShapeCastOp>(expandedType, operand);
150 for (int64_t i = 0; i < maxIndex; ++i) {
155 extracted[tuple.index()] =
156 builder.
create<vector::ExtractOp>(tuple.value(), offsets);
158 results[i] = compute(extracted);
162 Type resultEltType = cast<VectorType>(results[0].
getType()).getElementType();
165 resultExpandedType, builder.
getZeroAttr(resultExpandedType));
167 for (int64_t i = 0; i < maxIndex; ++i)
168 result = builder.
create<vector::InsertOp>(results[i], result,
172 return builder.
create<vector::ShapeCastOp>(
182 assert((elementType.
isF16() || elementType.
isF32()) &&
183 "x must be f16 or f32 type.");
184 return builder.
create<arith::ConstantOp>(
197 Value i32Value =
i32Cst(builder,
static_cast<int32_t
>(bits));
207 return builder.
create<arith::SelectOp>(
208 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::ULT, value, bound),
214 return builder.
create<arith::SelectOp>(
215 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::UGT, value, bound),
222 return max(builder,
min(builder, value, upperBound), lowerBound);
228 bool isPositive =
false) {
245 Value i32Half = builder.
create<arith::BitcastOp>(i32, cstHalf);
246 Value i32InvMantMask = builder.
create<arith::BitcastOp>(i32, cstInvMantMask);
247 Value i32Arg = builder.
create<arith::BitcastOp>(i32Vec, arg);
250 Value tmp0 = builder.
create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask));
251 Value tmp1 = builder.
create<arith::OrIOp>(tmp0, bcast(i32Half));
252 Value normalizedFraction = builder.
create<arith::BitcastOp>(f32Vec, tmp1);
255 Value arg0 = isPositive ? arg : builder.
create<math::AbsFOp>(arg);
256 Value biasedExponentBits = builder.
create<arith::ShRUIOp>(
257 builder.
create<arith::BitcastOp>(i32Vec, arg0),
258 bcast(
i32Cst(builder, 23)));
259 Value biasedExponent =
260 builder.
create<arith::SIToFPOp>(f32Vec, biasedExponentBits);
262 builder.
create<arith::SubFOp>(biasedExponent, bcast(cst126f));
264 return {normalizedFraction, exponent};
278 auto exponetBitLocation = bcast(
i32Cst(builder, 23));
280 auto bias = bcast(
i32Cst(builder, 127));
282 Value biasedArg = builder.
create<arith::AddIOp>(arg, bias);
284 builder.
create<arith::ShLIOp>(biasedArg, exponetBitLocation);
285 Value exp2ValueF32 = builder.
create<arith::BitcastOp>(f32Vec, exp2ValueInt);
294 assert((elementType.
isF32() || elementType.
isF16()) &&
295 "x must be f32 or f16 type");
301 if (coeffs.size() == 1)
304 Value res = builder.
create<math::FmaOp>(x, coeffs[coeffs.size() - 1],
305 coeffs[coeffs.size() - 2]);
306 for (
auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) {
307 res = builder.
create<math::FmaOp>(x, res, coeffs[i]);
317 template <
typename T>
335 if (
auto shaped = dyn_cast<ShapedType>(origType)) {
336 newType = shaped.clone(rewriter.
getF32Type());
337 }
else if (isa<FloatType>(origType)) {
341 "unable to find F32 equivalent type");
347 operands.push_back(rewriter.
create<arith::ExtFOp>(loc, newType, operand));
360 template <
typename T>
364 LogicalResult matchAndRewrite(T op,
PatternRewriter &rewriter)
const final {
366 T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(),
367 "requires same operands and result types");
368 return insertCasts<T>(op, rewriter);
382 LogicalResult matchAndRewrite(math::AtanOp op,
388 AtanApproximation::matchAndRewrite(math::AtanOp op,
404 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
abs, twoThirds);
408 Value xden = builder.
create<arith::SelectOp>(cmp2, addone, one);
415 auto tan3pio8 = bcast(
f32Cst(builder, 2.41421356237309504880));
417 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
abs, tan3pio8);
418 xnum = builder.
create<arith::SelectOp>(cmp1, one, xnum);
419 xden = builder.
create<arith::SelectOp>(cmp1,
abs, xden);
421 Value x = builder.
create<arith::DivFOp>(xnum, xden);
426 auto p0 = bcast(
f32Cst(builder, -8.750608600031904122785e-01));
427 auto p1 = bcast(
f32Cst(builder, -1.615753718733365076637e+01));
428 auto p2 = bcast(
f32Cst(builder, -7.500855792314704667340e+01));
429 auto p3 = bcast(
f32Cst(builder, -1.228866684490136173410e+02));
430 auto p4 = bcast(
f32Cst(builder, -6.485021904942025371773e+01));
431 auto q0 = bcast(
f32Cst(builder, +2.485846490142306297962e+01));
432 auto q1 = bcast(
f32Cst(builder, +1.650270098316988542046e+02));
433 auto q2 = bcast(
f32Cst(builder, +4.328810604912902668951e+02));
434 auto q3 = bcast(
f32Cst(builder, +4.853903996359136964868e+02));
435 auto q4 = bcast(
f32Cst(builder, +1.945506571482613964425e+02));
439 n = builder.
create<math::FmaOp>(xx, n, p1);
440 n = builder.
create<math::FmaOp>(xx, n, p2);
441 n = builder.
create<math::FmaOp>(xx, n, p3);
442 n = builder.
create<math::FmaOp>(xx, n, p4);
443 n = builder.
create<arith::MulFOp>(n, xx);
447 d = builder.
create<math::FmaOp>(xx, d, q1);
448 d = builder.
create<math::FmaOp>(xx, d, q2);
449 d = builder.
create<math::FmaOp>(xx, d, q3);
450 d = builder.
create<math::FmaOp>(xx, d, q4);
454 ans0 = builder.
create<math::FmaOp>(ans0, x, x);
457 Value mpi4 = bcast(
f32Cst(builder, llvm::numbers::pi / 4));
458 Value ans2 = builder.
create<arith::AddFOp>(mpi4, ans0);
459 Value ans = builder.
create<arith::SelectOp>(cmp2, ans2, ans0);
461 Value mpi2 = bcast(
f32Cst(builder, llvm::numbers::pi / 2));
462 Value ans1 = builder.
create<arith::SubFOp>(mpi2, ans0);
463 ans = builder.
create<arith::SelectOp>(cmp1, ans1, ans);
479 LogicalResult matchAndRewrite(math::Atan2Op op,
485 Atan2Approximation::matchAndRewrite(math::Atan2Op op,
496 auto div = builder.
create<arith::DivFOp>(y, x);
497 auto atan = builder.
create<math::AtanOp>(div);
502 auto addPi = builder.
create<arith::AddFOp>(atan, pi);
503 auto subPi = builder.
create<arith::SubFOp>(atan, pi);
505 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, atan, zero);
506 auto flippedAtan = builder.
create<arith::SelectOp>(atanGt, subPi, addPi);
509 auto xGt = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zero);
510 Value result = builder.
create<arith::SelectOp>(xGt, atan, flippedAtan);
514 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, x, zero);
515 Value yGt = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, y, zero);
516 Value isHalfPi = builder.
create<arith::AndIOp>(xZero, yGt);
517 auto halfPi =
broadcast(builder,
f32Cst(builder, 1.57079632679f), shape);
518 result = builder.
create<arith::SelectOp>(isHalfPi, halfPi, result);
521 Value yLt = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, y, zero);
522 Value isNegativeHalfPiPi = builder.
create<arith::AndIOp>(xZero, yLt);
523 auto negativeHalfPiPi =
525 result = builder.
create<arith::SelectOp>(isNegativeHalfPiPi, negativeHalfPiPi,
530 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, y, zero);
531 Value isNan = builder.
create<arith::AndIOp>(xZero, yZero);
533 result = builder.
create<arith::SelectOp>(isNan, cstNan, result);
548 LogicalResult matchAndRewrite(math::TanhOp op,
554 TanhApproximation::matchAndRewrite(math::TanhOp op,
567 Value minusClamp = bcast(
f32Cst(builder, -7.99881172180175781f));
568 Value plusClamp = bcast(
f32Cst(builder, 7.99881172180175781f));
578 Value alpha1 = bcast(
f32Cst(builder, 4.89352455891786e-03f));
579 Value alpha3 = bcast(
f32Cst(builder, 6.37261928875436e-04f));
580 Value alpha5 = bcast(
f32Cst(builder, 1.48572235717979e-05f));
581 Value alpha7 = bcast(
f32Cst(builder, 5.12229709037114e-08f));
582 Value alpha9 = bcast(
f32Cst(builder, -8.60467152213735e-11f));
583 Value alpha11 = bcast(
f32Cst(builder, 2.00018790482477e-13f));
584 Value alpha13 = bcast(
f32Cst(builder, -2.76076847742355e-16f));
587 Value beta0 = bcast(
f32Cst(builder, 4.89352518554385e-03f));
588 Value beta2 = bcast(
f32Cst(builder, 2.26843463243900e-03f));
589 Value beta4 = bcast(
f32Cst(builder, 1.18534705686654e-04f));
590 Value beta6 = bcast(
f32Cst(builder, 1.19825839466702e-06f));
596 Value p = builder.
create<math::FmaOp>(x2, alpha13, alpha11);
597 p = builder.
create<math::FmaOp>(x2, p, alpha9);
598 p = builder.
create<math::FmaOp>(x2, p, alpha7);
599 p = builder.
create<math::FmaOp>(x2, p, alpha5);
600 p = builder.
create<math::FmaOp>(x2, p, alpha3);
601 p = builder.
create<math::FmaOp>(x2, p, alpha1);
602 p = builder.
create<arith::MulFOp>(x, p);
605 Value q = builder.
create<math::FmaOp>(x2, beta6, beta4);
606 q = builder.
create<math::FmaOp>(x2, q, beta2);
607 q = builder.
create<math::FmaOp>(x2, q, beta0);
611 tinyMask, x, builder.
create<arith::DivFOp>(p, q));
619 0.693147180559945309417232121458176568075500134360255254120680009493393621L
620 #define LOG2E_VALUE \
621 1.442695040888963407359924681001892137426645954152985934135449406931109219L
628 template <
typename Op>
640 template <
typename Op>
665 Value cstCephesSQRTHF = bcast(
f32Cst(builder, 0.707106781186547524f));
666 Value cstCephesLogP0 = bcast(
f32Cst(builder, 7.0376836292E-2f));
667 Value cstCephesLogP1 = bcast(
f32Cst(builder, -1.1514610310E-1f));
668 Value cstCephesLogP2 = bcast(
f32Cst(builder, 1.1676998740E-1f));
669 Value cstCephesLogP3 = bcast(
f32Cst(builder, -1.2420140846E-1f));
670 Value cstCephesLogP4 = bcast(
f32Cst(builder, +1.4249322787E-1f));
671 Value cstCephesLogP5 = bcast(
f32Cst(builder, -1.6668057665E-1f));
672 Value cstCephesLogP6 = bcast(
f32Cst(builder, +2.0000714765E-1f));
673 Value cstCephesLogP7 = bcast(
f32Cst(builder, -2.4999993993E-1f));
674 Value cstCephesLogP8 = bcast(
f32Cst(builder, +3.3333331174E-1f));
679 x =
max(builder, x, cstMinNormPos);
682 std::pair<Value, Value> pair =
frexp(builder, x,
true);
684 Value e = pair.second;
694 Value mask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x,
696 Value tmp = builder.
create<arith::SelectOp>(mask, x, cstZero);
698 x = builder.
create<arith::SubFOp>(x, cstOne);
699 e = builder.
create<arith::SubFOp>(
700 e, builder.
create<arith::SelectOp>(mask, cstOne, cstZero));
701 x = builder.
create<arith::AddFOp>(x, tmp);
708 y0 = builder.
create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1);
709 y1 = builder.
create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4);
710 y2 = builder.
create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7);
711 y0 = builder.
create<math::FmaOp>(y0, x, cstCephesLogP2);
712 y1 = builder.
create<math::FmaOp>(y1, x, cstCephesLogP5);
713 y2 = builder.
create<math::FmaOp>(y2, x, cstCephesLogP8);
714 y0 = builder.
create<math::FmaOp>(y0, x3, y1);
715 y0 = builder.
create<math::FmaOp>(y0, x3, y2);
716 y0 = builder.
create<arith::MulFOp>(y0, x3);
718 y0 = builder.
create<math::FmaOp>(cstNegHalf, x2, y0);
719 x = builder.
create<arith::AddFOp>(x, y0);
723 x = builder.
create<math::FmaOp>(x, cstLog2e, e);
726 x = builder.
create<math::FmaOp>(e, cstLn2, x);
729 Value invalidMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::ULT,
731 Value zeroMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
733 Value posInfMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
740 Value aproximation = builder.
create<arith::SelectOp>(
741 zeroMask, cstMinusInf,
742 builder.
create<arith::SelectOp>(
744 builder.
create<arith::SelectOp>(posInfMask, cstPosInf, x)));
752 struct LogApproximation :
public LogApproximationBase<math::LogOp> {
753 using LogApproximationBase::LogApproximationBase;
755 LogicalResult matchAndRewrite(math::LogOp op,
757 return logMatchAndRewrite(op, rewriter,
false);
763 struct Log2Approximation :
public LogApproximationBase<math::Log2Op> {
764 using LogApproximationBase::LogApproximationBase;
766 LogicalResult matchAndRewrite(math::Log2Op op,
768 return logMatchAndRewrite(op, rewriter,
true);
782 LogicalResult matchAndRewrite(math::Log1pOp op,
789 Log1pApproximation::matchAndRewrite(math::Log1pOp op,
809 Value u = builder.
create<arith::AddFOp>(x, cstOne);
811 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne);
814 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU);
816 x, builder.
create<arith::DivFOp>(
817 logU, builder.
create<arith::SubFOp>(u, cstOne)));
818 Value approximation = builder.
create<arith::SelectOp>(
819 builder.
create<arith::OrIOp>(uSmall, uInf), x, logLarge);
832 struct AsinPolynomialApproximation :
public OpRewritePattern<math::AsinOp> {
836 LogicalResult matchAndRewrite(math::AsinOp op,
841 AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
846 if (!(elementType.
isF32() || elementType.
isF16()))
848 "only f32 and f16 type is supported.");
857 return builder.
create<math::FmaOp>(a, b, c);
861 return builder.
create<arith::MulFOp>(a, b);
865 return builder.
create<arith::SubFOp>(a, b);
870 auto sqrt = [&](
Value a) ->
Value {
return builder.
create<math::SqrtOp>(a); };
873 return builder.
create<math::CopySignOp>(a, b);
877 return builder.
create<arith::SelectOp>(a, b, c);
881 Value aa = mul(operand, operand);
882 Value opp = sqrt(sub(bcast(
floatCst(builder, 1.0, elementType)), aa));
885 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, aa,
886 bcast(
floatCst(builder, 0.5, elementType)));
888 Value x = sel(gt, opp, abso);
893 Value r = bcast(
floatCst(builder, 5.5579749017470502e-2, elementType));
894 Value t = bcast(
floatCst(builder, -6.2027913464120114e-2, elementType));
896 r = fma(r, q, bcast(
floatCst(builder, 5.4224464349245036e-2, elementType)));
897 t = fma(t, q, bcast(
floatCst(builder, -1.1326992890324464e-2, elementType)));
898 r = fma(r, q, bcast(
floatCst(builder, 1.5268872539397656e-2, elementType)));
899 t = fma(t, q, bcast(
floatCst(builder, 1.0493798473372081e-2, elementType)));
900 r = fma(r, q, bcast(
floatCst(builder, 1.4106045900607047e-2, elementType)));
901 t = fma(t, q, bcast(
floatCst(builder, 1.7339776384962050e-2, elementType)));
902 r = fma(r, q, bcast(
floatCst(builder, 2.2372961589651054e-2, elementType)));
903 t = fma(t, q, bcast(
floatCst(builder, 3.0381912707941005e-2, elementType)));
904 r = fma(r, q, bcast(
floatCst(builder, 4.4642857881094775e-2, elementType)));
905 t = fma(t, q, bcast(
floatCst(builder, 7.4999999991367292e-2, elementType)));
907 r = fma(r, s, bcast(
floatCst(builder, 1.6666666666670193e-1, elementType)));
911 Value rsub = sub(bcast(
floatCst(builder, 1.57079632679, elementType)), r);
912 r = sel(gt, rsub, r);
913 r = scopy(r, operand);
927 struct AcosPolynomialApproximation :
public OpRewritePattern<math::AcosOp> {
931 LogicalResult matchAndRewrite(math::AcosOp op,
936 AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
941 if (!(elementType.
isF32() || elementType.
isF16()))
943 "only f32 and f16 type is supported.");
952 return builder.
create<math::FmaOp>(a, b, c);
956 return builder.
create<arith::MulFOp>(a, b);
959 Value negOperand = builder.
create<arith::NegFOp>(operand);
964 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand, zero);
965 Value r = builder.
create<arith::SelectOp>(selR, negOperand, operand);
966 Value chkConst = bcast(
floatCst(builder, -0.5625, elementType));
968 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, r, chkConst);
971 fma(bcast(
floatCst(builder, 9.3282184640716537e-1, elementType)),
972 bcast(
floatCst(builder, 1.6839188885261840e+0, elementType)),
973 builder.
create<math::AsinOp>(r));
975 Value falseVal = builder.
create<math::SqrtOp>(fma(half, r, half));
976 falseVal = builder.
create<math::AsinOp>(falseVal);
977 falseVal = mul(bcast(
floatCst(builder, 2.0, elementType)), falseVal);
979 r = builder.
create<arith::SelectOp>(firstPred, trueVal, falseVal);
982 Value greaterThanNegOne =
983 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, negOne);
986 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
988 Value betweenNegOneZero =
989 builder.
create<arith::AndIOp>(greaterThanNegOne, lessThanZero);
991 trueVal = fma(bcast(
floatCst(builder, 1.8656436928143307e+0, elementType)),
992 bcast(
floatCst(builder, 1.6839188885261840e+0, elementType)),
993 builder.
create<arith::NegFOp>(r));
996 builder.
create<arith::SelectOp>(betweenNegOneZero, trueVal, r);
1019 if (!(elementType.
isF32() || elementType.
isF16()))
1021 "only f32 and f16 type is supported.");
1026 return broadcast(builder, value, shape);
1029 const int intervalsCount = 3;
1030 const int polyDegree = 4;
1034 Value pp[intervalsCount][polyDegree + 1];
1035 pp[0][0] = bcast(
floatCst(builder, +0.00000000000000000e+00f, elementType));
1036 pp[0][1] = bcast(
floatCst(builder, +1.12837916222975858e+00f, elementType));
1037 pp[0][2] = bcast(
floatCst(builder, -5.23018562988006470e-01f, elementType));
1038 pp[0][3] = bcast(
floatCst(builder, +2.09741709609267072e-01f, elementType));
1039 pp[0][4] = bcast(
floatCst(builder, +2.58146801602987875e-02f, elementType));
1040 pp[1][0] = bcast(
floatCst(builder, +0.00000000000000000e+00f, elementType));
1041 pp[1][1] = bcast(
floatCst(builder, +1.12750687816789140e+00f, elementType));
1042 pp[1][2] = bcast(
floatCst(builder, -3.64721408487825775e-01f, elementType));
1043 pp[1][3] = bcast(
floatCst(builder, +1.18407396425136952e-01f, elementType));
1044 pp[1][4] = bcast(
floatCst(builder, +3.70645533056476558e-02f, elementType));
1045 pp[2][0] = bcast(
floatCst(builder, -3.30093071049483172e-03f, elementType));
1046 pp[2][1] = bcast(
floatCst(builder, +3.51961938357697011e-03f, elementType));
1047 pp[2][2] = bcast(
floatCst(builder, -1.41373622814988039e-03f, elementType));
1048 pp[2][3] = bcast(
floatCst(builder, +2.53447094961941348e-04f, elementType));
1049 pp[2][4] = bcast(
floatCst(builder, -1.71048029455037401e-05f, elementType));
1051 Value qq[intervalsCount][polyDegree + 1];
1052 qq[0][0] = bcast(
floatCst(builder, +1.000000000000000000e+00f, elementType));
1053 qq[0][1] = bcast(
floatCst(builder, -4.635138185962547255e-01f, elementType));
1054 qq[0][2] = bcast(
floatCst(builder, +5.192301327279782447e-01f, elementType));
1055 qq[0][3] = bcast(
floatCst(builder, -1.318089722204810087e-01f, elementType));
1056 qq[0][4] = bcast(
floatCst(builder, +7.397964654672315005e-02f, elementType));
1057 qq[1][0] = bcast(
floatCst(builder, +1.00000000000000000e+00f, elementType));
1058 qq[1][1] = bcast(
floatCst(builder, -3.27607011824493086e-01f, elementType));
1059 qq[1][2] = bcast(
floatCst(builder, +4.48369090658821977e-01f, elementType));
1060 qq[1][3] = bcast(
floatCst(builder, -8.83462621207857930e-02f, elementType));
1061 qq[1][4] = bcast(
floatCst(builder, +5.72442770283176093e-02f, elementType));
1062 qq[2][0] = bcast(
floatCst(builder, +1.00000000000000000e+00f, elementType));
1063 qq[2][1] = bcast(
floatCst(builder, -2.06069165953913769e+00f, elementType));
1064 qq[2][2] = bcast(
floatCst(builder, +1.62705939945477759e+00f, elementType));
1065 qq[2][3] = bcast(
floatCst(builder, -5.83389859211130017e-01f, elementType));
1066 qq[2][4] = bcast(
floatCst(builder, +8.21908939856640930e-02f, elementType));
1068 Value offsets[intervalsCount];
1069 offsets[0] = bcast(
floatCst(builder, 0.0f, elementType));
1070 offsets[1] = bcast(
floatCst(builder, 0.0f, elementType));
1071 offsets[2] = bcast(
floatCst(builder, 1.0f, elementType));
1073 Value bounds[intervalsCount];
1074 bounds[0] = bcast(
floatCst(builder, 0.8f, elementType));
1075 bounds[1] = bcast(
floatCst(builder, 2.0f, elementType));
1076 bounds[2] = bcast(
floatCst(builder, 3.75f, elementType));
1078 Value isNegativeArg =
1079 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
1080 Value negArg = builder.
create<arith::NegFOp>(operand);
1081 Value x = builder.
create<arith::SelectOp>(isNegativeArg, negArg, operand);
1083 Value offset = offsets[0];
1084 Value p[polyDegree + 1];
1085 Value q[polyDegree + 1];
1086 for (
int i = 0; i <= polyDegree; ++i) {
1092 Value isLessThanBound[intervalsCount];
1093 for (
int j = 0;
j < intervalsCount - 1; ++
j) {
1094 isLessThanBound[
j] =
1095 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[
j]);
1096 for (
int i = 0; i <= polyDegree; ++i) {
1097 p[i] = builder.
create<arith::SelectOp>(isLessThanBound[
j], p[i],
1099 q[i] = builder.
create<arith::SelectOp>(isLessThanBound[
j], q[i],
1102 offset = builder.
create<arith::SelectOp>(isLessThanBound[
j], offset,
1105 isLessThanBound[intervalsCount - 1] = builder.
create<arith::CmpFOp>(
1106 arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
1108 Value pPoly = makePolynomialCalculation(builder, p, x);
1109 Value qPoly = makePolynomialCalculation(builder, q, x);
1110 Value rationalPoly = builder.
create<arith::DivFOp>(pPoly, qPoly);
1111 Value formula = builder.
create<arith::AddFOp>(offset, rationalPoly);
1112 formula = builder.
create<arith::SelectOp>(isLessThanBound[intervalsCount - 1],
1116 Value negFormula = builder.
create<arith::NegFOp>(formula);
1118 builder.
create<arith::SelectOp>(isNegativeArg, negFormula, formula);
1132 Value value,
float lowerBound,
float upperBound) {
1133 assert(!std::isnan(lowerBound));
1134 assert(!std::isnan(upperBound));
1137 return broadcast(builder, value, shape);
1140 auto selectCmp = [&builder](
auto pred,
Value value,
Value bound) {
1141 return builder.
create<arith::SelectOp>(
1142 builder.
create<arith::CmpFOp>(pred, value, bound), value, bound);
1148 value = selectCmp(arith::CmpFPredicate::UGE, value,
1149 bcast(
f32Cst(builder, lowerBound)));
1150 value = selectCmp(arith::CmpFPredicate::ULE, value,
1151 bcast(
f32Cst(builder, upperBound)));
1159 LogicalResult matchAndRewrite(math::ExpOp op,
1164 ExpApproximation::matchAndRewrite(math::ExpOp op,
1168 if (!elementTy.isF32())
1174 return builder.
create<arith::AddFOp>(a, b);
1177 return broadcast(builder, value, shape);
1181 return builder.
create<math::FmaOp>(a, b, c);
1184 return builder.
create<arith::MulFOp>(a, b);
1212 Value cstLog2ef = bcast(
f32Cst(builder, 1.44269504088896341f));
1214 Value cstExpC1 = bcast(
f32Cst(builder, -0.693359375f));
1215 Value cstExpC2 = bcast(
f32Cst(builder, 2.12194440e-4f));
1216 Value cstExpP0 = bcast(
f32Cst(builder, 1.9875691500E-4f));
1217 Value cstExpP1 = bcast(
f32Cst(builder, 1.3981999507E-3f));
1218 Value cstExpP2 = bcast(
f32Cst(builder, 8.3334519073E-3f));
1219 Value cstExpP3 = bcast(
f32Cst(builder, 4.1665795894E-2f));
1220 Value cstExpP4 = bcast(
f32Cst(builder, 1.6666665459E-1f));
1221 Value cstExpP5 = bcast(
f32Cst(builder, 5.0000001201E-1f));
1229 x = clampWithNormals(builder, shape, x, -87.8f, 88.8f);
1230 Value n =
floor(fmla(x, cstLog2ef, cstHalf));
1271 n = clampWithNormals(builder, shape, n, -127.0f, 127.0f);
1274 x = fmla(cstExpC1, n, x);
1275 x = fmla(cstExpC2, n, x);
1278 Value z = fmla(x, cstExpP0, cstExpP1);
1279 z = fmla(z, x, cstExpP2);
1280 z = fmla(z, x, cstExpP3);
1281 z = fmla(z, x, cstExpP4);
1282 z = fmla(z, x, cstExpP5);
1283 z = fmla(z, mul(x, x), x);
1288 Value nI32 = builder.
create<arith::FPToSIOp>(i32Vec, n);
1294 Value ret = mul(z, pow2);
1297 return mlir::success();
1312 LogicalResult matchAndRewrite(math::ExpM1Op op,
1318 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
1327 return broadcast(builder, value, shape);
1338 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
1339 Value uMinusOne = builder.
create<arith::SubFOp>(u, cstOne);
1340 Value uMinusOneEqNegOne = builder.
create<arith::CmpFOp>(
1341 arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
1347 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u);
1351 uMinusOne, builder.
create<arith::DivFOp>(x, logU));
1352 expm1 = builder.
create<arith::SelectOp>(isInf, u, expm1);
1353 Value approximation = builder.
create<arith::SelectOp>(
1355 builder.
create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
1366 template <
bool isSine,
typename OpTy>
1371 LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter)
const final;
1375 #define TWO_OVER_PI \
1376 0.6366197723675813430755350534900574481378385829618257949906693762L
1378 1.5707963267948966192313216916397514420985846996875529104874722961L
1383 template <
bool isSine,
typename OpTy>
1384 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
1387 llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
1388 "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
1397 return broadcast(builder, value, shape);
1400 return builder.
create<arith::MulFOp>(a, b);
1403 return builder.
create<arith::SubFOp>(a, b);
1408 auto fPToSingedInteger = [&](
Value a) ->
Value {
1409 return builder.
create<arith::FPToSIOp>(i32Vec, a);
1413 return builder.
create<arith::AndIOp>(a, bcast(
i32Cst(builder, 3)));
1417 return builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b);
1421 return builder.
create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b);
1425 return builder.
create<arith::SelectOp>(cond, t, f);
1429 return builder.
create<math::FmaOp>(a, b, c);
1433 return builder.
create<arith::OrIOp>(a, b);
1443 Value y = sub(x, mul(k, piOverTwo));
1446 Value cstNegativeOne = bcast(
f32Cst(builder, -1.0));
1448 Value cstSC2 = bcast(
f32Cst(builder, -0.16666667163372039794921875f));
1449 Value cstSC4 = bcast(
f32Cst(builder, 8.333347737789154052734375e-3f));
1450 Value cstSC6 = bcast(
f32Cst(builder, -1.9842604524455964565277099609375e-4f));
1452 bcast(
f32Cst(builder, 2.760012648650445044040679931640625e-6f));
1454 bcast(
f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f));
1457 Value cstCC4 = bcast(
f32Cst(builder, 4.166664183139801025390625e-2f));
1458 Value cstCC6 = bcast(
f32Cst(builder, -1.388833043165504932403564453125e-3f));
1459 Value cstCC8 = bcast(
f32Cst(builder, 2.47562347794882953166961669921875e-5f));
1461 bcast(
f32Cst(builder, -2.59630184018533327616751194000244140625e-7f));
1463 Value kMod4 = modulo4(fPToSingedInteger(k));
1465 Value kR0 = isEqualTo(kMod4, bcast(
i32Cst(builder, 0)));
1466 Value kR1 = isEqualTo(kMod4, bcast(
i32Cst(builder, 1)));
1467 Value kR2 = isEqualTo(kMod4, bcast(
i32Cst(builder, 2)));
1468 Value kR3 = isEqualTo(kMod4, bcast(
i32Cst(builder, 3)));
1470 Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
1471 Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(
i32Cst(builder, 1)))
1472 : bitwiseOr(kR1, kR2);
1474 Value y2 = mul(y, y);
1476 Value base = select(sinuseCos, cstOne, y);
1477 Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
1478 Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
1479 Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
1480 Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
1481 Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
1483 Value v1 = fmla(y2, cstC10, cstC8);
1484 Value v2 = fmla(y2, v1, cstC6);
1485 Value v3 = fmla(y2, v2, cstC4);
1486 Value v4 = fmla(y2, v3, cstC2);
1487 Value v5 = fmla(y2, v4, cstOne);
1488 Value v6 = mul(base, v5);
1490 Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6);
1505 LogicalResult matchAndRewrite(math::CbrtOp op,
1513 CbrtApproximation::matchAndRewrite(math::CbrtOp op,
1529 auto bconst = [&](TypedAttr attr) ->
Value {
1530 Value value = b.create<arith::ConstantOp>(attr);
1535 Value intTwo = bconst(b.getI32IntegerAttr(2));
1536 Value intFour = bconst(b.getI32IntegerAttr(4));
1537 Value intEight = bconst(b.getI32IntegerAttr(8));
1538 Value intMagic = bconst(b.getI32IntegerAttr(0x2a5137a0));
1539 Value fpThird = bconst(b.getF32FloatAttr(0.33333333f));
1540 Value fpTwo = bconst(b.getF32FloatAttr(2.0f));
1541 Value fpZero = bconst(b.getF32FloatAttr(0.0f));
1547 Value absValue = b.create<math::AbsFOp>(operand);
1548 Value intValue = b.create<arith::BitcastOp>(intTy, absValue);
1549 Value divideBy4 = b.create<arith::ShRSIOp>(intValue, intTwo);
1550 Value divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
1551 intValue = b.create<arith::AddIOp>(divideBy4, divideBy16);
1554 divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
1555 intValue = b.create<arith::AddIOp>(intValue, divideBy16);
1558 Value divideBy256 = b.create<arith::ShRSIOp>(intValue, intEight);
1559 intValue = b.create<arith::AddIOp>(intValue, divideBy256);
1562 intValue = b.create<arith::AddIOp>(intValue, intMagic);
1566 Value floatValue = b.create<arith::BitcastOp>(floatTy, intValue);
1567 Value squared = b.create<arith::MulFOp>(floatValue, floatValue);
1568 Value mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
1569 Value divSquared = b.create<arith::DivFOp>(absValue, squared);
1570 floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
1571 floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
1574 squared = b.create<arith::MulFOp>(floatValue, floatValue);
1575 mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
1576 divSquared = b.create<arith::DivFOp>(absValue, squared);
1577 floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
1578 floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
1582 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absValue, fpZero);
1583 floatValue = b.create<arith::SelectOp>(isZero, fpZero, floatValue);
1584 floatValue = b.create<math::CopySignOp>(floatValue, operand);
1598 LogicalResult matchAndRewrite(math::RsqrtOp op,
1604 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
1612 if (shape.
empty() || shape.
sizes.back() % 8 != 0)
1617 return broadcast(builder, value, shape);
1621 Value cstOnePointFive = bcast(
f32Cst(builder, 1.5f));
1630 arith::CmpFPredicate::OLT, op.
getOperand(), cstMinNormPos);
1631 Value infMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
1633 Value notNormalFiniteMask = builder.
create<arith::OrIOp>(ltMinMask, infMask);
1638 return builder.create<x86vector::RsqrtOp>(operands);
1645 Value inner = builder.
create<arith::MulFOp>(negHalf, yApprox);
1646 Value fma = builder.
create<math::FmaOp>(yApprox, inner, cstOnePointFive);
1647 Value yNewton = builder.
create<arith::MulFOp>(yApprox, fma);
1655 builder.
create<arith::SelectOp>(notNormalFiniteMask, yApprox, yNewton);
1678 .
add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
1679 ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
1680 ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
1681 ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
1682 ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
1683 ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
1687 .
add<AtanApproximation, Atan2Approximation, TanhApproximation,
1688 LogApproximation, Log2Approximation, Log1pApproximation,
1690 AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1691 CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
1692 SinAndCosApproximation<false, math::CosOp>>(patterns.
getContext());
1694 patterns.
add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
static llvm::ManagedStatic< PassManagerOptions > options
static std::pair< Value, Value > frexp(ImplicitLocOpBuilder &builder, Value arg, bool isPositive=false)
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg)
static Type broadcast(Type type, VectorShape shape)
static Value floatCst(ImplicitLocOpBuilder &builder, float value, Type elementType)
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter)
static VectorShape vectorShape(Type type)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits)
static Value f32Cst(ImplicitLocOpBuilder &builder, double value)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getI32IntegerAttr(int32_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
TypedAttr getZeroAttr(Type type)
FloatAttr getF32FloatAttr(float value)
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DynamicAPInt floor(const Fraction &f)
Fraction abs(const Fraction &f)
Include the generated interface declarations.
void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns)
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
int64_t computeMaxLinearIndex(ArrayRef< int64_t > basis)
Return the number of elements of basis (i.e.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns, const MathPolynomialApproximationOptions &options={})
ArrayRef< int64_t > sizes
ArrayRef< bool > scalableFlags
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(math::ErfOp op, PatternRewriter &rewriter) const final
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.