33 #include "llvm/ADT/ArrayRef.h" 34 #include "llvm/ADT/STLExtras.h" 57 assert(!type.
isa<VectorType>() &&
"must be scalar type");
58 return !shape.empty() ? VectorType::get(shape, type) : type;
64 assert(!value.
getType().
isa<VectorType>() &&
"must be scalar value");
66 return !shape.empty() ? builder.
create<BroadcastOp>(type,
value) : value;
92 assert(!operands.empty() &&
"operands must be not empty");
93 assert(vectorWidth > 0 &&
"vector width must be larger than 0");
95 VectorType inputType = operands[0].
getType().cast<VectorType>();
100 if (inputShape == llvm::makeArrayRef(vectorWidth))
101 return compute(operands);
105 int64_t innerDim = inputShape.back();
106 int64_t expansionDim = innerDim / vectorWidth;
107 assert((innerDim % vectorWidth == 0) &&
"invalid inner dimension size");
114 if (expansionDim > 1) {
116 expandedShape.insert(expandedShape.end() - 1, expansionDim);
117 expandedShape.back() = vectorWidth;
119 for (
unsigned i = 0; i < operands.size(); ++i) {
120 auto operand = operands[i];
122 auto expandedType = VectorType::get(expandedShape, eltType);
123 expandedOperands[i] =
124 builder.
create<vector::ShapeCastOp>(expandedType, operand);
138 for (int64_t i = 0; i < maxLinearIndex; ++i) {
143 extracted[tuple.index()] =
144 builder.
create<vector::ExtractOp>(tuple.value(), offsets);
146 results[i] = compute(extracted);
151 Type resultExpandedType = VectorType::get(expandedShape, resultEltType);
153 resultExpandedType, builder.
getZeroAttr(resultExpandedType));
155 for (int64_t i = 0; i < maxLinearIndex; ++i)
156 result = builder.
create<vector::InsertOp>(results[i], result,
160 return builder.
create<vector::ShapeCastOp>(
161 VectorType::get(inputShape, resultEltType), result);
177 Value i32Value =
i32Cst(builder, static_cast<int32_t>(bits));
187 return builder.
create<arith::SelectOp>(
188 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::ULT,
value, bound),
194 return builder.
create<arith::SelectOp>(
195 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::UGT,
value, bound),
202 return max(builder,
min(builder, value, upperBound), lowerBound);
208 bool isPositive =
false) {
225 Value i32Half = builder.
create<arith::BitcastOp>(i32, cstHalf);
226 Value i32InvMantMask = builder.
create<arith::BitcastOp>(i32, cstInvMantMask);
227 Value i32Arg = builder.
create<arith::BitcastOp>(i32Vec, arg);
230 Value tmp0 = builder.
create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask));
231 Value tmp1 = builder.
create<arith::OrIOp>(tmp0, bcast(i32Half));
232 Value normalizedFraction = builder.
create<arith::BitcastOp>(f32Vec, tmp1);
235 Value arg0 = isPositive ? arg : builder.
create<math::AbsOp>(arg);
236 Value biasedExponentBits = builder.
create<arith::ShRUIOp>(
237 builder.
create<arith::BitcastOp>(i32Vec, arg0),
238 bcast(
i32Cst(builder, 23)));
239 Value biasedExponent =
240 builder.
create<arith::SIToFPOp>(f32Vec, biasedExponentBits);
242 builder.
create<arith::SubFOp>(biasedExponent, bcast(cst126f));
244 return {normalizedFraction, exponent};
258 auto exponetBitLocation = bcast(
i32Cst(builder, 23));
260 auto bias = bcast(
i32Cst(builder, 127));
262 Value biasedArg = builder.
create<arith::AddIOp>(arg, bias);
264 builder.
create<arith::ShLIOp>(biasedArg, exponetBitLocation);
265 Value exp2ValueF32 = builder.
create<arith::BitcastOp>(f32Vec, exp2ValueInt);
279 if (coeffs.size() == 1)
282 Value res = builder.
create<math::FmaOp>(x, coeffs[coeffs.size() - 1],
283 coeffs[coeffs.size() - 2]);
284 for (
auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) {
285 res = builder.
create<math::FmaOp>(x, res, coeffs[i]);
295 template <
typename T>
313 if (
auto shaped = origType.
dyn_cast<ShapedType>()) {
314 newType = shaped.clone(rewriter.
getF32Type());
319 "unable to find F32 equivalent type");
325 operands.push_back(rewriter.
create<arith::ExtFOp>(loc, newType, operand));
326 auto result = rewriter.
create<math::Atan2Op>(loc, newType, operands);
337 template <
typename T>
343 T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(),
344 "requires same operands and result types");
345 return insertCasts<T>(op, rewriter);
365 AtanApproximation::matchAndRewrite(math::AtanOp op,
367 auto operand = op.getOperand();
381 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT,
abs, reciprocal);
392 p = builder.
create<math::FmaOp>(x, p, n3);
393 p = builder.
create<math::FmaOp>(x, p, n4);
394 p = builder.
create<arith::MulFOp>(x, p);
397 auto halfPi =
broadcast(builder,
f32Cst(builder, 1.57079632679f), shape);
398 Value sub = builder.
create<arith::SubFOp>(halfPi, p);
421 Atan2Approximation::matchAndRewrite(math::Atan2Op op,
423 auto y = op.getOperand(0);
424 auto x = op.getOperand(1);
432 auto div = builder.
create<arith::DivFOp>(y, x);
433 auto atan = builder.
create<math::AtanOp>(div);
438 auto addPi = builder.
create<arith::AddFOp>(atan, pi);
439 auto subPi = builder.
create<arith::SubFOp>(atan, pi);
441 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, atan, zero);
442 auto flippedAtan = builder.
create<arith::SelectOp>(atanGt, subPi, addPi);
445 auto xGt = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zero);
446 Value result = builder.
create<arith::SelectOp>(xGt, atan, flippedAtan);
450 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, x, zero);
451 Value yGt = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, y, zero);
452 Value isHalfPi = builder.
create<arith::AndIOp>(xZero, yGt);
453 auto halfPi =
broadcast(builder,
f32Cst(builder, 1.57079632679f), shape);
454 result = builder.
create<arith::SelectOp>(isHalfPi, halfPi, result);
457 Value yLt = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, y, zero);
458 Value isNegativeHalfPiPi = builder.
create<arith::AndIOp>(xZero, yLt);
459 auto negativeHalfPiPi =
461 result = builder.
create<arith::SelectOp>(isNegativeHalfPiPi, negativeHalfPiPi,
466 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, y, zero);
467 Value isNan = builder.
create<arith::AndIOp>(xZero, yZero);
469 result = builder.
create<arith::SelectOp>(isNan, cstNan, result);
490 TanhApproximation::matchAndRewrite(math::TanhOp op,
503 Value minusClamp = bcast(
f32Cst(builder, -7.99881172180175781f));
504 Value plusClamp = bcast(
f32Cst(builder, 7.99881172180175781f));
505 Value x =
clamp(builder, op.getOperand(), minusClamp, plusClamp);
510 arith::CmpFPredicate::OLT, builder.
create<math::AbsOp>(op.getOperand()),
514 Value alpha1 = bcast(
f32Cst(builder, 4.89352455891786e-03f));
515 Value alpha3 = bcast(
f32Cst(builder, 6.37261928875436e-04f));
516 Value alpha5 = bcast(
f32Cst(builder, 1.48572235717979e-05f));
517 Value alpha7 = bcast(
f32Cst(builder, 5.12229709037114e-08f));
518 Value alpha9 = bcast(
f32Cst(builder, -8.60467152213735e-11f));
519 Value alpha11 = bcast(
f32Cst(builder, 2.00018790482477e-13f));
520 Value alpha13 = bcast(
f32Cst(builder, -2.76076847742355e-16f));
523 Value beta0 = bcast(
f32Cst(builder, 4.89352518554385e-03f));
524 Value beta2 = bcast(
f32Cst(builder, 2.26843463243900e-03f));
525 Value beta4 = bcast(
f32Cst(builder, 1.18534705686654e-04f));
526 Value beta6 = bcast(
f32Cst(builder, 1.19825839466702e-06f));
532 Value p = builder.
create<math::FmaOp>(x2, alpha13, alpha11);
533 p = builder.
create<math::FmaOp>(x2, p, alpha9);
534 p = builder.
create<math::FmaOp>(x2, p, alpha7);
535 p = builder.
create<math::FmaOp>(x2, p, alpha5);
536 p = builder.
create<math::FmaOp>(x2, p, alpha3);
537 p = builder.
create<math::FmaOp>(x2, p, alpha1);
538 p = builder.
create<arith::MulFOp>(x, p);
541 Value q = builder.
create<math::FmaOp>(x2, beta6, beta4);
542 q = builder.
create<math::FmaOp>(x2, q, beta2);
543 q = builder.
create<math::FmaOp>(x2, q, beta0);
547 tinyMask, x, builder.
create<arith::DivFOp>(p, q));
555 0.693147180559945309417232121458176568075500134360255254120680009493393621L 556 #define LOG2E_VALUE \ 557 1.442695040888963407359924681001892137426645954152985934135449406931109219L 564 template <
typename Op>
576 template <
typename Op>
601 Value cstCephesSQRTHF = bcast(
f32Cst(builder, 0.707106781186547524f));
602 Value cstCephesLogP0 = bcast(
f32Cst(builder, 7.0376836292E-2f));
603 Value cstCephesLogP1 = bcast(
f32Cst(builder, -1.1514610310E-1f));
604 Value cstCephesLogP2 = bcast(
f32Cst(builder, 1.1676998740E-1f));
605 Value cstCephesLogP3 = bcast(
f32Cst(builder, -1.2420140846E-1f));
606 Value cstCephesLogP4 = bcast(
f32Cst(builder, +1.4249322787E-1f));
607 Value cstCephesLogP5 = bcast(
f32Cst(builder, -1.6668057665E-1f));
608 Value cstCephesLogP6 = bcast(
f32Cst(builder, +2.0000714765E-1f));
609 Value cstCephesLogP7 = bcast(
f32Cst(builder, -2.4999993993E-1f));
610 Value cstCephesLogP8 = bcast(
f32Cst(builder, +3.3333331174E-1f));
612 Value x = op.getOperand();
615 x =
max(builder, x, cstMinNormPos);
618 std::pair<Value, Value> pair =
frexp(builder, x,
true);
620 Value e = pair.second;
630 Value mask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x,
632 Value tmp = builder.
create<arith::SelectOp>(mask, x, cstZero);
634 x = builder.
create<arith::SubFOp>(x, cstOne);
635 e = builder.
create<arith::SubFOp>(
636 e, builder.
create<arith::SelectOp>(mask, cstOne, cstZero));
637 x = builder.
create<arith::AddFOp>(x, tmp);
644 y0 = builder.
create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1);
645 y1 = builder.
create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4);
646 y2 = builder.
create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7);
647 y0 = builder.
create<math::FmaOp>(y0, x, cstCephesLogP2);
648 y1 = builder.
create<math::FmaOp>(y1, x, cstCephesLogP5);
649 y2 = builder.
create<math::FmaOp>(y2, x, cstCephesLogP8);
650 y0 = builder.
create<math::FmaOp>(y0, x3, y1);
651 y0 = builder.
create<math::FmaOp>(y0, x3, y2);
652 y0 = builder.
create<arith::MulFOp>(y0, x3);
654 y0 = builder.
create<math::FmaOp>(cstNegHalf, x2, y0);
655 x = builder.
create<arith::AddFOp>(x, y0);
659 x = builder.
create<math::FmaOp>(x, cstLog2e, e);
662 x = builder.
create<math::FmaOp>(e, cstLn2, x);
665 Value invalidMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::ULT,
666 op.getOperand(), cstZero);
667 Value zeroMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
668 op.getOperand(), cstZero);
669 Value posInfMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
670 op.getOperand(), cstPosInf);
676 Value aproximation = builder.
create<arith::SelectOp>(
677 zeroMask, cstMinusInf,
678 builder.
create<arith::SelectOp>(
680 builder.
create<arith::SelectOp>(posInfMask, cstPosInf, x)));
688 struct LogApproximation :
public LogApproximationBase<math::LogOp> {
689 using LogApproximationBase::LogApproximationBase;
693 return logMatchAndRewrite(op, rewriter,
false);
699 struct Log2Approximation :
public LogApproximationBase<math::Log2Op> {
700 using LogApproximationBase::LogApproximationBase;
704 return logMatchAndRewrite(op, rewriter,
true);
725 Log1pApproximation::matchAndRewrite(math::Log1pOp op,
744 Value x = op.getOperand();
745 Value u = builder.
create<arith::AddFOp>(x, cstOne);
747 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne);
750 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU);
752 x, builder.
create<arith::DivFOp>(
753 logU, builder.
create<arith::SubFOp>(u, cstOne)));
754 Value approximation = builder.
create<arith::SelectOp>(
755 builder.
create<arith::OrIOp>(uSmall, uInf), x, logLarge);
784 const int intervalsCount = 3;
785 const int polyDegree = 4;
789 Value pp[intervalsCount][polyDegree + 1];
790 pp[0][0] = bcast(
f32Cst(builder, +0.00000000000000000e+00f));
791 pp[0][1] = bcast(
f32Cst(builder, +1.12837916222975858e+00f));
792 pp[0][2] = bcast(
f32Cst(builder, -5.23018562988006470e-01f));
793 pp[0][3] = bcast(
f32Cst(builder, +2.09741709609267072e-01f));
794 pp[0][4] = bcast(
f32Cst(builder, +2.58146801602987875e-02f));
795 pp[1][0] = bcast(
f32Cst(builder, +0.00000000000000000e+00f));
796 pp[1][1] = bcast(
f32Cst(builder, +1.12750687816789140e+00f));
797 pp[1][2] = bcast(
f32Cst(builder, -3.64721408487825775e-01f));
798 pp[1][3] = bcast(
f32Cst(builder, +1.18407396425136952e-01f));
799 pp[1][4] = bcast(
f32Cst(builder, +3.70645533056476558e-02f));
800 pp[2][0] = bcast(
f32Cst(builder, -3.30093071049483172e-03f));
801 pp[2][1] = bcast(
f32Cst(builder, +3.51961938357697011e-03f));
802 pp[2][2] = bcast(
f32Cst(builder, -1.41373622814988039e-03f));
803 pp[2][3] = bcast(
f32Cst(builder, +2.53447094961941348e-04f));
804 pp[2][4] = bcast(
f32Cst(builder, -1.71048029455037401e-05f));
806 Value qq[intervalsCount][polyDegree + 1];
807 qq[0][0] = bcast(
f32Cst(builder, +1.000000000000000000e+00f));
808 qq[0][1] = bcast(
f32Cst(builder, -4.635138185962547255e-01f));
809 qq[0][2] = bcast(
f32Cst(builder, +5.192301327279782447e-01f));
810 qq[0][3] = bcast(
f32Cst(builder, -1.318089722204810087e-01f));
811 qq[0][4] = bcast(
f32Cst(builder, +7.397964654672315005e-02f));
812 qq[1][0] = bcast(
f32Cst(builder, +1.00000000000000000e+00f));
813 qq[1][1] = bcast(
f32Cst(builder, -3.27607011824493086e-01f));
814 qq[1][2] = bcast(
f32Cst(builder, +4.48369090658821977e-01f));
815 qq[1][3] = bcast(
f32Cst(builder, -8.83462621207857930e-02f));
816 qq[1][4] = bcast(
f32Cst(builder, +5.72442770283176093e-02f));
817 qq[2][0] = bcast(
f32Cst(builder, +1.00000000000000000e+00f));
818 qq[2][1] = bcast(
f32Cst(builder, -2.06069165953913769e+00f));
819 qq[2][2] = bcast(
f32Cst(builder, +1.62705939945477759e+00f));
820 qq[2][3] = bcast(
f32Cst(builder, -5.83389859211130017e-01f));
821 qq[2][4] = bcast(
f32Cst(builder, +8.21908939856640930e-02f));
823 Value offsets[intervalsCount];
824 offsets[0] = bcast(
f32Cst(builder, 0.0f));
825 offsets[1] = bcast(
f32Cst(builder, 0.0f));
826 offsets[2] = bcast(
f32Cst(builder, 1.0f));
828 Value bounds[intervalsCount];
829 bounds[0] = bcast(
f32Cst(builder, 0.8f));
830 bounds[1] = bcast(
f32Cst(builder, 2.0f));
831 bounds[2] = bcast(
f32Cst(builder, 3.75f));
833 Value isNegativeArg = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT,
834 op.getOperand(), zero);
835 Value negArg = builder.
create<arith::NegFOp>(op.getOperand());
837 builder.
create<arith::SelectOp>(isNegativeArg, negArg, op.getOperand());
839 Value offset = offsets[0];
840 Value p[polyDegree + 1];
841 Value q[polyDegree + 1];
842 for (
int i = 0; i <= polyDegree; ++i) {
848 Value isLessThanBound[intervalsCount];
849 for (
int j = 0;
j < intervalsCount - 1; ++
j) {
851 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[
j]);
852 for (
int i = 0; i <= polyDegree; ++i) {
853 p[i] = builder.
create<arith::SelectOp>(isLessThanBound[
j], p[i],
855 q[i] = builder.
create<arith::SelectOp>(isLessThanBound[
j], q[i],
858 offset = builder.
create<arith::SelectOp>(isLessThanBound[
j], offset,
861 isLessThanBound[intervalsCount - 1] = builder.
create<arith::CmpFOp>(
862 arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
864 Value pPoly = makePolynomialCalculation(builder, p, x);
865 Value qPoly = makePolynomialCalculation(builder, q, x);
866 Value rationalPoly = builder.
create<arith::DivFOp>(pPoly, qPoly);
867 Value formula = builder.
create<arith::AddFOp>(offset, rationalPoly);
868 formula = builder.
create<arith::SelectOp>(isLessThanBound[intervalsCount - 1],
872 Value negFormula = builder.
create<arith::NegFOp>(formula);
874 builder.
create<arith::SelectOp>(isNegativeArg, negFormula, formula);
900 ExpApproximation::matchAndRewrite(math::ExpOp op,
915 return builder.
create<math::FmaOp>(a, b, c);
918 return builder.
create<arith::MulFOp>(a, b);
921 return builder.
create<arith::SubFOp>(a, b);
929 Value cstCephesExpP0 = bcast(
f32Cst(builder, 1.0));
930 Value cstCephesExpP1 = bcast(
f32Cst(builder, 1.0));
931 Value cstCephesExpP2 = bcast(
f32Cst(builder, 0.49970514590562437052f));
932 Value cstCephesExpP3 = bcast(
f32Cst(builder, 0.16873890085469545053f));
933 Value cstCephesExpP4 = bcast(
f32Cst(builder, 0.03668965196652099192f));
934 Value cstCephesExpP5 = bcast(
f32Cst(builder, 0.01314350012789660196f));
936 Value x = op.getOperand();
938 Value isNan = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, x, x);
941 Value xL2Inv = mul(x, cstLog2E);
943 Value kLn2 = mul(kF32, cstLn2);
944 Value y = sub(x, kLn2);
948 Value y2 = mul(y, y);
949 Value y4 = mul(y2, y2);
951 Value q0 = fmla(cstCephesExpP1, y, cstCephesExpP0);
952 Value q1 = fmla(cstCephesExpP3, y, cstCephesExpP2);
953 Value q2 = fmla(cstCephesExpP5, y, cstCephesExpP4);
954 Value expY = fmla(q1, y2, q0);
955 expY = fmla(q2, y4, expY);
960 Value k = builder.
create<arith::FPToSIOp>(i32Vec, kF32);
964 expY = mul(expY, exp2KValue);
973 auto constPosInfinity =
974 bcast(
f32Cst(builder, std::numeric_limits<float>::infinity()));
975 auto constNegIfinity =
976 bcast(
f32Cst(builder, -std::numeric_limits<float>::infinity()));
982 builder.
create<arith::CmpIOp>(arith::CmpIPredicate::sle, k, kMaxConst);
984 builder.
create<arith::CmpIOp>(arith::CmpIPredicate::sge, k, kMaxNegConst);
986 Value isNegInfinityX = builder.
create<arith::CmpFOp>(
987 arith::CmpFPredicate::OEQ, x, constNegIfinity);
988 Value isPosInfinityX = builder.
create<arith::CmpFOp>(
989 arith::CmpFPredicate::OEQ, x, constPosInfinity);
991 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zerof32Const);
992 Value isComputable = builder.
create<arith::AndIOp>(rightBound, leftBound);
994 expY = builder.
create<arith::SelectOp>(
996 builder.
create<arith::SelectOp>(
997 isNegInfinityX, zerof32Const,
998 builder.
create<arith::SelectOp>(
999 isPosInfinityX, constPosInfinity,
1000 builder.
create<arith::SelectOp>(
1002 builder.
create<arith::SelectOp>(isPostiveX, constPosInfinity,
1026 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
1043 Value x = op.getOperand();
1046 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
1047 Value uMinusOne = builder.
create<arith::SubFOp>(u, cstOne);
1048 Value uMinusOneEqNegOne = builder.
create<arith::CmpFOp>(
1049 arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
1055 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u);
1059 uMinusOne, builder.
create<arith::DivFOp>(x, logU));
1060 expm1 = builder.
create<arith::SelectOp>(isInf, u, expm1);
1061 Value approximation = builder.
create<arith::SelectOp>(
1063 builder.
create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
1074 template <
bool isSine,
typename OpTy>
1083 #define TWO_OVER_PI \ 1084 0.6366197723675813430755350534900574481378385829618257949906693762L 1086 1.5707963267948966192313216916397514420985846996875529104874722961L 1091 template <
bool isSine,
typename OpTy>
1092 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
1096 "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
1108 return builder.
create<arith::MulFOp>(a, b);
1111 return builder.
create<arith::SubFOp>(a, b);
1116 auto fPToSingedInteger = [&](
Value a) ->
Value {
1117 return builder.
create<arith::FPToSIOp>(i32Vec, a);
1121 return builder.
create<arith::AndIOp>(a, bcast(
i32Cst(builder, 3)));
1125 return builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b);
1129 return builder.
create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b);
1133 return builder.
create<arith::SelectOp>(cond, t, f);
1137 return builder.
create<math::FmaOp>(a, b, c);
1141 return builder.
create<arith::OrIOp>(a, b);
1147 Value x = op.getOperand();
1151 Value y = sub(x, mul(k, piOverTwo));
1154 Value cstNegativeOne = bcast(
f32Cst(builder, -1.0));
1156 Value cstSC2 = bcast(
f32Cst(builder, -0.16666667163372039794921875f));
1157 Value cstSC4 = bcast(
f32Cst(builder, 8.333347737789154052734375e-3f));
1158 Value cstSC6 = bcast(
f32Cst(builder, -1.9842604524455964565277099609375e-4f));
1160 bcast(
f32Cst(builder, 2.760012648650445044040679931640625e-6f));
1162 bcast(
f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f));
1165 Value cstCC4 = bcast(
f32Cst(builder, 4.166664183139801025390625e-2f));
1166 Value cstCC6 = bcast(
f32Cst(builder, -1.388833043165504932403564453125e-3f));
1167 Value cstCC8 = bcast(
f32Cst(builder, 2.47562347794882953166961669921875e-5f));
1169 bcast(
f32Cst(builder, -2.59630184018533327616751194000244140625e-7f));
1171 Value kMod4 = modulo4(fPToSingedInteger(k));
1173 Value kR0 = isEqualTo(kMod4, bcast(
i32Cst(builder, 0)));
1174 Value kR1 = isEqualTo(kMod4, bcast(
i32Cst(builder, 1)));
1175 Value kR2 = isEqualTo(kMod4, bcast(
i32Cst(builder, 2)));
1176 Value kR3 = isEqualTo(kMod4, bcast(
i32Cst(builder, 3)));
1178 Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
1179 Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(
i32Cst(builder, 1)))
1180 : bitwiseOr(kR1, kR2);
1182 Value y2 = mul(y, y);
1184 Value base = select(sinuseCos, cstOne, y);
1185 Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
1186 Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
1187 Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
1188 Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
1189 Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
1191 Value v1 = fmla(y2, cstC10, cstC8);
1192 Value v2 = fmla(y2, v1, cstC6);
1193 Value v3 = fmla(y2, v2, cstC4);
1194 Value v4 = fmla(y2, v3, cstC2);
1195 Value v5 = fmla(y2, v4, cstOne);
1196 Value v6 = mul(base, v5);
1198 Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6);
1219 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
1227 if (shape.empty() || shape.back() % 8 != 0)
1236 Value cstOnePointFive = bcast(
f32Cst(builder, 1.5f));
1240 Value negHalf = builder.
create<arith::MulFOp>(op.getOperand(), cstNegHalf);
1245 arith::CmpFPredicate::OLT, op.getOperand(), cstMinNormPos);
1246 Value infMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
1247 op.getOperand(), cstPosInf);
1248 Value notNormalFiniteMask = builder.
create<arith::OrIOp>(ltMinMask, infMask);
1252 builder, op->getOperands(), 8, [&builder](
ValueRange operands) ->
Value {
1253 return builder.
create<x86vector::RsqrtOp>(operands);
1260 Value inner = builder.
create<arith::MulFOp>(negHalf, yApprox);
1261 Value fma = builder.
create<math::FmaOp>(yApprox, inner, cstOnePointFive);
1262 Value yNewton = builder.
create<arith::MulFOp>(yApprox, fma);
1270 builder.
create<arith::SelectOp>(notNormalFiniteMask, yApprox, yNewton);
1281 patterns.
add<AtanApproximation, Atan2Approximation, TanhApproximation,
1282 LogApproximation, Log2Approximation, Log1pApproximation,
1284 ReuseF32Expansion<math::Atan2Op>,
1285 SinAndCosApproximation<true, math::SinOp>,
1286 SinAndCosApproximation<false, math::CosOp>>(
static Value f32Cst(ImplicitLocOpBuilder &builder, float value)
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
SmallVector< int64_t, 4 > computeStrides(ArrayRef< int64_t > shape, ArrayRef< int64_t > sizes)
Given the shape and sizes of a vector, returns the corresponding strides for each dimension...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Operation is a basic unit of execution within MLIR.
Attribute getZeroAttr(Type type)
operand_range getOperands()
Returns an iterator on the underlying Value's.
static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
operand_type_range getOperandTypes()
int64_t floor(Fraction f)
static ArrayRef< int64_t > vectorShape(Type type)
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
IntegerAttr getI32IntegerAttr(int32_t value)
void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns, const MathPolynomialApproximationOptions &options={})
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an efficient way to signal success or failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg)
LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter)
IntegerType getIntegerType(unsigned width)
SmallVector< int64_t, 4 > delinearize(ArrayRef< int64_t > strides, int64_t linearIndex)
Given the strides together with a linear index in the dimension space, returns the vector-space offse...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Location getLoc()
The source location the operation was defined or derived from.
int64_t computeMaxLinearIndex(ArrayRef< int64_t > basis)
Return the number of elements of basis, 0 if empty.
static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits)
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.
static std::pair< Value, Value > frexp(ImplicitLocOpBuilder &builder, Value arg, bool isPositive=false)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static llvm::ManagedStatic< PassManagerOptions > options
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Type front()
Return first type in the range.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
Type getType() const
Return the type of this value.
Location getLoc()
The source location the operation was defined or derived from.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments...
LogicalResult matchAndRewrite(math::ErfOp op, PatternRewriter &rewriter) const final
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
This provides public APIs that all operations should have.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
static Type broadcast(Type type, ArrayRef< int64_t > shape)
SlowMPInt abs(const SlowMPInt &x)
Redeclarations of friend declarations above to make it discoverable by lookups.
int compare(Fraction x, Fraction y)
Three-way comparison between two fractions.
FloatAttr getF32FloatAttr(float value)
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
This class provides an abstraction over the different types of ranges over Values.
result_type_range getResultTypes()
MLIRContext * getContext() const
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)