34 #include "llvm/ADT/ArrayRef.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/Support/MathExtras.h"
50 if (
auto vectorType = dyn_cast<VectorType>(type)) {
51 return VectorShape{vectorType.getShape(), vectorType.getScalableDims()};
66 assert(!isa<VectorType>(type) &&
"must be scalar type");
67 return shape ?
VectorType::get(shape->sizes, type, shape->scalableFlags)
73 std::optional<VectorShape> shape) {
74 assert(!isa<VectorType>(value.
getType()) &&
"must be scalar value");
76 return shape ? builder.
create<BroadcastOp>(type, value) : value;
102 assert(!operands.empty() &&
"operands must be not empty");
103 assert(vectorWidth > 0 &&
"vector width must be larger than 0");
105 VectorType inputType = cast<VectorType>(operands[0].
getType());
111 return compute(operands);
115 int64_t innerDim = inputShape.back();
116 int64_t expansionDim = innerDim / vectorWidth;
117 assert((innerDim % vectorWidth == 0) &&
"invalid inner dimension size");
124 if (expansionDim > 1) {
126 expandedShape.insert(expandedShape.end() - 1, expansionDim);
127 expandedShape.back() = vectorWidth;
129 for (
unsigned i = 0; i < operands.size(); ++i) {
130 auto operand = operands[i];
131 auto eltType = cast<VectorType>(operand.getType()).getElementType();
133 expandedOperands[i] =
134 builder.
create<vector::ShapeCastOp>(expandedType, operand);
146 for (int64_t i = 0; i < maxIndex; ++i) {
151 extracted[tuple.index()] =
152 builder.
create<vector::ExtractOp>(tuple.value(), offsets);
154 results[i] = compute(extracted);
158 Type resultEltType = cast<VectorType>(results[0].
getType()).getElementType();
161 resultExpandedType, builder.
getZeroAttr(resultExpandedType));
163 for (int64_t i = 0; i < maxIndex; ++i)
164 result = builder.
create<vector::InsertOp>(results[i], result,
168 return builder.
create<vector::ShapeCastOp>(
182 assert((elementType.
isF16() || elementType.
isF32()) &&
183 "x must be f16 or f32 type.");
184 return builder.
create<arith::ConstantOp>(
197 Value i32Value =
i32Cst(builder,
static_cast<int32_t
>(bits));
207 return builder.
create<arith::SelectOp>(
208 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::ULT, value, bound),
214 return builder.
create<arith::SelectOp>(
215 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::UGT, value, bound),
222 return max(builder,
min(builder, value, upperBound), lowerBound);
228 bool isPositive =
false) {
230 std::optional<VectorShape> shape =
vectorShape(arg);
245 Value i32Half = builder.
create<arith::BitcastOp>(i32, cstHalf);
246 Value i32InvMantMask = builder.
create<arith::BitcastOp>(i32, cstInvMantMask);
247 Value i32Arg = builder.
create<arith::BitcastOp>(i32Vec, arg);
250 Value tmp0 = builder.
create<arith::AndIOp>(i32Arg, bcast(i32InvMantMask));
251 Value tmp1 = builder.
create<arith::OrIOp>(tmp0, bcast(i32Half));
252 Value normalizedFraction = builder.
create<arith::BitcastOp>(f32Vec, tmp1);
255 Value arg0 = isPositive ? arg : builder.
create<math::AbsFOp>(arg);
256 Value biasedExponentBits = builder.
create<arith::ShRUIOp>(
257 builder.
create<arith::BitcastOp>(i32Vec, arg0),
258 bcast(
i32Cst(builder, 23)));
259 Value biasedExponent =
260 builder.
create<arith::SIToFPOp>(f32Vec, biasedExponentBits);
262 builder.
create<arith::SubFOp>(biasedExponent, bcast(cst126f));
264 return {normalizedFraction, exponent};
270 std::optional<VectorShape> shape =
vectorShape(arg);
278 auto exponetBitLocation = bcast(
i32Cst(builder, 23));
280 auto bias = bcast(
i32Cst(builder, 127));
282 Value biasedArg = builder.
create<arith::AddIOp>(arg, bias);
284 builder.
create<arith::ShLIOp>(biasedArg, exponetBitLocation);
285 Value exp2ValueF32 = builder.
create<arith::BitcastOp>(f32Vec, exp2ValueInt);
294 assert((elementType.
isF32() || elementType.
isF16()) &&
295 "x must be f32 or f16 type");
301 if (coeffs.size() == 1)
304 Value res = builder.
create<math::FmaOp>(x, coeffs[coeffs.size() - 1],
305 coeffs[coeffs.size() - 2]);
306 for (
auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) {
307 res = builder.
create<math::FmaOp>(x, res, coeffs[i]);
317 template <
typename T>
335 if (
auto shaped = dyn_cast<ShapedType>(origType)) {
336 newType = shaped.clone(rewriter.
getF32Type());
337 }
else if (isa<FloatType>(origType)) {
341 "unable to find F32 equivalent type");
347 operands.push_back(rewriter.
create<arith::ExtFOp>(loc, newType, operand));
360 template <
typename T>
364 LogicalResult matchAndRewrite(T op,
PatternRewriter &rewriter)
const final {
366 T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(),
367 "requires same operands and result types");
368 return insertCasts<T>(op, rewriter);
382 LogicalResult matchAndRewrite(math::AtanOp op,
388 AtanApproximation::matchAndRewrite(math::AtanOp op,
390 auto operand = op.getOperand();
394 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
404 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
abs, twoThirds);
408 Value xden = builder.
create<arith::SelectOp>(cmp2, addone, one);
415 auto tan3pio8 = bcast(
f32Cst(builder, 2.41421356237309504880));
417 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
abs, tan3pio8);
418 xnum = builder.
create<arith::SelectOp>(cmp1, one, xnum);
419 xden = builder.
create<arith::SelectOp>(cmp1,
abs, xden);
421 Value x = builder.
create<arith::DivFOp>(xnum, xden);
426 auto p0 = bcast(
f32Cst(builder, -8.750608600031904122785e-01));
427 auto p1 = bcast(
f32Cst(builder, -1.615753718733365076637e+01));
428 auto p2 = bcast(
f32Cst(builder, -7.500855792314704667340e+01));
429 auto p3 = bcast(
f32Cst(builder, -1.228866684490136173410e+02));
430 auto p4 = bcast(
f32Cst(builder, -6.485021904942025371773e+01));
431 auto q0 = bcast(
f32Cst(builder, +2.485846490142306297962e+01));
432 auto q1 = bcast(
f32Cst(builder, +1.650270098316988542046e+02));
433 auto q2 = bcast(
f32Cst(builder, +4.328810604912902668951e+02));
434 auto q3 = bcast(
f32Cst(builder, +4.853903996359136964868e+02));
435 auto q4 = bcast(
f32Cst(builder, +1.945506571482613964425e+02));
439 n = builder.
create<math::FmaOp>(xx, n, p1);
440 n = builder.
create<math::FmaOp>(xx, n, p2);
441 n = builder.
create<math::FmaOp>(xx, n, p3);
442 n = builder.
create<math::FmaOp>(xx, n, p4);
443 n = builder.
create<arith::MulFOp>(n, xx);
447 d = builder.
create<math::FmaOp>(xx, d, q1);
448 d = builder.
create<math::FmaOp>(xx, d, q2);
449 d = builder.
create<math::FmaOp>(xx, d, q3);
450 d = builder.
create<math::FmaOp>(xx, d, q4);
454 ans0 = builder.
create<math::FmaOp>(ans0, x, x);
457 Value mpi4 = bcast(
f32Cst(builder, llvm::numbers::pi / 4));
458 Value ans2 = builder.
create<arith::AddFOp>(mpi4, ans0);
459 Value ans = builder.
create<arith::SelectOp>(cmp2, ans2, ans0);
461 Value mpi2 = bcast(
f32Cst(builder, llvm::numbers::pi / 2));
462 Value ans1 = builder.
create<arith::SubFOp>(mpi2, ans0);
463 ans = builder.
create<arith::SelectOp>(cmp1, ans1, ans);
479 LogicalResult matchAndRewrite(math::Atan2Op op,
485 Atan2Approximation::matchAndRewrite(math::Atan2Op op,
487 auto y = op.getOperand(0);
488 auto x = op.getOperand(1);
493 std::optional<VectorShape> shape =
vectorShape(op.getResult());
496 auto div = builder.
create<arith::DivFOp>(y, x);
497 auto atan = builder.
create<math::AtanOp>(div);
502 auto addPi = builder.
create<arith::AddFOp>(atan, pi);
503 auto subPi = builder.
create<arith::SubFOp>(atan, pi);
505 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, atan, zero);
506 auto flippedAtan = builder.
create<arith::SelectOp>(atanGt, subPi, addPi);
509 auto xGt = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, x, zero);
510 Value result = builder.
create<arith::SelectOp>(xGt, atan, flippedAtan);
514 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, x, zero);
515 Value yGt = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, y, zero);
516 Value isHalfPi = builder.
create<arith::AndIOp>(xZero, yGt);
517 auto halfPi =
broadcast(builder,
f32Cst(builder, 1.57079632679f), shape);
518 result = builder.
create<arith::SelectOp>(isHalfPi, halfPi, result);
521 Value yLt = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, y, zero);
522 Value isNegativeHalfPiPi = builder.
create<arith::AndIOp>(xZero, yLt);
523 auto negativeHalfPiPi =
525 result = builder.
create<arith::SelectOp>(isNegativeHalfPiPi, negativeHalfPiPi,
530 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, y, zero);
531 Value isNan = builder.
create<arith::AndIOp>(xZero, yZero);
533 result = builder.
create<arith::SelectOp>(isNan, cstNan, result);
548 LogicalResult matchAndRewrite(math::TanhOp op,
554 TanhApproximation::matchAndRewrite(math::TanhOp op,
559 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
567 Value minusClamp = bcast(
f32Cst(builder, -7.99881172180175781f));
568 Value plusClamp = bcast(
f32Cst(builder, 7.99881172180175781f));
569 Value x =
clamp(builder, op.getOperand(), minusClamp, plusClamp);
574 arith::CmpFPredicate::OLT, builder.
create<math::AbsFOp>(op.getOperand()),
578 Value alpha1 = bcast(
f32Cst(builder, 4.89352455891786e-03f));
579 Value alpha3 = bcast(
f32Cst(builder, 6.37261928875436e-04f));
580 Value alpha5 = bcast(
f32Cst(builder, 1.48572235717979e-05f));
581 Value alpha7 = bcast(
f32Cst(builder, 5.12229709037114e-08f));
582 Value alpha9 = bcast(
f32Cst(builder, -8.60467152213735e-11f));
583 Value alpha11 = bcast(
f32Cst(builder, 2.00018790482477e-13f));
584 Value alpha13 = bcast(
f32Cst(builder, -2.76076847742355e-16f));
587 Value beta0 = bcast(
f32Cst(builder, 4.89352518554385e-03f));
588 Value beta2 = bcast(
f32Cst(builder, 2.26843463243900e-03f));
589 Value beta4 = bcast(
f32Cst(builder, 1.18534705686654e-04f));
590 Value beta6 = bcast(
f32Cst(builder, 1.19825839466702e-06f));
596 Value p = builder.
create<math::FmaOp>(x2, alpha13, alpha11);
597 p = builder.
create<math::FmaOp>(x2, p, alpha9);
598 p = builder.
create<math::FmaOp>(x2, p, alpha7);
599 p = builder.
create<math::FmaOp>(x2, p, alpha5);
600 p = builder.
create<math::FmaOp>(x2, p, alpha3);
601 p = builder.
create<math::FmaOp>(x2, p, alpha1);
602 p = builder.
create<arith::MulFOp>(x, p);
605 Value q = builder.
create<math::FmaOp>(x2, beta6, beta4);
606 q = builder.
create<math::FmaOp>(x2, q, beta2);
607 q = builder.
create<math::FmaOp>(x2, q, beta0);
611 tinyMask, x, builder.
create<arith::DivFOp>(p, q));
619 0.693147180559945309417232121458176568075500134360255254120680009493393621L
620 #define LOG2E_VALUE \
621 1.442695040888963407359924681001892137426645954152985934135449406931109219L
628 template <
typename Op>
640 template <
typename Op>
647 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
665 Value cstCephesSQRTHF = bcast(
f32Cst(builder, 0.707106781186547524f));
666 Value cstCephesLogP0 = bcast(
f32Cst(builder, 7.0376836292E-2f));
667 Value cstCephesLogP1 = bcast(
f32Cst(builder, -1.1514610310E-1f));
668 Value cstCephesLogP2 = bcast(
f32Cst(builder, 1.1676998740E-1f));
669 Value cstCephesLogP3 = bcast(
f32Cst(builder, -1.2420140846E-1f));
670 Value cstCephesLogP4 = bcast(
f32Cst(builder, +1.4249322787E-1f));
671 Value cstCephesLogP5 = bcast(
f32Cst(builder, -1.6668057665E-1f));
672 Value cstCephesLogP6 = bcast(
f32Cst(builder, +2.0000714765E-1f));
673 Value cstCephesLogP7 = bcast(
f32Cst(builder, -2.4999993993E-1f));
674 Value cstCephesLogP8 = bcast(
f32Cst(builder, +3.3333331174E-1f));
676 Value x = op.getOperand();
679 x =
max(builder, x, cstMinNormPos);
682 std::pair<Value, Value> pair =
frexp(builder, x,
true);
684 Value e = pair.second;
694 Value mask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x,
696 Value tmp = builder.
create<arith::SelectOp>(mask, x, cstZero);
698 x = builder.
create<arith::SubFOp>(x, cstOne);
699 e = builder.
create<arith::SubFOp>(
700 e, builder.
create<arith::SelectOp>(mask, cstOne, cstZero));
701 x = builder.
create<arith::AddFOp>(x, tmp);
708 y0 = builder.
create<math::FmaOp>(cstCephesLogP0, x, cstCephesLogP1);
709 y1 = builder.
create<math::FmaOp>(cstCephesLogP3, x, cstCephesLogP4);
710 y2 = builder.
create<math::FmaOp>(cstCephesLogP6, x, cstCephesLogP7);
711 y0 = builder.
create<math::FmaOp>(y0, x, cstCephesLogP2);
712 y1 = builder.
create<math::FmaOp>(y1, x, cstCephesLogP5);
713 y2 = builder.
create<math::FmaOp>(y2, x, cstCephesLogP8);
714 y0 = builder.
create<math::FmaOp>(y0, x3, y1);
715 y0 = builder.
create<math::FmaOp>(y0, x3, y2);
716 y0 = builder.
create<arith::MulFOp>(y0, x3);
718 y0 = builder.
create<math::FmaOp>(cstNegHalf, x2, y0);
719 x = builder.
create<arith::AddFOp>(x, y0);
723 x = builder.
create<math::FmaOp>(x, cstLog2e, e);
726 x = builder.
create<math::FmaOp>(e, cstLn2, x);
729 Value invalidMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::ULT,
730 op.getOperand(), cstZero);
731 Value zeroMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
732 op.getOperand(), cstZero);
733 Value posInfMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
734 op.getOperand(), cstPosInf);
740 Value aproximation = builder.
create<arith::SelectOp>(
741 zeroMask, cstMinusInf,
742 builder.
create<arith::SelectOp>(
744 builder.
create<arith::SelectOp>(posInfMask, cstPosInf, x)));
752 struct LogApproximation :
public LogApproximationBase<math::LogOp> {
753 using LogApproximationBase::LogApproximationBase;
755 LogicalResult matchAndRewrite(math::LogOp op,
757 return logMatchAndRewrite(op, rewriter,
false);
763 struct Log2Approximation :
public LogApproximationBase<math::Log2Op> {
764 using LogApproximationBase::LogApproximationBase;
766 LogicalResult matchAndRewrite(math::Log2Op op,
768 return logMatchAndRewrite(op, rewriter,
true);
782 LogicalResult matchAndRewrite(math::Log1pOp op,
789 Log1pApproximation::matchAndRewrite(math::Log1pOp op,
794 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
808 Value x = op.getOperand();
809 Value u = builder.
create<arith::AddFOp>(x, cstOne);
811 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, cstOne);
814 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, u, logU);
816 x, builder.
create<arith::DivFOp>(
817 logU, builder.
create<arith::SubFOp>(u, cstOne)));
818 Value approximation = builder.
create<arith::SelectOp>(
819 builder.
create<arith::OrIOp>(uSmall, uInf), x, logLarge);
832 struct AsinPolynomialApproximation :
public OpRewritePattern<math::AsinOp> {
836 LogicalResult matchAndRewrite(math::AsinOp op,
841 AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
843 Value operand = op.getOperand();
846 if (!(elementType.
isF32() || elementType.
isF16()))
848 "only f32 and f16 type is supported.");
849 std::optional<VectorShape> shape =
vectorShape(operand);
857 return builder.
create<math::FmaOp>(a, b, c);
861 return builder.
create<arith::MulFOp>(a, b);
865 return builder.
create<arith::SubFOp>(a, b);
870 auto sqrt = [&](
Value a) ->
Value {
return builder.
create<math::SqrtOp>(a); };
873 return builder.
create<math::CopySignOp>(a, b);
877 return builder.
create<arith::SelectOp>(a, b, c);
881 Value aa = mul(operand, operand);
882 Value opp = sqrt(sub(bcast(
floatCst(builder, 1.0, elementType)), aa));
885 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, aa,
886 bcast(
floatCst(builder, 0.5, elementType)));
888 Value x = sel(gt, opp, abso);
893 Value r = bcast(
floatCst(builder, 5.5579749017470502e-2, elementType));
894 Value t = bcast(
floatCst(builder, -6.2027913464120114e-2, elementType));
896 r = fma(r, q, bcast(
floatCst(builder, 5.4224464349245036e-2, elementType)));
897 t = fma(t, q, bcast(
floatCst(builder, -1.1326992890324464e-2, elementType)));
898 r = fma(r, q, bcast(
floatCst(builder, 1.5268872539397656e-2, elementType)));
899 t = fma(t, q, bcast(
floatCst(builder, 1.0493798473372081e-2, elementType)));
900 r = fma(r, q, bcast(
floatCst(builder, 1.4106045900607047e-2, elementType)));
901 t = fma(t, q, bcast(
floatCst(builder, 1.7339776384962050e-2, elementType)));
902 r = fma(r, q, bcast(
floatCst(builder, 2.2372961589651054e-2, elementType)));
903 t = fma(t, q, bcast(
floatCst(builder, 3.0381912707941005e-2, elementType)));
904 r = fma(r, q, bcast(
floatCst(builder, 4.4642857881094775e-2, elementType)));
905 t = fma(t, q, bcast(
floatCst(builder, 7.4999999991367292e-2, elementType)));
907 r = fma(r, s, bcast(
floatCst(builder, 1.6666666666670193e-1, elementType)));
911 Value rsub = sub(bcast(
floatCst(builder, 1.57079632679, elementType)), r);
912 r = sel(gt, rsub, r);
913 r = scopy(r, operand);
927 struct AcosPolynomialApproximation :
public OpRewritePattern<math::AcosOp> {
931 LogicalResult matchAndRewrite(math::AcosOp op,
936 AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
938 Value operand = op.getOperand();
941 if (!(elementType.
isF32() || elementType.
isF16()))
943 "only f32 and f16 type is supported.");
944 std::optional<VectorShape> shape =
vectorShape(operand);
952 return builder.
create<math::FmaOp>(a, b, c);
956 return builder.
create<arith::MulFOp>(a, b);
959 Value negOperand = builder.
create<arith::NegFOp>(operand);
964 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, operand, zero);
965 Value r = builder.
create<arith::SelectOp>(selR, negOperand, operand);
966 Value chkConst = bcast(
floatCst(builder, -0.5625, elementType));
968 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, r, chkConst);
971 fma(bcast(
floatCst(builder, 9.3282184640716537e-1, elementType)),
972 bcast(
floatCst(builder, 1.6839188885261840e+0, elementType)),
973 builder.
create<math::AsinOp>(r));
975 Value falseVal = builder.
create<math::SqrtOp>(fma(half, r, half));
976 falseVal = builder.
create<math::AsinOp>(falseVal);
977 falseVal = mul(bcast(
floatCst(builder, 2.0, elementType)), falseVal);
979 r = builder.
create<arith::SelectOp>(firstPred, trueVal, falseVal);
982 Value greaterThanNegOne =
983 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGE, operand, negOne);
986 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
988 Value betweenNegOneZero =
989 builder.
create<arith::AndIOp>(greaterThanNegOne, lessThanZero);
991 trueVal = fma(bcast(
floatCst(builder, 1.8656436928143307e+0, elementType)),
992 bcast(
floatCst(builder, 1.6839188885261840e+0, elementType)),
993 builder.
create<arith::NegFOp>(r));
996 builder.
create<arith::SelectOp>(betweenNegOneZero, trueVal, r);
1016 Value operand = op.getOperand();
1019 if (!(elementType.
isF32() || elementType.
isF16()))
1021 "only f32 and f16 type is supported.");
1022 std::optional<VectorShape> shape =
vectorShape(operand);
1026 return broadcast(builder, value, shape);
1029 const int intervalsCount = 3;
1030 const int polyDegree = 4;
1034 Value pp[intervalsCount][polyDegree + 1];
1035 pp[0][0] = bcast(
floatCst(builder, +0.00000000000000000e+00f, elementType));
1036 pp[0][1] = bcast(
floatCst(builder, +1.12837916222975858e+00f, elementType));
1037 pp[0][2] = bcast(
floatCst(builder, -5.23018562988006470e-01f, elementType));
1038 pp[0][3] = bcast(
floatCst(builder, +2.09741709609267072e-01f, elementType));
1039 pp[0][4] = bcast(
floatCst(builder, +2.58146801602987875e-02f, elementType));
1040 pp[1][0] = bcast(
floatCst(builder, +0.00000000000000000e+00f, elementType));
1041 pp[1][1] = bcast(
floatCst(builder, +1.12750687816789140e+00f, elementType));
1042 pp[1][2] = bcast(
floatCst(builder, -3.64721408487825775e-01f, elementType));
1043 pp[1][3] = bcast(
floatCst(builder, +1.18407396425136952e-01f, elementType));
1044 pp[1][4] = bcast(
floatCst(builder, +3.70645533056476558e-02f, elementType));
1045 pp[2][0] = bcast(
floatCst(builder, -3.30093071049483172e-03f, elementType));
1046 pp[2][1] = bcast(
floatCst(builder, +3.51961938357697011e-03f, elementType));
1047 pp[2][2] = bcast(
floatCst(builder, -1.41373622814988039e-03f, elementType));
1048 pp[2][3] = bcast(
floatCst(builder, +2.53447094961941348e-04f, elementType));
1049 pp[2][4] = bcast(
floatCst(builder, -1.71048029455037401e-05f, elementType));
1051 Value qq[intervalsCount][polyDegree + 1];
1052 qq[0][0] = bcast(
floatCst(builder, +1.000000000000000000e+00f, elementType));
1053 qq[0][1] = bcast(
floatCst(builder, -4.635138185962547255e-01f, elementType));
1054 qq[0][2] = bcast(
floatCst(builder, +5.192301327279782447e-01f, elementType));
1055 qq[0][3] = bcast(
floatCst(builder, -1.318089722204810087e-01f, elementType));
1056 qq[0][4] = bcast(
floatCst(builder, +7.397964654672315005e-02f, elementType));
1057 qq[1][0] = bcast(
floatCst(builder, +1.00000000000000000e+00f, elementType));
1058 qq[1][1] = bcast(
floatCst(builder, -3.27607011824493086e-01f, elementType));
1059 qq[1][2] = bcast(
floatCst(builder, +4.48369090658821977e-01f, elementType));
1060 qq[1][3] = bcast(
floatCst(builder, -8.83462621207857930e-02f, elementType));
1061 qq[1][4] = bcast(
floatCst(builder, +5.72442770283176093e-02f, elementType));
1062 qq[2][0] = bcast(
floatCst(builder, +1.00000000000000000e+00f, elementType));
1063 qq[2][1] = bcast(
floatCst(builder, -2.06069165953913769e+00f, elementType));
1064 qq[2][2] = bcast(
floatCst(builder, +1.62705939945477759e+00f, elementType));
1065 qq[2][3] = bcast(
floatCst(builder, -5.83389859211130017e-01f, elementType));
1066 qq[2][4] = bcast(
floatCst(builder, +8.21908939856640930e-02f, elementType));
1068 Value offsets[intervalsCount];
1069 offsets[0] = bcast(
floatCst(builder, 0.0f, elementType));
1070 offsets[1] = bcast(
floatCst(builder, 0.0f, elementType));
1071 offsets[2] = bcast(
floatCst(builder, 1.0f, elementType));
1073 Value bounds[intervalsCount];
1074 bounds[0] = bcast(
floatCst(builder, 0.8f, elementType));
1075 bounds[1] = bcast(
floatCst(builder, 2.0f, elementType));
1076 bounds[2] = bcast(
floatCst(builder, 3.75f, elementType));
1078 Value isNegativeArg =
1079 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, operand, zero);
1080 Value negArg = builder.
create<arith::NegFOp>(operand);
1081 Value x = builder.
create<arith::SelectOp>(isNegativeArg, negArg, operand);
1083 Value offset = offsets[0];
1084 Value p[polyDegree + 1];
1085 Value q[polyDegree + 1];
1086 for (
int i = 0; i <= polyDegree; ++i) {
1092 Value isLessThanBound[intervalsCount];
1093 for (
int j = 0;
j < intervalsCount - 1; ++
j) {
1094 isLessThanBound[
j] =
1095 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, bounds[
j]);
1096 for (
int i = 0; i <= polyDegree; ++i) {
1097 p[i] = builder.
create<arith::SelectOp>(isLessThanBound[
j], p[i],
1099 q[i] = builder.
create<arith::SelectOp>(isLessThanBound[
j], q[i],
1102 offset = builder.
create<arith::SelectOp>(isLessThanBound[
j], offset,
1105 isLessThanBound[intervalsCount - 1] = builder.
create<arith::CmpFOp>(
1106 arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
1108 Value pPoly = makePolynomialCalculation(builder, p, x);
1109 Value qPoly = makePolynomialCalculation(builder, q, x);
1110 Value rationalPoly = builder.
create<arith::DivFOp>(pPoly, qPoly);
1111 Value formula = builder.
create<arith::AddFOp>(offset, rationalPoly);
1112 formula = builder.
create<arith::SelectOp>(isLessThanBound[intervalsCount - 1],
1116 Value negFormula = builder.
create<arith::NegFOp>(formula);
1118 builder.
create<arith::SelectOp>(isNegativeArg, negFormula, formula);
1137 Value x = op.getOperand();
1142 std::optional<VectorShape> shape =
vectorShape(x);
1146 return broadcast(builder, value, shape);
1162 Value q = builder.
create<math::FmaOp>(neg4, r, one);
1166 q = builder.
create<math::FmaOp>(r, e, q);
1168 p = bcast(
floatCst(builder, -0x1.a4a000p-12f, et));
1170 p = builder.
create<math::FmaOp>(p, q, c1);
1172 p = builder.
create<math::FmaOp>(p, q, c2);
1174 p = builder.
create<math::FmaOp>(p, q, c3);
1176 p = builder.
create<math::FmaOp>(p, q, c4);
1178 p = builder.
create<math::FmaOp>(p, q, c5);
1180 p = builder.
create<math::FmaOp>(p, q, c6);
1182 p = builder.
create<math::FmaOp>(p, q, c7);
1184 p = builder.
create<math::FmaOp>(p, q, c8);
1186 p = builder.
create<math::FmaOp>(p, q, c9);
1188 Value d = builder.
create<math::FmaOp>(pos2, a, one);
1189 r = builder.
create<arith::DivFOp>(one, d);
1190 q = builder.
create<math::FmaOp>(p, r, r);
1192 Value fmaqah = builder.
create<math::FmaOp>(q, negfa, onehalf);
1193 Value psubq = builder.
create<arith::SubFOp>(p, q);
1194 e = builder.
create<math::FmaOp>(fmaqah, pos2, psubq);
1195 r = builder.
create<math::FmaOp>(e, r, q);
1198 e = builder.
create<math::ExpOp>(builder.
create<arith::NegFOp>(s));
1200 t = builder.
create<math::FmaOp>(builder.
create<arith::NegFOp>(a), a, s);
1201 r = builder.
create<math::FmaOp>(
1203 builder.
create<arith::MulFOp>(builder.
create<arith::MulFOp>(r, e), t));
1205 Value isNotLessThanInf = builder.
create<arith::XOrIOp>(
1206 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, a, posInf),
1208 r = builder.
create<arith::SelectOp>(isNotLessThanInf,
1209 builder.
create<arith::AddFOp>(x, x), r);
1210 Value isGreaterThanClamp =
1211 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, a, clampVal);
1212 r = builder.
create<arith::SelectOp>(isGreaterThanClamp, zero, r);
1215 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, zero);
1216 r = builder.
create<arith::SelectOp>(
1217 isNegative, builder.
create<arith::SubFOp>(pos2, r), r);
1229 const std::optional<VectorShape> shape,
Value value,
1230 float lowerBound,
float upperBound) {
1231 assert(!std::isnan(lowerBound));
1232 assert(!std::isnan(upperBound));
1235 return broadcast(builder, value, shape);
1238 auto selectCmp = [&builder](
auto pred,
Value value,
Value bound) {
1239 return builder.
create<arith::SelectOp>(
1240 builder.
create<arith::CmpFOp>(pred, value, bound), value, bound);
1246 value = selectCmp(arith::CmpFPredicate::UGE, value,
1247 bcast(
f32Cst(builder, lowerBound)));
1248 value = selectCmp(arith::CmpFPredicate::ULE, value,
1249 bcast(
f32Cst(builder, upperBound)));
1257 LogicalResult matchAndRewrite(math::ExpOp op,
1262 ExpApproximation::matchAndRewrite(math::ExpOp op,
1264 auto shape =
vectorShape(op.getOperand().getType());
1266 if (!elementTy.isF32())
1272 return builder.
create<arith::AddFOp>(a, b);
1275 return broadcast(builder, value, shape);
1279 return builder.
create<math::FmaOp>(a, b, c);
1282 return builder.
create<arith::MulFOp>(a, b);
1310 Value cstLog2ef = bcast(
f32Cst(builder, 1.44269504088896341f));
1312 Value cstExpC1 = bcast(
f32Cst(builder, -0.693359375f));
1313 Value cstExpC2 = bcast(
f32Cst(builder, 2.12194440e-4f));
1314 Value cstExpP0 = bcast(
f32Cst(builder, 1.9875691500E-4f));
1315 Value cstExpP1 = bcast(
f32Cst(builder, 1.3981999507E-3f));
1316 Value cstExpP2 = bcast(
f32Cst(builder, 8.3334519073E-3f));
1317 Value cstExpP3 = bcast(
f32Cst(builder, 4.1665795894E-2f));
1318 Value cstExpP4 = bcast(
f32Cst(builder, 1.6666665459E-1f));
1319 Value cstExpP5 = bcast(
f32Cst(builder, 5.0000001201E-1f));
1326 Value x = op.getOperand();
1327 x = clampWithNormals(builder, shape, x, -87.8f, 88.8f);
1328 Value n =
floor(fmla(x, cstLog2ef, cstHalf));
1369 n = clampWithNormals(builder, shape, n, -127.0f, 127.0f);
1372 x = fmla(cstExpC1, n, x);
1373 x = fmla(cstExpC2, n, x);
1376 Value z = fmla(x, cstExpP0, cstExpP1);
1377 z = fmla(z, x, cstExpP2);
1378 z = fmla(z, x, cstExpP3);
1379 z = fmla(z, x, cstExpP4);
1380 z = fmla(z, x, cstExpP5);
1381 z = fmla(z, mul(x, x), x);
1386 Value nI32 = builder.
create<arith::FPToSIOp>(i32Vec, n);
1392 Value ret = mul(z, pow2);
1395 return mlir::success();
1410 LogicalResult matchAndRewrite(math::ExpM1Op op,
1416 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
1421 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
1425 return broadcast(builder, value, shape);
1433 Value x = op.getOperand();
1436 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, u, cstOne);
1437 Value uMinusOne = builder.
create<arith::SubFOp>(u, cstOne);
1438 Value uMinusOneEqNegOne = builder.
create<arith::CmpFOp>(
1439 arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
1445 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, logU, u);
1449 uMinusOne, builder.
create<arith::DivFOp>(x, logU));
1450 expm1 = builder.
create<arith::SelectOp>(isInf, u, expm1);
1451 Value approximation = builder.
create<arith::SelectOp>(
1453 builder.
create<arith::SelectOp>(uMinusOneEqNegOne, cstNegOne, expm1));
1464 template <
bool isSine,
typename OpTy>
1469 LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter)
const final;
1473 #define TWO_OVER_PI \
1474 0.6366197723675813430755350534900574481378385829618257949906693762L
1476 1.5707963267948966192313216916397514420985846996875529104874722961L
1481 template <
bool isSine,
typename OpTy>
1482 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
1485 llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
1486 "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
1491 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
1495 return broadcast(builder, value, shape);
1498 return builder.
create<arith::MulFOp>(a, b);
1501 return builder.
create<arith::SubFOp>(a, b);
1506 auto fPToSingedInteger = [&](
Value a) ->
Value {
1507 return builder.
create<arith::FPToSIOp>(i32Vec, a);
1511 return builder.
create<arith::AndIOp>(a, bcast(
i32Cst(builder, 3)));
1515 return builder.
create<arith::CmpIOp>(arith::CmpIPredicate::eq, a, b);
1519 return builder.
create<arith::CmpIOp>(arith::CmpIPredicate::sgt, a, b);
1523 return builder.
create<arith::SelectOp>(cond, t, f);
1527 return builder.
create<math::FmaOp>(a, b, c);
1531 return builder.
create<arith::OrIOp>(a, b);
1537 Value x = op.getOperand();
1541 Value y = sub(x, mul(k, piOverTwo));
1544 Value cstNegativeOne = bcast(
f32Cst(builder, -1.0));
1546 Value cstSC2 = bcast(
f32Cst(builder, -0.16666667163372039794921875f));
1547 Value cstSC4 = bcast(
f32Cst(builder, 8.333347737789154052734375e-3f));
1548 Value cstSC6 = bcast(
f32Cst(builder, -1.9842604524455964565277099609375e-4f));
1550 bcast(
f32Cst(builder, 2.760012648650445044040679931640625e-6f));
1552 bcast(
f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f));
1555 Value cstCC4 = bcast(
f32Cst(builder, 4.166664183139801025390625e-2f));
1556 Value cstCC6 = bcast(
f32Cst(builder, -1.388833043165504932403564453125e-3f));
1557 Value cstCC8 = bcast(
f32Cst(builder, 2.47562347794882953166961669921875e-5f));
1559 bcast(
f32Cst(builder, -2.59630184018533327616751194000244140625e-7f));
1561 Value kMod4 = modulo4(fPToSingedInteger(k));
1563 Value kR0 = isEqualTo(kMod4, bcast(
i32Cst(builder, 0)));
1564 Value kR1 = isEqualTo(kMod4, bcast(
i32Cst(builder, 1)));
1565 Value kR2 = isEqualTo(kMod4, bcast(
i32Cst(builder, 2)));
1566 Value kR3 = isEqualTo(kMod4, bcast(
i32Cst(builder, 3)));
1568 Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
1569 Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(
i32Cst(builder, 1)))
1570 : bitwiseOr(kR1, kR2);
1572 Value y2 = mul(y, y);
1574 Value base = select(sinuseCos, cstOne, y);
1575 Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
1576 Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
1577 Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
1578 Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
1579 Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
1581 Value v1 = fmla(y2, cstC10, cstC8);
1582 Value v2 = fmla(y2, v1, cstC6);
1583 Value v3 = fmla(y2, v2, cstC4);
1584 Value v4 = fmla(y2, v3, cstC2);
1585 Value v5 = fmla(y2, v4, cstOne);
1586 Value v6 = mul(base, v5);
1588 Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6);
1603 LogicalResult matchAndRewrite(math::CbrtOp op,
1611 CbrtApproximation::matchAndRewrite(math::CbrtOp op,
1613 auto operand = op.getOperand();
1618 std::optional<VectorShape> shape =
vectorShape(operand);
1627 auto bconst = [&](TypedAttr attr) ->
Value {
1628 Value value = b.create<arith::ConstantOp>(attr);
1633 Value intTwo = bconst(b.getI32IntegerAttr(2));
1634 Value intFour = bconst(b.getI32IntegerAttr(4));
1635 Value intEight = bconst(b.getI32IntegerAttr(8));
1636 Value intMagic = bconst(b.getI32IntegerAttr(0x2a5137a0));
1637 Value fpThird = bconst(b.getF32FloatAttr(0.33333333f));
1638 Value fpTwo = bconst(b.getF32FloatAttr(2.0f));
1639 Value fpZero = bconst(b.getF32FloatAttr(0.0f));
1645 Value absValue = b.create<math::AbsFOp>(operand);
1646 Value intValue = b.create<arith::BitcastOp>(intTy, absValue);
1647 Value divideBy4 = b.create<arith::ShRSIOp>(intValue, intTwo);
1648 Value divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
1649 intValue = b.create<arith::AddIOp>(divideBy4, divideBy16);
1652 divideBy16 = b.create<arith::ShRSIOp>(intValue, intFour);
1653 intValue = b.create<arith::AddIOp>(intValue, divideBy16);
1656 Value divideBy256 = b.create<arith::ShRSIOp>(intValue, intEight);
1657 intValue = b.create<arith::AddIOp>(intValue, divideBy256);
1660 intValue = b.create<arith::AddIOp>(intValue, intMagic);
1664 Value floatValue = b.create<arith::BitcastOp>(floatTy, intValue);
1665 Value squared = b.create<arith::MulFOp>(floatValue, floatValue);
1666 Value mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
1667 Value divSquared = b.create<arith::DivFOp>(absValue, squared);
1668 floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
1669 floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
1672 squared = b.create<arith::MulFOp>(floatValue, floatValue);
1673 mulTwo = b.create<arith::MulFOp>(floatValue, fpTwo);
1674 divSquared = b.create<arith::DivFOp>(absValue, squared);
1675 floatValue = b.create<arith::AddFOp>(mulTwo, divSquared);
1676 floatValue = b.create<arith::MulFOp>(floatValue, fpThird);
1680 b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absValue, fpZero);
1681 floatValue = b.create<arith::SelectOp>(isZero, fpZero, floatValue);
1682 floatValue = b.create<math::CopySignOp>(floatValue, operand);
1696 LogicalResult matchAndRewrite(math::RsqrtOp op,
1702 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
1707 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
1710 if (!shape || shape->sizes.empty() || shape->sizes.back() % 8 != 0)
1715 return broadcast(builder, value, shape);
1719 Value cstOnePointFive = bcast(
f32Cst(builder, 1.5f));
1723 Value negHalf = builder.
create<arith::MulFOp>(op.getOperand(), cstNegHalf);
1728 arith::CmpFPredicate::OLT, op.getOperand(), cstMinNormPos);
1729 Value infMask = builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
1730 op.getOperand(), cstPosInf);
1731 Value notNormalFiniteMask = builder.
create<arith::OrIOp>(ltMinMask, infMask);
1735 builder, op->getOperands(), 8, [&builder](
ValueRange operands) ->
Value {
1736 return builder.create<x86vector::RsqrtOp>(operands);
1743 Value inner = builder.
create<arith::MulFOp>(negHalf, yApprox);
1744 Value fma = builder.
create<math::FmaOp>(yApprox, inner, cstOnePointFive);
1745 Value yNewton = builder.
create<arith::MulFOp>(yApprox, fma);
1753 builder.
create<arith::SelectOp>(notNormalFiniteMask, yApprox, yNewton);
1776 template <
typename OpType>
1781 if (predicate(OpType::getOperationName())) {
1789 populateMathF32ExpansionPattern<math::AcosOp>(
patterns, predicate, benefit);
1790 populateMathF32ExpansionPattern<math::AcoshOp>(
patterns, predicate, benefit);
1791 populateMathF32ExpansionPattern<math::AsinOp>(
patterns, predicate, benefit);
1792 populateMathF32ExpansionPattern<math::AsinhOp>(
patterns, predicate, benefit);
1793 populateMathF32ExpansionPattern<math::AtanOp>(
patterns, predicate, benefit);
1794 populateMathF32ExpansionPattern<math::Atan2Op>(
patterns, predicate, benefit);
1795 populateMathF32ExpansionPattern<math::AtanhOp>(
patterns, predicate, benefit);
1796 populateMathF32ExpansionPattern<math::CbrtOp>(
patterns, predicate, benefit);
1797 populateMathF32ExpansionPattern<math::CosOp>(
patterns, predicate, benefit);
1798 populateMathF32ExpansionPattern<math::CoshOp>(
patterns, predicate, benefit);
1799 populateMathF32ExpansionPattern<math::ErfOp>(
patterns, predicate, benefit);
1800 populateMathF32ExpansionPattern<math::ErfcOp>(
patterns, predicate, benefit);
1801 populateMathF32ExpansionPattern<math::ExpOp>(
patterns, predicate, benefit);
1802 populateMathF32ExpansionPattern<math::Exp2Op>(
patterns, predicate, benefit);
1803 populateMathF32ExpansionPattern<math::ExpM1Op>(
patterns, predicate, benefit);
1804 populateMathF32ExpansionPattern<math::LogOp>(
patterns, predicate, benefit);
1805 populateMathF32ExpansionPattern<math::Log10Op>(
patterns, predicate, benefit);
1806 populateMathF32ExpansionPattern<math::Log1pOp>(
patterns, predicate, benefit);
1807 populateMathF32ExpansionPattern<math::Log2Op>(
patterns, predicate, benefit);
1808 populateMathF32ExpansionPattern<math::PowFOp>(
patterns, predicate, benefit);
1809 populateMathF32ExpansionPattern<math::RsqrtOp>(
patterns, predicate, benefit);
1810 populateMathF32ExpansionPattern<math::SinOp>(
patterns, predicate, benefit);
1811 populateMathF32ExpansionPattern<math::SinhOp>(
patterns, predicate, benefit);
1812 populateMathF32ExpansionPattern<math::SqrtOp>(
patterns, predicate, benefit);
1813 populateMathF32ExpansionPattern<math::TanOp>(
patterns, predicate, benefit);
1814 populateMathF32ExpansionPattern<math::TanhOp>(
patterns, predicate, benefit);
1817 template <
typename OpType,
typename PatternType>
1821 if (predicate(OpType::getOperationName())) {
1830 AcosPolynomialApproximation>(
1833 AsinPolynomialApproximation>(
1835 populateMathPolynomialApproximationPattern<AtanOp, AtanApproximation>(
1837 populateMathPolynomialApproximationPattern<Atan2Op, Atan2Approximation>(
1839 populateMathPolynomialApproximationPattern<CbrtOp, CbrtApproximation>(
1842 CosOp, SinAndCosApproximation<false, math::CosOp>>(
patterns, predicate,
1844 populateMathPolynomialApproximationPattern<ErfOp, ErfPolynomialApproximation>(
1849 populateMathPolynomialApproximationPattern<ExpOp, ExpApproximation>(
1851 populateMathPolynomialApproximationPattern<ExpM1Op, ExpM1Approximation>(
1853 populateMathPolynomialApproximationPattern<LogOp, LogApproximation>(
1855 populateMathPolynomialApproximationPattern<Log2Op, Log2Approximation>(
1857 populateMathPolynomialApproximationPattern<Log1pOp, Log1pApproximation>(
1859 populateMathPolynomialApproximationPattern<RsqrtOp, RsqrtApproximation>(
1862 SinOp, SinAndCosApproximation<true, math::SinOp>>(
patterns, predicate,
1864 populateMathPolynomialApproximationPattern<TanhOp, TanhApproximation>(
1872 return llvm::is_contained(
1873 {math::AtanOp::getOperationName(), math::Atan2Op::getOperationName(),
1874 math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
1875 math::Log2Op::getOperationName(), math::Log1pOp::getOperationName(),
1876 math::ErfOp::getOperationName(), math::ErfcOp::getOperationName(),
1877 math::ExpOp::getOperationName(), math::ExpM1Op::getOperationName(),
1878 math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
1879 math::CosOp::getOperationName()},
1884 patterns, [](StringRef name) ->
bool {
1885 return llvm::is_contained(
1886 {math::AtanOp::getOperationName(),
1887 math::Atan2Op::getOperationName(),
1888 math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
1889 math::Log2Op::getOperationName(),
1890 math::Log1pOp::getOperationName(), math::ErfOp::getOperationName(),
1891 math::ErfcOp::getOperationName(), math::AsinOp::getOperationName(),
1892 math::AcosOp::getOperationName(), math::ExpOp::getOperationName(),
1893 math::ExpM1Op::getOperationName(),
1894 math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
1895 math::CosOp::getOperationName()},
1900 auto predicateRsqrt = [](StringRef name) {
1901 return name == math::RsqrtOp::getOperationName();
static llvm::ManagedStatic< PassManagerOptions > options
static std::pair< Value, Value > frexp(ImplicitLocOpBuilder &builder, Value arg, bool isPositive=false)
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg)
static void populateMathF32ExpansionPattern(RewritePatternSet &patterns, llvm::function_ref< bool(StringRef)> predicate, PatternBenefit benefit)
static std::optional< VectorShape > vectorShape(Type type)
static Value boolCst(ImplicitLocOpBuilder &builder, bool value)
static Value floatCst(ImplicitLocOpBuilder &builder, float value, Type elementType)
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value)
static Type broadcast(Type type, std::optional< VectorShape > shape)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits)
static Value f32Cst(ImplicitLocOpBuilder &builder, double value)
static void populateMathPolynomialApproximationPattern(RewritePatternSet &patterns, llvm::function_ref< bool(StringRef)> predicate, PatternBenefit benefit)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getI32IntegerAttr(int32_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
TypedAttr getZeroAttr(Type type)
FloatAttr getF32FloatAttr(float value)
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DynamicAPInt floor(const Fraction &f)
Fraction abs(const Fraction &f)
Include the generated interface declarations.
void populatePolynomialApproximateErfcPattern(RewritePatternSet &patterns)
void populateMathF32ExpansionPatterns(RewritePatternSet &patterns, llvm::function_ref< bool(StringRef)> predicate, PatternBenefit=1)
void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns)
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
int64_t computeMaxLinearIndex(ArrayRef< int64_t > basis)
Return the number of elements of basis (i.e.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns, llvm::function_ref< bool(StringRef)> predicate, PatternBenefit=1)
ArrayRef< int64_t > sizes
ArrayRef< bool > scalableFlags
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(math::ErfOp op, PatternRewriter &rewriter) const final
LogicalResult matchAndRewrite(math::ErfcOp op, PatternRewriter &rewriter) const final
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.