33 #include "llvm/ADT/ArrayRef.h"
34 #include "llvm/ADT/STLExtras.h"
43 auto vectorType = type.
dyn_cast<VectorType>();
57 assert(!type.
isa<VectorType>() &&
"must be scalar type");
58 return !shape.empty() ? VectorType::get(shape, type) : type;
64 assert(!value.
getType().
isa<VectorType>() &&
"must be scalar value");
66 return !shape.empty() ? builder.
create<BroadcastOp>(type, value) : value;
92 assert(!operands.empty() &&
"operands must be not empty");
93 assert(vectorWidth > 0 &&
"vector width must be larger than 0");
95 VectorType inputType = operands[0].
getType().cast<VectorType>();
101 return compute(operands);
105 int64_t innerDim = inputShape.back();
106 int64_t expansionDim = innerDim / vectorWidth;
107 assert((innerDim % vectorWidth == 0) &&
"invalid inner dimension size");
114 if (expansionDim > 1) {
116 expandedShape.insert(expandedShape.end() - 1, expansionDim);
117 expandedShape.back() = vectorWidth;
119 for (
unsigned i = 0; i < operands.size(); ++i) {
120 auto operand = operands[i];
122 auto expandedType = VectorType::get(expandedShape, eltType);
123 expandedOperands[i] =
124 builder.
create<vector::ShapeCastOp>(expandedType, operand);
136 for (int64_t i = 0; i < maxIndex; ++i) {
141 extracted[tuple.index()] =
142 builder.
create<vector::ExtractOp>(tuple.value(), offsets);
144 results[i] = compute(extracted);
149 Type resultExpandedType = VectorType::get(expandedShape, resultEltType);
151 resultExpandedType, builder.
getZeroAttr(resultExpandedType));
153 for (int64_t i = 0; i < maxIndex; ++i)
154 result = builder.
create<vector::InsertOp>(results[i], result,
158 return builder.
create<vector::ShapeCastOp>(
159 VectorType::get(inputShape, resultEltType), result);
168 assert((elementType.
isF16() || elementType.
isF32()) &&
169 "x must be f16 or f32 type.");
170 return builder.
create<arith::ConstantOp>(
183 Value i32Value =
i32Cst(builder,
static_cast<int32_t
>(bits));
193 return builder.
create<arith::SelectOp>(
194 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::ULT, value, bound),
200 return builder.
create<arith::SelectOp>(
201 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::UGT, value, bound),
208 return max(builder,
min(builder, value, upperBound), lowerBound);
214 bool isPositive =
false) {
231 Value i32Half = builder.
create<arith::BitcastOp>(i32, cstHalf);
232 Value i32InvMantMask = builder.
create<arith::BitcastOp>(i32, cstInvMantMask);
233 Value i32Arg = builder.
create<arith::BitcastOp>(i32Vec, arg);
236 Value tmp0 = builder.
create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask));
237 Value tmp1 = builder.
create<arith::OrIOp>(tmp0, bcast(i32Half));
238 Value normalizedFraction = builder.
create<arith::BitcastOp>(f32Vec, tmp1);
241 Value arg0 = isPositive ? arg : builder.
create<math::AbsFOp>(arg);
242 Value biasedExponentBits = builder.
create<arith::ShRUIOp>(
243 builder.
create<arith::BitcastOp>(i32Vec, arg0),
244 bcast(
i32Cst(builder, 23)));
245 Value biasedExponent =
246 builder.
create<arith::SIToFPOp>(f32Vec, biasedExponentBits);
248 builder.
create<arith::SubFOp>(biasedExponent, bcast(cst126f));
250 return {normalizedFraction, exponent};
264 auto exponetBitLocation = bcast(
i32Cst(builder, 23));
266 auto bias = bcast(
i32Cst(builder, 127));
268 Value biasedArg = builder.
create<arith::AddIOp>(arg, bias);
270 builder.
create<arith::ShLIOp>(biasedArg, exponetBitLocation);
271 Value exp2ValueF32 = builder.
create<arith::BitcastOp>(f32Vec, exp2ValueInt);
280 assert((elementType.
isF32() || elementType.
isF16()) &&
281 "x must be f32 or f16 type");
287 if (coeffs.size() == 1)
290 Value res = builder.
create<math::FmaOp>(x, coeffs[coeffs.size() - 1],
291 coeffs[coeffs.size() - 2]);
292 for (
auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) {
293 res = builder.
create<math::FmaOp>(x, res, coeffs[i]);
303 template <
typename T>
321 if (
auto shaped = origType.
dyn_cast<ShapedType>()) {
322 newType = shaped.clone(rewriter.
getF32Type());
327 "unable to find F32 equivalent type");
333 operands.push_back(rewriter.
create<arith::ExtFOp>(loc, newType, operand));
346 template <
typename T>
352 T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(),
353 "requires same operands and result types");
354 return insertCasts<T>(op, rewriter);
374 AtanApproximation::matchAndRewrite(math::AtanOp op,
376 auto operand = op.getOperand();
390 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT,
abs, reciprocal);
401 p = builder.
create<math::FmaOp>(x, p, n3);
402 p = builder.
create<math::FmaOp>(x, p, n4);
403 p = builder.
create<arith::MulFOp>(x, p);
406 auto halfPi =
broadcast(builder,
f32Cst(builder, 1.57079632679f), shape);
407 Value sub = builder.
create<arith::SubFOp>(halfPi, p);
430 Atan2Approximation::matchAndRewrite(math::Atan2Op op,
432 auto y = op.getOperand(0);
433 auto x = op.getOperand(1);
441 auto div = builder.
create<arith::DivFOp>(y, x);
442 auto atan = builder.
create<math::AtanOp>(div);
447 auto addPi = builder.
create<arith::AddFOp>(atan, pi);
448 auto subPi = builder.
create<arith::SubFOp>(atan, pi);
450 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, atan, zero);
451 auto flippedAtan = builder.
create<arith::SelectOp>(atanGt, subPi, addPi);
454 auto xGt = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zero);
455 Value result = builder.
create<arith::SelectOp>(xGt, atan, flippedAtan);
459 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, x, zero);
460 Value yGt = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, y, zero);
461 Value isHalfPi = builder.
create<arith::AndIOp>(xZero, yGt);
462 auto halfPi =
broadcast(builder,
f32Cst(builder, 1.57079632679f), shape);
463 result = builder.
create<arith::SelectOp>(isHalfPi, halfPi, result);
466 Value yLt = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, y, zero);
467 Value isNegativeHalfPiPi = builder.
create<arith::AndIOp>(xZero, yLt);
468 auto negativeHalfPiPi =
470 result = builder.
create<arith::SelectOp>(isNegativeHalfPiPi, negativeHalfPiPi,
475 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, y, zero);
476 Value isNan = builder.
create<arith::AndIOp>(xZero, yZero);
478 result = builder.
create<arith::SelectOp>(isNan, cstNan, result);
499 TanhApproximation::matchAndRewrite(math::TanhOp op,
512 Value minusClamp = bcast(
f32Cst(builder, -7.99881172180175781f));
513 Value plusClamp = bcast(
f32Cst(builder, 7.99881172180175781f));
514 Value x =
clamp(builder, op.getOperand(), minusClamp, plusClamp);
519 arith::CmpFPredicate::OLT, builder.
create<math::AbsFOp>(op.getOperand()),
523 Value alpha1 = bcast(
f32Cst(builder, 4.89352455891786e-03f));
524 Value alpha3 = bcast(
f32Cst(builder, 6.37261928875436e-04f));
525 Value alpha5 = bcast(
f32Cst(builder, 1.48572235717979e-05f));
526 Value alpha7 = bcast(
f32Cst(builder, 5.12229709037114e-08f));
527 Value alpha9 = bcast(
f32Cst(builder, -8.60467152213735e-11f));
528 Value alpha11 = bcast(
f32Cst(builder, 2.00018790482477e-13f));
529 Value alpha13 = bcast(
f32Cst(builder, -2.76076847742355e-16f));
532 Value beta0 = bcast(
f32Cst(builder, 4.89352518554385e-03f));
533 Value beta2 = bcast(
f32Cst(builder, 2.26843463243900e-03f));
534 Value beta4 = bcast(
f32Cst(builder, 1.18534705686654e-04f));
535 Value beta6 = bcast(
f32Cst(builder, 1.19825839466702e-06f));
541 Value p = builder.
create<math::FmaOp>(x2, alpha13, alpha11);
542 p = builder.
create<math::FmaOp>(x2, p, alpha9);
543 p = builder.
create<math::FmaOp>(x2, p, alpha7);
544 p = builder.
create<math::FmaOp>(x2, p, alpha5);
545 p = builder.
create<math::FmaOp>(x2, p, alpha3);
546 p = builder.
create<math::FmaOp>(x2, p, alpha1);
547 p = builder.
create<arith::MulFOp>(x, p);
550 Value q = builder.
create<math::FmaOp>(x2, beta6, beta4);
551 q = builder.
create<math::FmaOp>(x2, q, beta2);
552 q = builder.
create<math::FmaOp>(x2, q, beta0);
556 tinyMask, x, builder.
create<arith::DivFOp>(p, q));
564 0.693147180559945309417232121458176568075500134360255254120680009493393621L
565 #define LOG2E_VALUE \
566 1.442695040888963407359924681001892137426645954152985934135449406931109219L
573 template <
typename Op>
585 template <
typename Op>
610 Value cstCephesSQRTHF = bcast(
f32Cst(builder, 0.707106781186547524f));
611 Value cstCephesLogP0 = bcast(
f32Cst(builder, 7.0376836292E-2f));
612 Value cstCephesLogP1 = bcast(
f32Cst(builder, -1.1514610310E-1f));
613 Value cstCephesLogP2 = bcast(
f32Cst(builder, 1.1676998740E-1f));
614 Value cstCephesLogP3 = bcast(
f32Cst(builder, -1.2420140846E-1f));
615 Value cstCephesLogP4 = bcast(
f32Cst(builder, +1.4249322787E-1f));
616 Value cstCephesLogP5 = bcast(
f32Cst(builder, -1.6668057665E-1f));
617 Value cstCephesLogP6 = bcast(
f32Cst(builder, +2.0000714765E-1f));
618 Value cstCephesLogP7 = bcast(
f32Cst(builder, -2.4999993993E-1f));
619 Value cstCephesLogP8 = bcast(
f32Cst(builder, +3.3333331174E-1f));
621 Value x = op.getOperand();
624 x =
max(builder, x, cstMinNormPos);
627 std::pair<Value, Value> pair =
frexp(builder, x,
true);
629 Value e = pair.second;
639 Value mask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x,
641 Value tmp = builder.
create<arith::SelectOp>(mask, x, cstZero);
643 x = builder.
create<arith::SubFOp>(x, cstOne);
644 e = builder.
create<arith::SubFOp>(
645 e, builder.
create<arith::SelectOp>(mask, cstOne, cstZero));
646 x = builder.
create<arith::AddFOp>(x, tmp);
653 y0 = builder.
create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1);
654 y1 = builder.
create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4);
655 y2 = builder.
create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7);
656 y0 = builder.
create<math::FmaOp>(y0, x, cstCephesLogP2);
657 y1 = builder.
create<math::FmaOp>(y1, x, cstCephesLogP5);
658 y2 = builder.
create<math::FmaOp>(y2, x, cstCephesLogP8);
659 y0 = builder.
create<math::FmaOp>(y0, x3, y1);
660 y0 = builder.
create<math::FmaOp>(y0, x3, y2);
661 y0 = builder.
create<arith::MulFOp>(y0, x3);
663 y0 = builder.
create<math::FmaOp>(cstNegHalf, x2, y0);
664 x = builder.
create<arith::AddFOp>(x, y0);
668 x = builder.
create<math::FmaOp>(x, cstLog2e, e);
671 x = builder.
create<math::FmaOp>(e, cstLn2, x);
674 Value invalidMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::ULT,
675 op.getOperand(), cstZero);
676 Value zeroMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
677 op.getOperand(), cstZero);
678 Value posInfMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
679 op.getOperand(), cstPosInf);
685 Value aproximation = builder.
create<arith::SelectOp>(
686 zeroMask, cstMinusInf,
687 builder.
create<arith::SelectOp>(
689 builder.
create<arith::SelectOp>(posInfMask, cstPosInf, x)));
697 struct LogApproximation :
public LogApproximationBase<math::LogOp> {
698 using LogApproximationBase::LogApproximationBase;
702 return logMatchAndRewrite(op, rewriter,
false);
708 struct Log2Approximation :
public LogApproximationBase<math::Log2Op> {
709 using LogApproximationBase::LogApproximationBase;
713 return logMatchAndRewrite(op, rewriter,
true);
734 Log1pApproximation::matchAndRewrite(math::Log1pOp op,
753 Value x = op.getOperand();
754 Value u = builder.
create<arith::AddFOp>(x, cstOne);
756 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne);
759 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU);
761 x, builder.
create<arith::DivFOp>(
762 logU, builder.
create<arith::SubFOp>(u, cstOne)));
763 Value approximation = builder.
create<arith::SelectOp>(
764 builder.
create<arith::OrIOp>(uSmall, uInf), x, logLarge);
783 Value operand = op.getOperand();
786 if (!(elementType.
isF32() || elementType.
isF16()))
788 "only f32 and f16 type is supported.");
796 const int intervalsCount = 3;
797 const int polyDegree = 4;
801 Value pp[intervalsCount][polyDegree + 1];
802 pp[0][0] = bcast(
floatCst(builder, +0.00000000000000000e+00f, elementType));
803 pp[0][1] = bcast(
floatCst(builder, +1.12837916222975858e+00f, elementType));
804 pp[0][2] = bcast(
floatCst(builder, -5.23018562988006470e-01f, elementType));
805 pp[0][3] = bcast(
floatCst(builder, +2.09741709609267072e-01f, elementType));
806 pp[0][4] = bcast(
floatCst(builder, +2.58146801602987875e-02f, elementType));
807 pp[1][0] = bcast(
floatCst(builder, +0.00000000000000000e+00f, elementType));
808 pp[1][1] = bcast(
floatCst(builder, +1.12750687816789140e+00f, elementType));
809 pp[1][2] = bcast(
floatCst(builder, -3.64721408487825775e-01f, elementType));
810 pp[1][3] = bcast(
floatCst(builder, +1.18407396425136952e-01f, elementType));
811 pp[1][4] = bcast(
floatCst(builder, +3.70645533056476558e-02f, elementType));
812 pp[2][0] = bcast(
floatCst(builder, -3.30093071049483172e-03f, elementType));
813 pp[2][1] = bcast(
floatCst(builder, +3.51961938357697011e-03f, elementType));
814 pp[2][2] = bcast(
floatCst(builder, -1.41373622814988039e-03f, elementType));
815 pp[2][3] = bcast(
floatCst(builder, +2.53447094961941348e-04f, elementType));
816 pp[2][4] = bcast(
floatCst(builder, -1.71048029455037401e-05f, elementType));
818 Value qq[intervalsCount][polyDegree + 1];
819 qq[0][0] = bcast(
floatCst(builder, +1.000000000000000000e+00f, elementType));
820 qq[0][1] = bcast(
floatCst(builder, -4.635138185962547255e-01f, elementType));
821 qq[0][2] = bcast(
floatCst(builder, +5.192301327279782447e-01f, elementType));
822 qq[0][3] = bcast(
floatCst(builder, -1.318089722204810087e-01f, elementType));
823 qq[0][4] = bcast(
floatCst(builder, +7.397964654672315005e-02f, elementType));
824 qq[1][0] = bcast(
floatCst(builder, +1.00000000000000000e+00f, elementType));
825 qq[1][1] = bcast(
floatCst(builder, -3.27607011824493086e-01f, elementType));
826 qq[1][2] = bcast(
floatCst(builder, +4.48369090658821977e-01f, elementType));
827 qq[1][3] = bcast(
floatCst(builder, -8.83462621207857930e-02f, elementType));
828 qq[1][4] = bcast(
floatCst(builder, +5.72442770283176093e-02f, elementType));
829 qq[2][0] = bcast(
floatCst(builder, +1.00000000000000000e+00f, elementType));
830 qq[2][1] = bcast(
floatCst(builder, -2.06069165953913769e+00f, elementType));
831 qq[2][2] = bcast(
floatCst(builder, +1.62705939945477759e+00f, elementType));
832 qq[2][3] = bcast(
floatCst(builder, -5.83389859211130017e-01f, elementType));
833 qq[2][4] = bcast(
floatCst(builder, +8.21908939856640930e-02f, elementType));
835 Value offsets[intervalsCount];
836 offsets[0] = bcast(
floatCst(builder, 0.0f, elementType));
837 offsets[1] = bcast(
floatCst(builder, 0.0f, elementType));
838 offsets[2] = bcast(
floatCst(builder, 1.0f, elementType));
840 Value bounds[intervalsCount];
841 bounds[0] = bcast(
floatCst(builder, 0.8f, elementType));
842 bounds[1] = bcast(
floatCst(builder, 2.0f, elementType));
843 bounds[2] = bcast(
floatCst(builder, 3.75f, elementType));
845 Value isNegativeArg =
846 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
847 Value negArg = builder.
create<arith::NegFOp>(operand);
848 Value x = builder.
create<arith::SelectOp>(isNegativeArg, negArg, operand);
850 Value offset = offsets[0];
851 Value p[polyDegree + 1];
852 Value q[polyDegree + 1];
853 for (
int i = 0; i <= polyDegree; ++i) {
859 Value isLessThanBound[intervalsCount];
860 for (
int j = 0;
j < intervalsCount - 1; ++
j) {
862 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[
j]);
863 for (
int i = 0; i <= polyDegree; ++i) {
864 p[i] = builder.
create<arith::SelectOp>(isLessThanBound[
j], p[i],
866 q[i] = builder.
create<arith::SelectOp>(isLessThanBound[
j], q[i],
869 offset = builder.
create<arith::SelectOp>(isLessThanBound[
j], offset,
872 isLessThanBound[intervalsCount - 1] = builder.
create<arith::CmpFOp>(
873 arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
875 Value pPoly = makePolynomialCalculation(builder, p, x);
876 Value qPoly = makePolynomialCalculation(builder, q, x);
877 Value rationalPoly = builder.
create<arith::DivFOp>(pPoly, qPoly);
878 Value formula = builder.
create<arith::AddFOp>(offset, rationalPoly);
879 formula = builder.
create<arith::SelectOp>(isLessThanBound[intervalsCount - 1],
883 Value negFormula = builder.
create<arith::NegFOp>(formula);
885 builder.
create<arith::SelectOp>(isNegativeArg, negFormula, formula);
911 ExpApproximation::matchAndRewrite(math::ExpOp op,
926 return builder.
create<math::FmaOp>(a, b, c);
929 return builder.
create<arith::MulFOp>(a, b);
932 return builder.
create<arith::SubFOp>(a, b);
940 Value cstCephesExpP0 = bcast(
f32Cst(builder, 1.0));
941 Value cstCephesExpP1 = bcast(
f32Cst(builder, 1.0));
942 Value cstCephesExpP2 = bcast(
f32Cst(builder, 0.49970514590562437052f));
943 Value cstCephesExpP3 = bcast(
f32Cst(builder, 0.16873890085469545053f));
944 Value cstCephesExpP4 = bcast(
f32Cst(builder, 0.03668965196652099192f));
945 Value cstCephesExpP5 = bcast(
f32Cst(builder, 0.01314350012789660196f));
947 Value x = op.getOperand();
949 Value isNan = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, x, x);
952 Value xL2Inv = mul(x, cstLog2E);
954 Value kLn2 = mul(kF32, cstLn2);
955 Value y = sub(x, kLn2);
959 Value y2 = mul(y, y);
960 Value y4 = mul(y2, y2);
962 Value q0 = fmla(cstCephesExpP1, y, cstCephesExpP0);
963 Value q1 = fmla(cstCephesExpP3, y, cstCephesExpP2);
964 Value q2 = fmla(cstCephesExpP5, y, cstCephesExpP4);
965 Value expY = fmla(q1, y2, q0);
966 expY = fmla(q2, y4, expY);
975 expY = mul(expY, exp2KValue);
984 auto constPosInfinity =
985 bcast(
f32Cst(builder, std::numeric_limits<float>::infinity()));
986 auto constNegIfinity =
987 bcast(
f32Cst(builder, -std::numeric_limits<float>::infinity()));
993 builder.
create<arith::CmpIOp>(arith::CmpIPredicate::sle, k, kMaxConst);
995 builder.
create<arith::CmpIOp>(arith::CmpIPredicate::sge, k, kMaxNegConst);
997 Value isNegInfinityX = builder.
create<arith::CmpFOp>(
998 arith::CmpFPredicate::OEQ, x, constNegIfinity);
999 Value isPosInfinityX = builder.
create<arith::CmpFOp>(
1000 arith::CmpFPredicate::OEQ, x, constPosInfinity);
1002 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zerof32Const);
1003 Value isComputable = builder.
create<arith::AndIOp>(rightBound, leftBound);
1005 expY = builder.
create<arith::SelectOp>(
1007 builder.
create<arith::SelectOp>(
1008 isNegInfinityX, zerof32Const,
1009 builder.
create<arith::SelectOp>(
1010 isPosInfinityX, constPosInfinity,
1011 builder.
create<arith::SelectOp>(
1013 builder.
create<arith::SelectOp>(isPostiveX, constPosInfinity,
1037 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
1046 return broadcast(builder, value, shape);
1054 Value x = op.getOperand();
1057 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
1058 Value uMinusOne = builder.
create<arith::SubFOp>(u, cstOne);
1059 Value uMinusOneEqNegOne = builder.
create<arith::CmpFOp>(
1060 arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
1066 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u);
1070 uMinusOne, builder.
create<arith::DivFOp>(x, logU));
1071 expm1 = builder.
create<arith::SelectOp>(isInf, u, expm1);
1072 Value approximation = builder.
create<arith::SelectOp>(
1074 builder.
create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
1085 template <
bool isSine,
typename OpTy>
1094 #define TWO_OVER_PI \
1095 0.6366197723675813430755350534900574481378385829618257949906693762L
1097 1.5707963267948966192313216916397514420985846996875529104874722961L
1102 template <
bool isSine,
typename OpTy>
1103 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
1106 llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
1107 "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
1116 return broadcast(builder, value, shape);
1119 return builder.
create<arith::MulFOp>(a, b);
1122 return builder.
create<arith::SubFOp>(a, b);
1127 auto fPToSingedInteger = [&](
Value a) ->
Value {
1128 return builder.
create<arith::FPToSIOp>(i32Vec, a);
1132 return builder.
create<arith::AndIOp>(a, bcast(
i32Cst(builder, 3)));
1136 return builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b);
1140 return builder.
create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b);
1144 return builder.
create<arith::SelectOp>(cond, t, f);
1148 return builder.
create<math::FmaOp>(a, b, c);
1152 return builder.
create<arith::OrIOp>(a, b);
1158 Value x = op.getOperand();
1162 Value y = sub(x, mul(k, piOverTwo));
1165 Value cstNegativeOne = bcast(
f32Cst(builder, -1.0));
1167 Value cstSC2 = bcast(
f32Cst(builder, -0.16666667163372039794921875f));
1168 Value cstSC4 = bcast(
f32Cst(builder, 8.333347737789154052734375e-3f));
1169 Value cstSC6 = bcast(
f32Cst(builder, -1.9842604524455964565277099609375e-4f));
1171 bcast(
f32Cst(builder, 2.760012648650445044040679931640625e-6f));
1173 bcast(
f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f));
1176 Value cstCC4 = bcast(
f32Cst(builder, 4.166664183139801025390625e-2f));
1177 Value cstCC6 = bcast(
f32Cst(builder, -1.388833043165504932403564453125e-3f));
1178 Value cstCC8 = bcast(
f32Cst(builder, 2.47562347794882953166961669921875e-5f));
1180 bcast(
f32Cst(builder, -2.59630184018533327616751194000244140625e-7f));
1182 Value kMod4 = modulo4(fPToSingedInteger(k));
1184 Value kR0 = isEqualTo(kMod4, bcast(
i32Cst(builder, 0)));
1185 Value kR1 = isEqualTo(kMod4, bcast(
i32Cst(builder, 1)));
1186 Value kR2 = isEqualTo(kMod4, bcast(
i32Cst(builder, 2)));
1187 Value kR3 = isEqualTo(kMod4, bcast(
i32Cst(builder, 3)));
1189 Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
1190 Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(
i32Cst(builder, 1)))
1191 : bitwiseOr(kR1, kR2);
1193 Value y2 = mul(y, y);
1195 Value base = select(sinuseCos, cstOne, y);
1196 Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
1197 Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
1198 Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
1199 Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
1200 Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
1202 Value v1 = fmla(y2, cstC10, cstC8);
1203 Value v2 = fmla(y2, v1, cstC6);
1204 Value v3 = fmla(y2, v2, cstC4);
1205 Value v4 = fmla(y2, v3, cstC2);
1206 Value v5 = fmla(y2, v4, cstOne);
1207 Value v6 = mul(base, v5);
1209 Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6);
1232 CbrtApproximation::matchAndRewrite(math::CbrtOp op,
1234 auto operand = op.getOperand();
1249 Value value = b.create<arith::ConstantOp>(attr);
1254 Value intTwo = bconst(b.getI32IntegerAttr(2));
1255 Value intFour = bconst(b.getI32IntegerAttr(4));
1256 Value intEight = bconst(b.getI32IntegerAttr(8));
1257 Value intMagic = bconst(b.getI32IntegerAttr(0x2a5137a0));
1258 Value fpThird = bconst(b.getF32FloatAttr(0.33333333f));
1259 Value fpTwo = bconst(b.getF32FloatAttr(2.0f));
1260 Value fpZero = bconst(b.getF32FloatAttr(0.0f));
1266 Value absValue = b.create<math::AbsFOp>(operand);
1267 Value intValue = b.create<arith::BitcastOp>(intTy, absValue);
1268 Value divideBy4 = b.create<arith::ShRSIOp>(intValue, intTwo);
1269 Value divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
1270 intValue = b.create<arith::AddIOp>(divideBy4, divideBy16);
1273 divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
1274 intValue = b.create<arith::AddIOp>(intValue, divideBy16);
1277 Value divideBy256 = b.create<arith::ShRSIOp>(intValue, intEight);
1278 intValue = b.create<arith::AddIOp>(intValue, divideBy256);
1281 intValue = b.create<arith::AddIOp>(intValue, intMagic);
1285 Value floatValue = b.create<arith::BitcastOp>(floatTy, intValue);
1286 Value squared = b.create<arith::MulFOp>(floatValue, floatValue);
1287 Value mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
1288 Value divSquared = b.create<arith::DivFOp>(absValue, squared);
1289 floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
1290 floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
1293 squared = b.create<arith::MulFOp>(floatValue, floatValue);
1294 mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
1295 divSquared = b.create<arith::DivFOp>(absValue, squared);
1296 floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
1297 floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
1301 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absValue, fpZero);
1302 floatValue = b.create<arith::SelectOp>(
isZero, fpZero, floatValue);
1303 floatValue = b.create<math::CopySignOp>(floatValue, operand);
1323 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
1331 if (shape.empty() || shape.back() % 8 != 0)
1336 return broadcast(builder, value, shape);
1340 Value cstOnePointFive = bcast(
f32Cst(builder, 1.5f));
1344 Value negHalf = builder.
create<arith::MulFOp>(op.getOperand(), cstNegHalf);
1349 arith::CmpFPredicate::OLT, op.getOperand(), cstMinNormPos);
1350 Value infMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
1351 op.getOperand(), cstPosInf);
1352 Value notNormalFiniteMask = builder.
create<arith::OrIOp>(ltMinMask, infMask);
1356 builder, op->getOperands(), 8, [&builder](
ValueRange operands) ->
Value {
1357 return builder.create<x86vector::RsqrtOp>(operands);
1364 Value inner = builder.
create<arith::MulFOp>(negHalf, yApprox);
1365 Value fma = builder.
create<math::FmaOp>(yApprox, inner, cstOnePointFive);
1366 Value yNewton = builder.
create<arith::MulFOp>(yApprox, fma);
1374 builder.
create<arith::SelectOp>(notNormalFiniteMask, yApprox, yNewton);
1387 .
add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
1388 ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
1389 ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
1390 ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
1391 ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
1392 ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
1395 patterns.
add<AtanApproximation, Atan2Approximation, TanhApproximation,
1396 LogApproximation, Log2Approximation, Log1pApproximation,
1398 CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
1399 SinAndCosApproximation<false, math::CosOp>>(
1402 patterns.
add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
static llvm::ManagedStatic< PassManagerOptions > options
static std::pair< Value, Value > frexp(ImplicitLocOpBuilder &builder, Value arg, bool isPositive=false)
static Type broadcast(Type type, ArrayRef< int64_t > shape)
static Value f32Cst(ImplicitLocOpBuilder &builder, float value)
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg)
static ArrayRef< int64_t > vectorShape(Type type)
static Value floatCst(ImplicitLocOpBuilder &builder, float value, Type elementType)
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static bool isZero(OpFoldResult v)
Attributes are known-constant values of operations.
IntegerAttr getI32IntegerAttr(int32_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
Attribute getZeroAttr(Type type)
FloatAttr getF32FloatAttr(float value)
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
int compare(const Fraction &x, const Fraction &y)
Three-way comparison between two fractions.
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
MPInt floor(const Fraction &f)
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
int64_t computeMaxLinearIndex(ArrayRef< int64_t > basis)
Return the number of elements of basis (i.e.
void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns, const MathPolynomialApproximationOptions &options={})
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(math::ErfOp op, PatternRewriter &rewriter) const final
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.