32 #include "llvm/ADT/ArrayRef.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/Support/MathExtras.h"
48 if (
auto vectorType = dyn_cast<VectorType>(type)) {
49 return VectorShape{vectorType.getShape(), vectorType.getScalableDims()};
64 assert(!isa<VectorType>(type) &&
"must be scalar type");
65 return shape ?
VectorType::get(shape->sizes, type, shape->scalableFlags)
71 std::optional<VectorShape> shape) {
72 assert(!isa<VectorType>(value.
getType()) &&
"must be scalar value");
74 return shape ? BroadcastOp::create(builder, type, value) : value;
100 assert(!operands.empty() &&
"operands must be not empty");
101 assert(vectorWidth > 0 &&
"vector width must be larger than 0");
103 VectorType inputType = cast<VectorType>(operands[0].
getType());
109 return compute(operands);
113 int64_t innerDim = inputShape.back();
114 int64_t expansionDim = innerDim / vectorWidth;
115 assert((innerDim % vectorWidth == 0) &&
"invalid inner dimension size");
122 if (expansionDim > 1) {
124 expandedShape.insert(expandedShape.end() - 1, expansionDim);
125 expandedShape.back() = vectorWidth;
127 for (
unsigned i = 0; i < operands.size(); ++i) {
128 auto operand = operands[i];
129 auto eltType = cast<VectorType>(operand.getType()).getElementType();
131 expandedOperands[i] =
132 vector::ShapeCastOp::create(builder, expandedType, operand);
144 for (int64_t i = 0; i < maxIndex; ++i) {
149 extracted[tuple.index()] =
150 vector::ExtractOp::create(builder, tuple.value(), offsets);
152 results[i] = compute(extracted);
156 Type resultEltType = cast<VectorType>(results[0].
getType()).getElementType();
158 Value result = arith::ConstantOp::create(
159 builder, resultExpandedType, builder.
getZeroAttr(resultExpandedType));
161 for (int64_t i = 0; i < maxIndex; ++i)
162 result = vector::InsertOp::create(builder, results[i], result,
166 return vector::ShapeCastOp::create(
175 return arith::ConstantOp::create(builder, builder.
getBoolAttr(value));
180 assert((elementType.
isF16() || elementType.
isF32()) &&
181 "x must be f16 or f32 type.");
182 return arith::ConstantOp::create(builder,
187 return arith::ConstantOp::create(builder, builder.
getF32FloatAttr(value));
195 Value i32Value =
i32Cst(builder,
static_cast<int32_t
>(bits));
196 return arith::BitcastOp::create(builder, builder.
getF32Type(), i32Value);
205 return arith::SelectOp::create(
207 arith::CmpFOp::create(builder, arith::CmpFPredicate::ULT, value, bound),
213 return arith::SelectOp::create(
215 arith::CmpFOp::create(builder, arith::CmpFPredicate::UGT, value, bound),
222 return max(builder,
min(builder, value, upperBound), lowerBound);
228 bool isPositive =
false) {
230 std::optional<VectorShape> shape =
vectorShape(arg);
245 Value i32Half = arith::BitcastOp::create(builder, i32, cstHalf);
246 Value i32InvMantMask = arith::BitcastOp::create(builder, i32, cstInvMantMask);
247 Value i32Arg = arith::BitcastOp::create(builder, i32Vec, arg);
250 Value tmp0 = arith::AndIOp::create(builder, i32Arg, bcast(i32InvMantMask));
251 Value tmp1 = arith::OrIOp::create(builder, tmp0, bcast(i32Half));
252 Value normalizedFraction = arith::BitcastOp::create(builder, f32Vec, tmp1);
255 Value arg0 = isPositive ? arg : math::AbsFOp::create(builder, arg);
256 Value biasedExponentBits = arith::ShRUIOp::create(
257 builder, arith::BitcastOp::create(builder, i32Vec, arg0),
258 bcast(
i32Cst(builder, 23)));
259 Value biasedExponent =
260 arith::SIToFPOp::create(builder, f32Vec, biasedExponentBits);
262 arith::SubFOp::create(builder, biasedExponent, bcast(cst126f));
264 return {normalizedFraction, exponent};
270 std::optional<VectorShape> shape =
vectorShape(arg);
278 auto exponetBitLocation = bcast(
i32Cst(builder, 23));
280 auto bias = bcast(
i32Cst(builder, 127));
282 Value biasedArg = arith::AddIOp::create(builder, arg, bias);
284 arith::ShLIOp::create(builder, biasedArg, exponetBitLocation);
285 Value exp2ValueF32 = arith::BitcastOp::create(builder, f32Vec, exp2ValueInt);
294 assert((elementType.
isF32() || elementType.
isF16()) &&
295 "x must be f32 or f16 type");
301 if (coeffs.size() == 1)
304 Value res = math::FmaOp::create(builder, x, coeffs[coeffs.size() - 1],
305 coeffs[coeffs.size() - 2]);
306 for (
auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) {
307 res = math::FmaOp::create(builder, x, res, coeffs[i]);
317 template <
typename T>
335 if (
auto shaped = dyn_cast<ShapedType>(origType)) {
336 newType = shaped.clone(rewriter.
getF32Type());
337 }
else if (isa<FloatType>(origType)) {
341 "unable to find F32 equivalent type");
347 operands.push_back(arith::ExtFOp::create(rewriter, loc, newType, operand));
360 template <
typename T>
364 LogicalResult matchAndRewrite(T op,
PatternRewriter &rewriter)
const final {
366 T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(),
367 "requires same operands and result types");
368 return insertCasts<T>(op, rewriter);
382 LogicalResult matchAndRewrite(math::AtanOp op,
388 AtanApproximation::matchAndRewrite(math::AtanOp op,
390 auto operand = op.getOperand();
394 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
397 Value abs = math::AbsFOp::create(builder, operand);
404 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, abs, twoThirds);
405 Value addone = arith::AddFOp::create(builder, abs, one);
406 Value subone = arith::SubFOp::create(builder, abs, one);
407 Value xnum = arith::SelectOp::create(builder, cmp2, subone, abs);
408 Value xden = arith::SelectOp::create(builder, cmp2, addone, one);
415 auto tan3pio8 = bcast(
f32Cst(builder, 2.41421356237309504880));
417 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, abs, tan3pio8);
418 xnum = arith::SelectOp::create(builder, cmp1, one, xnum);
419 xden = arith::SelectOp::create(builder, cmp1, abs, xden);
421 Value x = arith::DivFOp::create(builder, xnum, xden);
422 Value xx = arith::MulFOp::create(builder, x, x);
426 auto p0 = bcast(
f32Cst(builder, -8.750608600031904122785e-01));
427 auto p1 = bcast(
f32Cst(builder, -1.615753718733365076637e+01));
428 auto p2 = bcast(
f32Cst(builder, -7.500855792314704667340e+01));
429 auto p3 = bcast(
f32Cst(builder, -1.228866684490136173410e+02));
430 auto p4 = bcast(
f32Cst(builder, -6.485021904942025371773e+01));
431 auto q0 = bcast(
f32Cst(builder, +2.485846490142306297962e+01));
432 auto q1 = bcast(
f32Cst(builder, +1.650270098316988542046e+02));
433 auto q2 = bcast(
f32Cst(builder, +4.328810604912902668951e+02));
434 auto q3 = bcast(
f32Cst(builder, +4.853903996359136964868e+02));
435 auto q4 = bcast(
f32Cst(builder, +1.945506571482613964425e+02));
439 n = math::FmaOp::create(builder, xx, n, p1);
440 n = math::FmaOp::create(builder, xx, n, p2);
441 n = math::FmaOp::create(builder, xx, n, p3);
442 n = math::FmaOp::create(builder, xx, n, p4);
443 n = arith::MulFOp::create(builder, n, xx);
447 d = math::FmaOp::create(builder, xx, d, q1);
448 d = math::FmaOp::create(builder, xx, d, q2);
449 d = math::FmaOp::create(builder, xx, d, q3);
450 d = math::FmaOp::create(builder, xx, d, q4);
453 Value ans0 = arith::DivFOp::create(builder, n, d);
454 ans0 = math::FmaOp::create(builder, ans0, x, x);
457 Value mpi4 = bcast(
f32Cst(builder, llvm::numbers::pi / 4));
458 Value ans2 = arith::AddFOp::create(builder, mpi4, ans0);
459 Value ans = arith::SelectOp::create(builder, cmp2, ans2, ans0);
461 Value mpi2 = bcast(
f32Cst(builder, llvm::numbers::pi / 2));
462 Value ans1 = arith::SubFOp::create(builder, mpi2, ans0);
463 ans = arith::SelectOp::create(builder, cmp1, ans1, ans);
479 LogicalResult matchAndRewrite(math::Atan2Op op,
485 Atan2Approximation::matchAndRewrite(math::Atan2Op op,
487 auto y = op.getOperand(0);
488 auto x = op.getOperand(1);
493 std::optional<VectorShape> shape =
vectorShape(op.getResult());
496 auto div = arith::DivFOp::create(builder, y, x);
497 auto atan = math::AtanOp::create(builder, div);
502 auto addPi = arith::AddFOp::create(builder, atan, pi);
503 auto subPi = arith::SubFOp::create(builder, atan, pi);
505 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, atan, zero);
506 auto flippedAtan = arith::SelectOp::create(builder, atanGt, subPi, addPi);
509 auto xGt = arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, x, zero);
510 Value result = arith::SelectOp::create(builder, xGt, atan, flippedAtan);
514 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, x, zero);
516 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, y, zero);
517 Value isHalfPi = arith::AndIOp::create(builder, xZero, yGt);
518 auto halfPi =
broadcast(builder,
f32Cst(builder, 1.57079632679f), shape);
519 result = arith::SelectOp::create(builder, isHalfPi, halfPi, result);
523 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, y, zero);
524 Value isNegativeHalfPiPi = arith::AndIOp::create(builder, xZero, yLt);
525 auto negativeHalfPiPi =
527 result = arith::SelectOp::create(builder, isNegativeHalfPiPi,
528 negativeHalfPiPi, result);
532 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, y, zero);
533 Value isNan = arith::AndIOp::create(builder, xZero, yZero);
535 result = arith::SelectOp::create(builder, isNan, cstNan, result);
550 LogicalResult matchAndRewrite(math::TanhOp op,
556 TanhApproximation::matchAndRewrite(math::TanhOp op,
561 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
569 Value minusClamp = bcast(
f32Cst(builder, -7.99881172180175781f));
570 Value plusClamp = bcast(
f32Cst(builder, 7.99881172180175781f));
571 Value x =
clamp(builder, op.getOperand(), minusClamp, plusClamp);
575 Value tinyMask = arith::CmpFOp::create(
576 builder, arith::CmpFPredicate::OLT,
577 math::AbsFOp::create(builder, op.getOperand()), tiny);
580 Value alpha1 = bcast(
f32Cst(builder, 4.89352455891786e-03f));
581 Value alpha3 = bcast(
f32Cst(builder, 6.37261928875436e-04f));
582 Value alpha5 = bcast(
f32Cst(builder, 1.48572235717979e-05f));
583 Value alpha7 = bcast(
f32Cst(builder, 5.12229709037114e-08f));
584 Value alpha9 = bcast(
f32Cst(builder, -8.60467152213735e-11f));
585 Value alpha11 = bcast(
f32Cst(builder, 2.00018790482477e-13f));
586 Value alpha13 = bcast(
f32Cst(builder, -2.76076847742355e-16f));
589 Value beta0 = bcast(
f32Cst(builder, 4.89352518554385e-03f));
590 Value beta2 = bcast(
f32Cst(builder, 2.26843463243900e-03f));
591 Value beta4 = bcast(
f32Cst(builder, 1.18534705686654e-04f));
592 Value beta6 = bcast(
f32Cst(builder, 1.19825839466702e-06f));
595 Value x2 = arith::MulFOp::create(builder, x, x);
598 Value p = math::FmaOp::create(builder, x2, alpha13, alpha11);
599 p = math::FmaOp::create(builder, x2, p, alpha9);
600 p = math::FmaOp::create(builder, x2, p, alpha7);
601 p = math::FmaOp::create(builder, x2, p, alpha5);
602 p = math::FmaOp::create(builder, x2, p, alpha3);
603 p = math::FmaOp::create(builder, x2, p, alpha1);
604 p = arith::MulFOp::create(builder, x, p);
607 Value q = math::FmaOp::create(builder, x2, beta6, beta4);
608 q = math::FmaOp::create(builder, x2, q, beta2);
609 q = math::FmaOp::create(builder, x2, q, beta0);
612 Value res = arith::SelectOp::create(builder, tinyMask, x,
613 arith::DivFOp::create(builder, p, q));
621 0.693147180559945309417232121458176568075500134360255254120680009493393621L
622 #define LOG2E_VALUE \
623 1.442695040888963407359924681001892137426645954152985934135449406931109219L
630 template <
typename Op>
642 template <
typename Op>
649 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
667 Value cstCephesSQRTHF = bcast(
f32Cst(builder, 0.707106781186547524f));
668 Value cstCephesLogP0 = bcast(
f32Cst(builder, 7.0376836292E-2f));
669 Value cstCephesLogP1 = bcast(
f32Cst(builder, -1.1514610310E-1f));
670 Value cstCephesLogP2 = bcast(
f32Cst(builder, 1.1676998740E-1f));
671 Value cstCephesLogP3 = bcast(
f32Cst(builder, -1.2420140846E-1f));
672 Value cstCephesLogP4 = bcast(
f32Cst(builder, +1.4249322787E-1f));
673 Value cstCephesLogP5 = bcast(
f32Cst(builder, -1.6668057665E-1f));
674 Value cstCephesLogP6 = bcast(
f32Cst(builder, +2.0000714765E-1f));
675 Value cstCephesLogP7 = bcast(
f32Cst(builder, -2.4999993993E-1f));
676 Value cstCephesLogP8 = bcast(
f32Cst(builder, +3.3333331174E-1f));
678 Value x = op.getOperand();
681 x =
max(builder, x, cstMinNormPos);
684 std::pair<Value, Value> pair =
frexp(builder, x,
true);
686 Value e = pair.second;
696 Value mask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, x,
698 Value tmp = arith::SelectOp::create(builder, mask, x, cstZero);
700 x = arith::SubFOp::create(builder, x, cstOne);
701 e = arith::SubFOp::create(
702 builder, e, arith::SelectOp::create(builder, mask, cstOne, cstZero));
703 x = arith::AddFOp::create(builder, x, tmp);
705 Value x2 = arith::MulFOp::create(builder, x, x);
706 Value x3 = arith::MulFOp::create(builder, x2, x);
710 y0 = math::FmaOp::create(builder, cstCephesLogP0, x, cstCephesLogP1);
711 y1 = math::FmaOp::create(builder, cstCephesLogP3, x, cstCephesLogP4);
712 y2 = math::FmaOp::create(builder, cstCephesLogP6, x, cstCephesLogP7);
713 y0 = math::FmaOp::create(builder, y0, x, cstCephesLogP2);
714 y1 = math::FmaOp::create(builder, y1, x, cstCephesLogP5);
715 y2 = math::FmaOp::create(builder, y2, x, cstCephesLogP8);
716 y0 = math::FmaOp::create(builder, y0, x3, y1);
717 y0 = math::FmaOp::create(builder, y0, x3, y2);
718 y0 = arith::MulFOp::create(builder, y0, x3);
720 y0 = math::FmaOp::create(builder, cstNegHalf, x2, y0);
721 x = arith::AddFOp::create(builder, x, y0);
725 x = math::FmaOp::create(builder, x, cstLog2e, e);
728 x = math::FmaOp::create(builder, e, cstLn2, x);
731 Value invalidMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::ULT,
732 op.getOperand(), cstZero);
733 Value zeroMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ,
734 op.getOperand(), cstZero);
735 Value posInfMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ,
736 op.getOperand(), cstPosInf);
742 Value aproximation = arith::SelectOp::create(
743 builder, zeroMask, cstMinusInf,
744 arith::SelectOp::create(
745 builder, invalidMask, cstNan,
746 arith::SelectOp::create(builder, posInfMask, cstPosInf, x)));
754 struct LogApproximation :
public LogApproximationBase<math::LogOp> {
755 using LogApproximationBase::LogApproximationBase;
757 LogicalResult matchAndRewrite(math::LogOp op,
759 return logMatchAndRewrite(op, rewriter,
false);
765 struct Log2Approximation :
public LogApproximationBase<math::Log2Op> {
766 using LogApproximationBase::LogApproximationBase;
768 LogicalResult matchAndRewrite(math::Log2Op op,
770 return logMatchAndRewrite(op, rewriter,
true);
784 LogicalResult matchAndRewrite(math::Log1pOp op,
791 Log1pApproximation::matchAndRewrite(math::Log1pOp op,
796 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
810 Value x = op.getOperand();
811 Value u = arith::AddFOp::create(builder, x, cstOne);
813 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, u, cstOne);
814 Value logU = math::LogOp::create(builder, u);
816 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, u, logU);
817 Value logLarge = arith::MulFOp::create(
819 arith::DivFOp::create(builder, logU,
820 arith::SubFOp::create(builder, u, cstOne)));
821 Value approximation = arith::SelectOp::create(
822 builder, arith::OrIOp::create(builder, uSmall, uInf), x, logLarge);
835 struct AsinPolynomialApproximation :
public OpRewritePattern<math::AsinOp> {
839 LogicalResult matchAndRewrite(math::AsinOp op,
844 AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
846 Value operand = op.getOperand();
849 if (!(elementType.
isF32() || elementType.
isF16()))
851 "only f32 and f16 type is supported.");
852 std::optional<VectorShape> shape =
vectorShape(operand);
860 return math::FmaOp::create(builder, a, b, c);
864 return arith::MulFOp::create(builder, a, b);
868 return arith::SubFOp::create(builder, a, b);
871 auto abs = [&](
Value a) ->
Value {
return math::AbsFOp::create(builder, a); };
874 return math::SqrtOp::create(builder, a);
878 return math::CopySignOp::create(builder, a, b);
882 return arith::SelectOp::create(builder, a, b, c);
886 Value aa = mul(operand, operand);
887 Value opp = sqrt(sub(bcast(
floatCst(builder, 1.0, elementType)), aa));
889 Value gt = arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, aa,
890 bcast(
floatCst(builder, 0.5, elementType)));
892 Value x = sel(gt, opp, abso);
897 Value r = bcast(
floatCst(builder, 5.5579749017470502e-2, elementType));
898 Value t = bcast(
floatCst(builder, -6.2027913464120114e-2, elementType));
900 r = fma(r, q, bcast(
floatCst(builder, 5.4224464349245036e-2, elementType)));
901 t = fma(t, q, bcast(
floatCst(builder, -1.1326992890324464e-2, elementType)));
902 r = fma(r, q, bcast(
floatCst(builder, 1.5268872539397656e-2, elementType)));
903 t = fma(t, q, bcast(
floatCst(builder, 1.0493798473372081e-2, elementType)));
904 r = fma(r, q, bcast(
floatCst(builder, 1.4106045900607047e-2, elementType)));
905 t = fma(t, q, bcast(
floatCst(builder, 1.7339776384962050e-2, elementType)));
906 r = fma(r, q, bcast(
floatCst(builder, 2.2372961589651054e-2, elementType)));
907 t = fma(t, q, bcast(
floatCst(builder, 3.0381912707941005e-2, elementType)));
908 r = fma(r, q, bcast(
floatCst(builder, 4.4642857881094775e-2, elementType)));
909 t = fma(t, q, bcast(
floatCst(builder, 7.4999999991367292e-2, elementType)));
911 r = fma(r, s, bcast(
floatCst(builder, 1.6666666666670193e-1, elementType)));
915 Value rsub = sub(bcast(
floatCst(builder, 1.57079632679, elementType)), r);
916 r = sel(gt, rsub, r);
917 r = scopy(r, operand);
931 struct AcosPolynomialApproximation :
public OpRewritePattern<math::AcosOp> {
935 LogicalResult matchAndRewrite(math::AcosOp op,
940 AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
942 Value operand = op.getOperand();
945 if (!(elementType.
isF32() || elementType.
isF16()))
947 "only f32 and f16 type is supported.");
948 std::optional<VectorShape> shape =
vectorShape(operand);
956 return math::FmaOp::create(builder, a, b, c);
960 return arith::MulFOp::create(builder, a, b);
963 Value negOperand = arith::NegFOp::create(builder, operand);
968 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, operand, zero);
969 Value r = arith::SelectOp::create(builder, selR, negOperand, operand);
970 Value chkConst = bcast(
floatCst(builder, -0.5625, elementType));
972 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, r, chkConst);
975 fma(bcast(
floatCst(builder, 9.3282184640716537e-1, elementType)),
976 bcast(
floatCst(builder, 1.6839188885261840e+0, elementType)),
977 math::AsinOp::create(builder, r));
979 Value falseVal = math::SqrtOp::create(builder, fma(half, r, half));
980 falseVal = math::AsinOp::create(builder, falseVal);
981 falseVal = mul(bcast(
floatCst(builder, 2.0, elementType)), falseVal);
983 r = arith::SelectOp::create(builder, firstPred, trueVal, falseVal);
986 Value greaterThanNegOne = arith::CmpFOp::create(
987 builder, arith::CmpFPredicate::OGE, operand, negOne);
990 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, operand, zero);
992 Value betweenNegOneZero =
993 arith::AndIOp::create(builder, greaterThanNegOne, lessThanZero);
995 trueVal = fma(bcast(
floatCst(builder, 1.8656436928143307e+0, elementType)),
996 bcast(
floatCst(builder, 1.6839188885261840e+0, elementType)),
997 arith::NegFOp::create(builder, r));
1000 arith::SelectOp::create(builder, betweenNegOneZero, trueVal, r);
1020 Value operand = op.getOperand();
1023 if (!(elementType.
isF32() || elementType.
isF16()))
1025 "only f32 and f16 type is supported.");
1026 std::optional<VectorShape> shape =
vectorShape(operand);
1030 return broadcast(builder, value, shape);
1033 const int intervalsCount = 3;
1034 const int polyDegree = 4;
1038 Value pp[intervalsCount][polyDegree + 1];
1039 pp[0][0] = bcast(
floatCst(builder, +0.00000000000000000e+00f, elementType));
1040 pp[0][1] = bcast(
floatCst(builder, +1.12837916222975858e+00f, elementType));
1041 pp[0][2] = bcast(
floatCst(builder, -5.23018562988006470e-01f, elementType));
1042 pp[0][3] = bcast(
floatCst(builder, +2.09741709609267072e-01f, elementType));
1043 pp[0][4] = bcast(
floatCst(builder, +2.58146801602987875e-02f, elementType));
1044 pp[1][0] = bcast(
floatCst(builder, +0.00000000000000000e+00f, elementType));
1045 pp[1][1] = bcast(
floatCst(builder, +1.12750687816789140e+00f, elementType));
1046 pp[1][2] = bcast(
floatCst(builder, -3.64721408487825775e-01f, elementType));
1047 pp[1][3] = bcast(
floatCst(builder, +1.18407396425136952e-01f, elementType));
1048 pp[1][4] = bcast(
floatCst(builder, +3.70645533056476558e-02f, elementType));
1049 pp[2][0] = bcast(
floatCst(builder, -3.30093071049483172e-03f, elementType));
1050 pp[2][1] = bcast(
floatCst(builder, +3.51961938357697011e-03f, elementType));
1051 pp[2][2] = bcast(
floatCst(builder, -1.41373622814988039e-03f, elementType));
1052 pp[2][3] = bcast(
floatCst(builder, +2.53447094961941348e-04f, elementType));
1053 pp[2][4] = bcast(
floatCst(builder, -1.71048029455037401e-05f, elementType));
1055 Value qq[intervalsCount][polyDegree + 1];
1056 qq[0][0] = bcast(
floatCst(builder, +1.000000000000000000e+00f, elementType));
1057 qq[0][1] = bcast(
floatCst(builder, -4.635138185962547255e-01f, elementType));
1058 qq[0][2] = bcast(
floatCst(builder, +5.192301327279782447e-01f, elementType));
1059 qq[0][3] = bcast(
floatCst(builder, -1.318089722204810087e-01f, elementType));
1060 qq[0][4] = bcast(
floatCst(builder, +7.397964654672315005e-02f, elementType));
1061 qq[1][0] = bcast(
floatCst(builder, +1.00000000000000000e+00f, elementType));
1062 qq[1][1] = bcast(
floatCst(builder, -3.27607011824493086e-01f, elementType));
1063 qq[1][2] = bcast(
floatCst(builder, +4.48369090658821977e-01f, elementType));
1064 qq[1][3] = bcast(
floatCst(builder, -8.83462621207857930e-02f, elementType));
1065 qq[1][4] = bcast(
floatCst(builder, +5.72442770283176093e-02f, elementType));
1066 qq[2][0] = bcast(
floatCst(builder, +1.00000000000000000e+00f, elementType));
1067 qq[2][1] = bcast(
floatCst(builder, -2.06069165953913769e+00f, elementType));
1068 qq[2][2] = bcast(
floatCst(builder, +1.62705939945477759e+00f, elementType));
1069 qq[2][3] = bcast(
floatCst(builder, -5.83389859211130017e-01f, elementType));
1070 qq[2][4] = bcast(
floatCst(builder, +8.21908939856640930e-02f, elementType));
1072 Value offsets[intervalsCount];
1073 offsets[0] = bcast(
floatCst(builder, 0.0f, elementType));
1074 offsets[1] = bcast(
floatCst(builder, 0.0f, elementType));
1075 offsets[2] = bcast(
floatCst(builder, 1.0f, elementType));
1077 Value bounds[intervalsCount];
1078 bounds[0] = bcast(
floatCst(builder, 0.8f, elementType));
1079 bounds[1] = bcast(
floatCst(builder, 2.0f, elementType));
1080 bounds[2] = bcast(
floatCst(builder, 3.75f, elementType));
1082 Value isNegativeArg =
1083 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, operand, zero);
1084 Value negArg = arith::NegFOp::create(builder, operand);
1085 Value x = arith::SelectOp::create(builder, isNegativeArg, negArg, operand);
1087 Value offset = offsets[0];
1088 Value p[polyDegree + 1];
1089 Value q[polyDegree + 1];
1090 for (
int i = 0; i <= polyDegree; ++i) {
1096 Value isLessThanBound[intervalsCount];
1097 for (
int j = 0;
j < intervalsCount - 1; ++
j) {
1098 isLessThanBound[
j] =
1099 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, x, bounds[
j]);
1100 for (
int i = 0; i <= polyDegree; ++i) {
1101 p[i] = arith::SelectOp::create(builder, isLessThanBound[
j], p[i],
1103 q[i] = arith::SelectOp::create(builder, isLessThanBound[
j], q[i],
1106 offset = arith::SelectOp::create(builder, isLessThanBound[
j], offset,
1109 isLessThanBound[intervalsCount - 1] = arith::CmpFOp::create(
1110 builder, arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
1112 Value pPoly = makePolynomialCalculation(builder, p, x);
1113 Value qPoly = makePolynomialCalculation(builder, q, x);
1114 Value rationalPoly = arith::DivFOp::create(builder, pPoly, qPoly);
1115 Value formula = arith::AddFOp::create(builder, offset, rationalPoly);
1116 formula = arith::SelectOp::create(
1117 builder, isLessThanBound[intervalsCount - 1], formula, one);
1120 Value negFormula = arith::NegFOp::create(builder, formula);
1122 arith::SelectOp::create(builder, isNegativeArg, negFormula, formula);
1141 Value x = op.getOperand();
1146 std::optional<VectorShape> shape =
vectorShape(x);
1150 return broadcast(builder, value, shape);
1163 Value a = math::AbsFOp::create(builder, x);
1164 Value p = arith::AddFOp::create(builder, a, pos2);
1165 Value r = arith::DivFOp::create(builder, one, p);
1166 Value q = math::FmaOp::create(builder, neg4, r, one);
1167 Value t = math::FmaOp::create(builder, arith::AddFOp::create(builder, q, one),
1170 math::FmaOp::create(builder, arith::NegFOp::create(builder, a), q, t);
1171 q = math::FmaOp::create(builder, r, e, q);
1173 p = bcast(
floatCst(builder, -0x1.a4a000p-12f, et));
1175 p = math::FmaOp::create(builder, p, q, c1);
1177 p = math::FmaOp::create(builder, p, q, c2);
1179 p = math::FmaOp::create(builder, p, q, c3);
1181 p = math::FmaOp::create(builder, p, q, c4);
1183 p = math::FmaOp::create(builder, p, q, c5);
1185 p = math::FmaOp::create(builder, p, q, c6);
1187 p = math::FmaOp::create(builder, p, q, c7);
1189 p = math::FmaOp::create(builder, p, q, c8);
1191 p = math::FmaOp::create(builder, p, q, c9);
1193 Value d = math::FmaOp::create(builder, pos2, a, one);
1194 r = arith::DivFOp::create(builder, one, d);
1195 q = math::FmaOp::create(builder, p, r, r);
1196 Value negfa = arith::NegFOp::create(builder, a);
1197 Value fmaqah = math::FmaOp::create(builder, q, negfa, onehalf);
1198 Value psubq = arith::SubFOp::create(builder, p, q);
1199 e = math::FmaOp::create(builder, fmaqah, pos2, psubq);
1200 r = math::FmaOp::create(builder, e, r, q);
1202 Value s = arith::MulFOp::create(builder, a, a);
1203 e = math::ExpOp::create(builder, arith::NegFOp::create(builder, s));
1205 t = math::FmaOp::create(builder, arith::NegFOp::create(builder, a), a, s);
1206 r = math::FmaOp::create(
1208 arith::MulFOp::create(builder, arith::MulFOp::create(builder, r, e), t));
1210 Value isNotLessThanInf = arith::XOrIOp::create(
1212 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, a, posInf),
1214 r = arith::SelectOp::create(builder, isNotLessThanInf,
1215 arith::AddFOp::create(builder, x, x), r);
1216 Value isGreaterThanClamp =
1217 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, a, clampVal);
1218 r = arith::SelectOp::create(builder, isGreaterThanClamp, zero, r);
1221 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, x, zero);
1222 r = arith::SelectOp::create(builder, isNegative,
1223 arith::SubFOp::create(builder, pos2, r), r);
1235 const std::optional<VectorShape> shape,
Value value,
1236 float lowerBound,
float upperBound) {
1237 assert(!std::isnan(lowerBound));
1238 assert(!std::isnan(upperBound));
1241 return broadcast(builder, value, shape);
1244 auto selectCmp = [&builder](
auto pred,
Value value,
Value bound) {
1245 return arith::SelectOp::create(
1246 builder, arith::CmpFOp::create(builder, pred, value, bound), value,
1253 value = selectCmp(arith::CmpFPredicate::UGE, value,
1254 bcast(
f32Cst(builder, lowerBound)));
1255 value = selectCmp(arith::CmpFPredicate::ULE, value,
1256 bcast(
f32Cst(builder, upperBound)));
1264 LogicalResult matchAndRewrite(math::ExpOp op,
1269 ExpApproximation::matchAndRewrite(math::ExpOp op,
1271 auto shape =
vectorShape(op.getOperand().getType());
1273 if (!elementTy.isF32())
1279 return arith::AddFOp::create(builder, a, b);
1282 return broadcast(builder, value, shape);
1284 auto floor = [&](
Value a) {
return math::FloorOp::create(builder, a); };
1286 return math::FmaOp::create(builder, a, b, c);
1289 return arith::MulFOp::create(builder, a, b);
1317 Value cstLog2ef = bcast(
f32Cst(builder, 1.44269504088896341f));
1319 Value cstExpC1 = bcast(
f32Cst(builder, -0.693359375f));
1320 Value cstExpC2 = bcast(
f32Cst(builder, 2.12194440e-4f));
1321 Value cstExpP0 = bcast(
f32Cst(builder, 1.9875691500E-4f));
1322 Value cstExpP1 = bcast(
f32Cst(builder, 1.3981999507E-3f));
1323 Value cstExpP2 = bcast(
f32Cst(builder, 8.3334519073E-3f));
1324 Value cstExpP3 = bcast(
f32Cst(builder, 4.1665795894E-2f));
1325 Value cstExpP4 = bcast(
f32Cst(builder, 1.6666665459E-1f));
1326 Value cstExpP5 = bcast(
f32Cst(builder, 5.0000001201E-1f));
1333 Value x = op.getOperand();
1334 x = clampWithNormals(builder, shape, x, -87.8f, 88.8f);
1335 Value n =
floor(fmla(x, cstLog2ef, cstHalf));
1376 n = clampWithNormals(builder, shape, n, -127.0f, 127.0f);
1379 x = fmla(cstExpC1, n, x);
1380 x = fmla(cstExpC2, n, x);
1383 Value z = fmla(x, cstExpP0, cstExpP1);
1384 z = fmla(z, x, cstExpP2);
1385 z = fmla(z, x, cstExpP3);
1386 z = fmla(z, x, cstExpP4);
1387 z = fmla(z, x, cstExpP5);
1388 z = fmla(z, mul(x, x), x);
1393 Value nI32 = arith::FPToSIOp::create(builder, i32Vec, n);
1399 Value ret = mul(z, pow2);
1402 return mlir::success();
1417 LogicalResult matchAndRewrite(math::ExpM1Op op,
1423 ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
1428 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
1432 return broadcast(builder, value, shape);
1440 Value x = op.getOperand();
1441 Value u = math::ExpOp::create(builder, x);
1443 arith::CmpFOp::create(builder, arith::CmpFPredicate::UEQ, u, cstOne);
1444 Value uMinusOne = arith::SubFOp::create(builder, u, cstOne);
1445 Value uMinusOneEqNegOne = arith::CmpFOp::create(
1446 builder, arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
1448 Value logU = math::LogOp::create(builder, u);
1452 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, logU, u);
1455 Value expm1 = arith::MulFOp::create(builder, uMinusOne,
1456 arith::DivFOp::create(builder, x, logU));
1457 expm1 = arith::SelectOp::create(builder, isInf, u, expm1);
1458 Value approximation = arith::SelectOp::create(
1459 builder, uEqOneOrNaN, x,
1460 arith::SelectOp::create(builder, uMinusOneEqNegOne, cstNegOne, expm1));
1471 template <
bool isSine,
typename OpTy>
1476 LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter)
const final;
1480 #define TWO_OVER_PI \
1481 0.6366197723675813430755350534900574481378385829618257949906693762L
1483 1.5707963267948966192313216916397514420985846996875529104874722961L
1488 template <
bool isSine,
typename OpTy>
1489 LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
1492 llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
1493 "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
1498 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
1502 return broadcast(builder, value, shape);
1505 return arith::MulFOp::create(builder, a, b);
1508 return arith::SubFOp::create(builder, a, b);
1510 auto floor = [&](
Value a) {
return math::FloorOp::create(builder, a); };
1513 auto fPToSingedInteger = [&](
Value a) ->
Value {
1514 return arith::FPToSIOp::create(builder, i32Vec, a);
1518 return arith::AndIOp::create(builder, a, bcast(
i32Cst(builder, 3)));
1522 return arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, a, b);
1526 return arith::CmpIOp::create(builder, arith::CmpIPredicate::sgt, a, b);
1530 return arith::SelectOp::create(builder, cond, t, f);
1534 return math::FmaOp::create(builder, a, b, c);
1538 return arith::OrIOp::create(builder, a, b);
1544 Value x = op.getOperand();
1548 Value y = sub(x, mul(k, piOverTwo));
1551 Value cstNegativeOne = bcast(
f32Cst(builder, -1.0));
1553 Value cstSC2 = bcast(
f32Cst(builder, -0.16666667163372039794921875f));
1554 Value cstSC4 = bcast(
f32Cst(builder, 8.333347737789154052734375e-3f));
1555 Value cstSC6 = bcast(
f32Cst(builder, -1.9842604524455964565277099609375e-4f));
1557 bcast(
f32Cst(builder, 2.760012648650445044040679931640625e-6f));
1559 bcast(
f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f));
1562 Value cstCC4 = bcast(
f32Cst(builder, 4.166664183139801025390625e-2f));
1563 Value cstCC6 = bcast(
f32Cst(builder, -1.388833043165504932403564453125e-3f));
1564 Value cstCC8 = bcast(
f32Cst(builder, 2.47562347794882953166961669921875e-5f));
1566 bcast(
f32Cst(builder, -2.59630184018533327616751194000244140625e-7f));
1568 Value kMod4 = modulo4(fPToSingedInteger(k));
1570 Value kR0 = isEqualTo(kMod4, bcast(
i32Cst(builder, 0)));
1571 Value kR1 = isEqualTo(kMod4, bcast(
i32Cst(builder, 1)));
1572 Value kR2 = isEqualTo(kMod4, bcast(
i32Cst(builder, 2)));
1573 Value kR3 = isEqualTo(kMod4, bcast(
i32Cst(builder, 3)));
1575 Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
1576 Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(
i32Cst(builder, 1)))
1577 : bitwiseOr(kR1, kR2);
1579 Value y2 = mul(y, y);
1581 Value base = select(sinuseCos, cstOne, y);
1582 Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
1583 Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
1584 Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
1585 Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
1586 Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
1588 Value v1 = fmla(y2, cstC10, cstC8);
1589 Value v2 = fmla(y2, v1, cstC6);
1590 Value v3 = fmla(y2, v2, cstC4);
1591 Value v4 = fmla(y2, v3, cstC2);
1592 Value v5 = fmla(y2, v4, cstOne);
1593 Value v6 = mul(base, v5);
1595 Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6);
1610 LogicalResult matchAndRewrite(math::CbrtOp op,
1618 CbrtApproximation::matchAndRewrite(math::CbrtOp op,
1620 auto operand = op.getOperand();
1625 std::optional<VectorShape> shape =
vectorShape(operand);
1634 auto bconst = [&](TypedAttr attr) ->
Value {
1635 Value value = arith::ConstantOp::create(b, attr);
1640 Value intTwo = bconst(b.getI32IntegerAttr(2));
1641 Value intFour = bconst(b.getI32IntegerAttr(4));
1642 Value intEight = bconst(b.getI32IntegerAttr(8));
1643 Value intMagic = bconst(b.getI32IntegerAttr(0x2a5137a0));
1644 Value fpThird = bconst(b.getF32FloatAttr(0.33333333f));
1645 Value fpTwo = bconst(b.getF32FloatAttr(2.0f));
1646 Value fpZero = bconst(b.getF32FloatAttr(0.0f));
1652 Value absValue = math::AbsFOp::create(b, operand);
1653 Value intValue = arith::BitcastOp::create(b, intTy, absValue);
1654 Value divideBy4 = arith::ShRSIOp::create(b, intValue, intTwo);
1655 Value divideBy16 = arith::ShRSIOp::create(b, intValue, intFour);
1656 intValue = arith::AddIOp::create(b, divideBy4, divideBy16);
1659 divideBy16 = arith::ShRSIOp::create(b, intValue, intFour);
1660 intValue = arith::AddIOp::create(b, intValue, divideBy16);
1663 Value divideBy256 = arith::ShRSIOp::create(b, intValue, intEight);
1664 intValue = arith::AddIOp::create(b, intValue, divideBy256);
1667 intValue = arith::AddIOp::create(b, intValue, intMagic);
1671 Value floatValue = arith::BitcastOp::create(b, floatTy, intValue);
1672 Value squared = arith::MulFOp::create(b, floatValue, floatValue);
1673 Value mulTwo = arith::MulFOp::create(b, floatValue, fpTwo);
1674 Value divSquared = arith::DivFOp::create(b, absValue, squared);
1675 floatValue = arith::AddFOp::create(b, mulTwo, divSquared);
1676 floatValue = arith::MulFOp::create(b, floatValue, fpThird);
1679 squared = arith::MulFOp::create(b, floatValue, floatValue);
1680 mulTwo = arith::MulFOp::create(b, floatValue, fpTwo);
1681 divSquared = arith::DivFOp::create(b, absValue, squared);
1682 floatValue = arith::AddFOp::create(b, mulTwo, divSquared);
1683 floatValue = arith::MulFOp::create(b, floatValue, fpThird);
1687 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absValue, fpZero);
1688 floatValue = arith::SelectOp::create(b, isZero, fpZero, floatValue);
1689 floatValue = math::CopySignOp::create(b, floatValue, operand);
1703 LogicalResult matchAndRewrite(math::RsqrtOp op,
1709 RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
1714 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
1717 if (!shape || shape->sizes.empty() || shape->sizes.back() % 8 != 0)
1722 return broadcast(builder, value, shape);
1726 Value cstOnePointFive = bcast(
f32Cst(builder, 1.5f));
1730 Value negHalf = arith::MulFOp::create(builder, op.getOperand(), cstNegHalf);
1734 Value ltMinMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT,
1735 op.getOperand(), cstMinNormPos);
1736 Value infMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ,
1737 op.getOperand(), cstPosInf);
1738 Value notNormalFiniteMask = arith::OrIOp::create(builder, ltMinMask, infMask);
1742 builder, op->getOperands(), 8, [&builder](
ValueRange operands) ->
Value {
1743 return x86vector::RsqrtOp::create(builder, operands);
1750 Value inner = arith::MulFOp::create(builder, negHalf, yApprox);
1751 Value fma = math::FmaOp::create(builder, yApprox, inner, cstOnePointFive);
1752 Value yNewton = arith::MulFOp::create(builder, yApprox, fma);
1760 arith::SelectOp::create(builder, notNormalFiniteMask, yApprox, yNewton);
1783 template <
typename OpType>
1788 if (predicate(OpType::getOperationName())) {
1796 populateMathF32ExpansionPattern<math::AcosOp>(
patterns, predicate, benefit);
1797 populateMathF32ExpansionPattern<math::AcoshOp>(
patterns, predicate, benefit);
1798 populateMathF32ExpansionPattern<math::AsinOp>(
patterns, predicate, benefit);
1799 populateMathF32ExpansionPattern<math::AsinhOp>(
patterns, predicate, benefit);
1800 populateMathF32ExpansionPattern<math::AtanOp>(
patterns, predicate, benefit);
1801 populateMathF32ExpansionPattern<math::Atan2Op>(
patterns, predicate, benefit);
1802 populateMathF32ExpansionPattern<math::AtanhOp>(
patterns, predicate, benefit);
1803 populateMathF32ExpansionPattern<math::CbrtOp>(
patterns, predicate, benefit);
1804 populateMathF32ExpansionPattern<math::CosOp>(
patterns, predicate, benefit);
1805 populateMathF32ExpansionPattern<math::CoshOp>(
patterns, predicate, benefit);
1806 populateMathF32ExpansionPattern<math::ErfOp>(
patterns, predicate, benefit);
1807 populateMathF32ExpansionPattern<math::ErfcOp>(
patterns, predicate, benefit);
1808 populateMathF32ExpansionPattern<math::ExpOp>(
patterns, predicate, benefit);
1809 populateMathF32ExpansionPattern<math::Exp2Op>(
patterns, predicate, benefit);
1810 populateMathF32ExpansionPattern<math::ExpM1Op>(
patterns, predicate, benefit);
1811 populateMathF32ExpansionPattern<math::LogOp>(
patterns, predicate, benefit);
1812 populateMathF32ExpansionPattern<math::Log10Op>(
patterns, predicate, benefit);
1813 populateMathF32ExpansionPattern<math::Log1pOp>(
patterns, predicate, benefit);
1814 populateMathF32ExpansionPattern<math::Log2Op>(
patterns, predicate, benefit);
1815 populateMathF32ExpansionPattern<math::PowFOp>(
patterns, predicate, benefit);
1816 populateMathF32ExpansionPattern<math::RsqrtOp>(
patterns, predicate, benefit);
1817 populateMathF32ExpansionPattern<math::SinOp>(
patterns, predicate, benefit);
1818 populateMathF32ExpansionPattern<math::SinhOp>(
patterns, predicate, benefit);
1819 populateMathF32ExpansionPattern<math::SqrtOp>(
patterns, predicate, benefit);
1820 populateMathF32ExpansionPattern<math::TanOp>(
patterns, predicate, benefit);
1821 populateMathF32ExpansionPattern<math::TanhOp>(
patterns, predicate, benefit);
1824 template <
typename OpType,
typename PatternType>
1828 if (predicate(OpType::getOperationName())) {
1837 AcosPolynomialApproximation>(
1840 AsinPolynomialApproximation>(
1842 populateMathPolynomialApproximationPattern<AtanOp, AtanApproximation>(
1844 populateMathPolynomialApproximationPattern<Atan2Op, Atan2Approximation>(
1846 populateMathPolynomialApproximationPattern<CbrtOp, CbrtApproximation>(
1849 CosOp, SinAndCosApproximation<false, math::CosOp>>(
patterns, predicate,
1851 populateMathPolynomialApproximationPattern<ErfOp, ErfPolynomialApproximation>(
1856 populateMathPolynomialApproximationPattern<ExpOp, ExpApproximation>(
1858 populateMathPolynomialApproximationPattern<ExpM1Op, ExpM1Approximation>(
1860 populateMathPolynomialApproximationPattern<LogOp, LogApproximation>(
1862 populateMathPolynomialApproximationPattern<Log2Op, Log2Approximation>(
1864 populateMathPolynomialApproximationPattern<Log1pOp, Log1pApproximation>(
1866 populateMathPolynomialApproximationPattern<RsqrtOp, RsqrtApproximation>(
1869 SinOp, SinAndCosApproximation<true, math::SinOp>>(
patterns, predicate,
1871 populateMathPolynomialApproximationPattern<TanhOp, TanhApproximation>(
1879 return llvm::is_contained(
1880 {math::AtanOp::getOperationName(), math::Atan2Op::getOperationName(),
1881 math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
1882 math::Log2Op::getOperationName(), math::Log1pOp::getOperationName(),
1883 math::ErfOp::getOperationName(), math::ErfcOp::getOperationName(),
1884 math::ExpOp::getOperationName(), math::ExpM1Op::getOperationName(),
1885 math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
1886 math::CosOp::getOperationName()},
1891 patterns, [](StringRef name) ->
bool {
1892 return llvm::is_contained(
1893 {math::AtanOp::getOperationName(),
1894 math::Atan2Op::getOperationName(),
1895 math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
1896 math::Log2Op::getOperationName(),
1897 math::Log1pOp::getOperationName(), math::ErfOp::getOperationName(),
1898 math::ErfcOp::getOperationName(), math::AsinOp::getOperationName(),
1899 math::AcosOp::getOperationName(), math::ExpOp::getOperationName(),
1900 math::ExpM1Op::getOperationName(),
1901 math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
1902 math::CosOp::getOperationName()},
1907 auto predicateRsqrt = [](StringRef name) {
1908 return name == math::RsqrtOp::getOperationName();
static llvm::ManagedStatic< PassManagerOptions > options
static std::pair< Value, Value > frexp(ImplicitLocOpBuilder &builder, Value arg, bool isPositive=false)
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg)
static void populateMathF32ExpansionPattern(RewritePatternSet &patterns, llvm::function_ref< bool(StringRef)> predicate, PatternBenefit benefit)
static std::optional< VectorShape > vectorShape(Type type)
static Value boolCst(ImplicitLocOpBuilder &builder, bool value)
static Value floatCst(ImplicitLocOpBuilder &builder, float value, Type elementType)
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value)
static Type broadcast(Type type, std::optional< VectorShape > shape)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits)
static Value f32Cst(ImplicitLocOpBuilder &builder, double value)
static void populateMathPolynomialApproximationPattern(RewritePatternSet &patterns, llvm::function_ref< bool(StringRef)> predicate, PatternBenefit benefit)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getI32IntegerAttr(int32_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
TypedAttr getZeroAttr(Type type)
FloatAttr getF32FloatAttr(float value)
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DynamicAPInt floor(const Fraction &f)
Fraction abs(const Fraction &f)
Include the generated interface declarations.
void populatePolynomialApproximateErfcPattern(RewritePatternSet &patterns)
void populateMathF32ExpansionPatterns(RewritePatternSet &patterns, llvm::function_ref< bool(StringRef)> predicate, PatternBenefit=1)
void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns)
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
int64_t computeMaxLinearIndex(ArrayRef< int64_t > basis)
Return the number of elements of basis (i.e.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns, llvm::function_ref< bool(StringRef)> predicate, PatternBenefit=1)
ArrayRef< int64_t > sizes
ArrayRef< bool > scalableFlags
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(math::ErfOp op, PatternRewriter &rewriter) const final
LogicalResult matchAndRewrite(math::ErfcOp op, PatternRewriter &rewriter) const final
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.