32#include "llvm/ADT/ArrayRef.h"
33#include "llvm/ADT/STLExtras.h"
34#include "llvm/Support/MathExtras.h"
48 if (
auto vectorType = dyn_cast<VectorType>(type)) {
49 return VectorShape{vectorType.getShape(), vectorType.getScalableDims()};
64 assert(!isa<VectorType>(type) &&
"must be scalar type");
65 return shape ? VectorType::get(
shape->sizes, type,
shape->scalableFlags)
71 std::optional<VectorShape>
shape) {
72 assert(!isa<VectorType>(value.
getType()) &&
"must be scalar value");
74 return shape ? BroadcastOp::create(builder, type, value) : value;
100 assert(!operands.empty() &&
"operands must be not empty");
101 assert(vectorWidth > 0 &&
"vector width must be larger than 0");
103 VectorType inputType = cast<VectorType>(operands[0].
getType());
109 return compute(operands);
113 int64_t innerDim = inputShape.back();
114 int64_t expansionDim = innerDim / vectorWidth;
115 assert((innerDim % vectorWidth == 0) &&
"invalid inner dimension size");
122 if (expansionDim > 1) {
124 expandedShape.insert(expandedShape.end() - 1, expansionDim);
125 expandedShape.back() = vectorWidth;
127 for (
unsigned i = 0; i < operands.size(); ++i) {
128 auto operand = operands[i];
129 auto eltType = cast<VectorType>(operand.getType()).getElementType();
130 auto expandedType = VectorType::get(expandedShape, eltType);
131 expandedOperands[i] =
132 vector::ShapeCastOp::create(builder, expandedType, operand);
144 for (
int64_t i = 0; i < maxIndex; ++i) {
148 for (
const auto &tuple : llvm::enumerate(expandedOperands))
149 extracted[tuple.index()] =
150 vector::ExtractOp::create(builder, tuple.value(), offsets);
152 results[i] = compute(extracted);
156 Type resultEltType = cast<VectorType>(results[0].
getType()).getElementType();
157 Type resultExpandedType = VectorType::get(expandedShape, resultEltType);
159 builder, resultExpandedType, builder.
getZeroAttr(resultExpandedType));
161 for (
int64_t i = 0; i < maxIndex; ++i)
162 result = vector::InsertOp::create(builder, results[i],
result,
166 return vector::ShapeCastOp::create(
167 builder, VectorType::get(inputShape, resultEltType),
result);
175 return arith::ConstantOp::create(builder, builder.
getBoolAttr(value));
180 assert((elementType.
isF16() || elementType.
isF32()) &&
181 "x must be f16 or f32 type.");
182 return arith::ConstantOp::create(builder,
187 return arith::ConstantOp::create(builder, builder.
getF32FloatAttr(value));
195 Value i32Value =
i32Cst(builder,
static_cast<int32_t
>(bits));
196 return arith::BitcastOp::create(builder, builder.
getF32Type(), i32Value);
205 return arith::SelectOp::create(
207 arith::CmpFOp::create(builder, arith::CmpFPredicate::ULT, value, bound),
213 return arith::SelectOp::create(
215 arith::CmpFOp::create(builder, arith::CmpFPredicate::UGT, value, bound),
222 return max(builder,
min(builder, value, upperBound), lowerBound);
228 bool isPositive =
false) {
245 Value i32Half = arith::BitcastOp::create(builder, i32, cstHalf);
246 Value i32InvMantMask = arith::BitcastOp::create(builder, i32, cstInvMantMask);
247 Value i32Arg = arith::BitcastOp::create(builder, i32Vec, arg);
250 Value tmp0 = arith::AndIOp::create(builder, i32Arg, bcast(i32InvMantMask));
251 Value tmp1 = arith::OrIOp::create(builder, tmp0, bcast(i32Half));
252 Value normalizedFraction = arith::BitcastOp::create(builder, f32Vec, tmp1);
255 Value arg0 = isPositive ? arg : math::AbsFOp::create(builder, arg);
256 Value biasedExponentBits = arith::ShRUIOp::create(
257 builder, arith::BitcastOp::create(builder, i32Vec, arg0),
258 bcast(
i32Cst(builder, 23)));
259 Value biasedExponent =
260 arith::SIToFPOp::create(builder, f32Vec, biasedExponentBits);
262 arith::SubFOp::create(builder, biasedExponent, bcast(cst126f));
264 return {normalizedFraction, exponent};
278 auto exponetBitLocation = bcast(
i32Cst(builder, 23));
280 auto bias = bcast(
i32Cst(builder, 127));
282 Value biasedArg = arith::AddIOp::create(builder, arg, bias);
284 arith::ShLIOp::create(builder, biasedArg, exponetBitLocation);
285 Value exp2ValueF32 = arith::BitcastOp::create(builder, f32Vec, exp2ValueInt);
294 assert((elementType.
isF32() || elementType.
isF16()) &&
295 "x must be f32 or f16 type");
301 if (coeffs.size() == 1)
304 Value res = math::FmaOp::create(builder, x, coeffs[coeffs.size() - 1],
305 coeffs[coeffs.size() - 2]);
306 for (
auto i = ptrdiff_t(coeffs.size()) - 3; i >= 0; --i) {
307 res = math::FmaOp::create(builder, x, res, coeffs[i]);
335 if (
auto shaped = dyn_cast<ShapedType>(origType)) {
336 newType = shaped.clone(rewriter.
getF32Type());
337 }
else if (isa<FloatType>(origType)) {
341 "unable to find F32 equivalent type");
347 operands.push_back(arith::ExtFOp::create(rewriter, loc, newType, operand));
363 using OpRewritePattern<T>::OpRewritePattern;
364 LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter)
const final {
366 T::template hasTrait<mlir::OpTrait::SameOperandsAndResultType>(),
367 "requires same operands and result types");
382 LogicalResult matchAndRewrite(math::AtanOp op,
383 PatternRewriter &rewriter)
const final;
388AtanApproximation::matchAndRewrite(math::AtanOp op,
390 auto operand = op.getOperand();
394 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
396 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
397 Value
abs = math::AbsFOp::create(builder, operand);
404 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, abs, twoThirds);
405 Value addone = arith::AddFOp::create(builder, abs, one);
406 Value subone = arith::SubFOp::create(builder, abs, one);
407 Value xnum = arith::SelectOp::create(builder, cmp2, subone, abs);
408 Value xden = arith::SelectOp::create(builder, cmp2, addone, one);
410 auto bcast = [&](Value value) -> Value {
415 auto tan3pio8 = bcast(
f32Cst(builder, 2.41421356237309504880));
417 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, abs, tan3pio8);
418 xnum = arith::SelectOp::create(builder, cmp1, one, xnum);
419 xden = arith::SelectOp::create(builder, cmp1, abs, xden);
421 Value x = arith::DivFOp::create(builder, xnum, xden);
422 Value xx = arith::MulFOp::create(builder, x, x);
426 auto p0 = bcast(
f32Cst(builder, -8.750608600031904122785e-01));
427 auto p1 = bcast(
f32Cst(builder, -1.615753718733365076637e+01));
428 auto p2 = bcast(
f32Cst(builder, -7.500855792314704667340e+01));
429 auto p3 = bcast(
f32Cst(builder, -1.228866684490136173410e+02));
430 auto p4 = bcast(
f32Cst(builder, -6.485021904942025371773e+01));
431 auto q0 = bcast(
f32Cst(builder, +2.485846490142306297962e+01));
432 auto q1 = bcast(
f32Cst(builder, +1.650270098316988542046e+02));
433 auto q2 = bcast(
f32Cst(builder, +4.328810604912902668951e+02));
434 auto q3 = bcast(
f32Cst(builder, +4.853903996359136964868e+02));
435 auto q4 = bcast(
f32Cst(builder, +1.945506571482613964425e+02));
439 n = math::FmaOp::create(builder, xx, n, p1);
440 n = math::FmaOp::create(builder, xx, n, p2);
441 n = math::FmaOp::create(builder, xx, n, p3);
442 n = math::FmaOp::create(builder, xx, n, p4);
443 n = arith::MulFOp::create(builder, n, xx);
447 d = math::FmaOp::create(builder, xx, d, q1);
448 d = math::FmaOp::create(builder, xx, d, q2);
449 d = math::FmaOp::create(builder, xx, d, q3);
450 d = math::FmaOp::create(builder, xx, d, q4);
453 Value ans0 = arith::DivFOp::create(builder, n, d);
454 ans0 = math::FmaOp::create(builder, ans0, x, x);
457 Value mpi4 = bcast(
f32Cst(builder, llvm::numbers::pi / 4));
458 Value ans2 = arith::AddFOp::create(builder, mpi4, ans0);
459 Value ans = arith::SelectOp::create(builder, cmp2, ans2, ans0);
461 Value mpi2 = bcast(
f32Cst(builder, llvm::numbers::pi / 2));
462 Value ans1 = arith::SubFOp::create(builder, mpi2, ans0);
463 ans = arith::SelectOp::create(builder, cmp1, ans1, ans);
475struct Atan2Approximation :
public OpRewritePattern<math::Atan2Op> {
479 LogicalResult matchAndRewrite(math::Atan2Op op,
480 PatternRewriter &rewriter)
const final;
485Atan2Approximation::matchAndRewrite(math::Atan2Op op,
486 PatternRewriter &rewriter)
const {
487 auto y = op.getOperand(0);
488 auto x = op.getOperand(1);
492 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
493 std::optional<VectorShape> shape =
vectorShape(op.getResult());
496 auto div = arith::DivFOp::create(builder, y, x);
497 auto atan = math::AtanOp::create(builder,
div);
502 auto addPi = arith::AddFOp::create(builder, atan, pi);
503 auto subPi = arith::SubFOp::create(builder, atan, pi);
505 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, atan, zero);
506 auto flippedAtan = arith::SelectOp::create(builder, atanGt, subPi, addPi);
509 auto xGt = arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, x, zero);
510 Value
result = arith::SelectOp::create(builder, xGt, atan, flippedAtan);
514 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, x, zero);
516 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, y, zero);
517 Value isHalfPi = arith::AndIOp::create(builder, xZero, yGt);
518 auto halfPi =
broadcast(builder,
f32Cst(builder, 1.57079632679f), shape);
519 result = arith::SelectOp::create(builder, isHalfPi, halfPi,
result);
523 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, y, zero);
524 Value isNegativeHalfPiPi = arith::AndIOp::create(builder, xZero, yLt);
525 auto negativeHalfPiPi =
527 result = arith::SelectOp::create(builder, isNegativeHalfPiPi,
528 negativeHalfPiPi,
result);
532 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, y, zero);
533 Value isNan = arith::AndIOp::create(builder, xZero, yZero);
535 result = arith::SelectOp::create(builder, isNan, cstNan,
result);
546struct TanhApproximation :
public OpRewritePattern<math::TanhOp> {
550 LogicalResult matchAndRewrite(math::TanhOp op,
551 PatternRewriter &rewriter)
const final;
556TanhApproximation::matchAndRewrite(math::TanhOp op,
557 PatternRewriter &rewriter)
const {
561 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
563 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
564 auto bcast = [&](Value value) -> Value {
569 Value minusClamp = bcast(
f32Cst(builder, -7.99881172180175781f));
570 Value plusClamp = bcast(
f32Cst(builder, 7.99881172180175781f));
571 Value x =
clamp(builder, op.getOperand(), minusClamp, plusClamp);
574 Value tiny = bcast(
f32Cst(builder, 0.0004f));
575 Value tinyMask = arith::CmpFOp::create(
576 builder, arith::CmpFPredicate::OLT,
577 math::AbsFOp::create(builder, op.getOperand()), tiny);
580 Value alpha1 = bcast(
f32Cst(builder, 4.89352455891786e-03f));
581 Value alpha3 = bcast(
f32Cst(builder, 6.37261928875436e-04f));
582 Value alpha5 = bcast(
f32Cst(builder, 1.48572235717979e-05f));
583 Value alpha7 = bcast(
f32Cst(builder, 5.12229709037114e-08f));
584 Value alpha9 = bcast(
f32Cst(builder, -8.60467152213735e-11f));
585 Value alpha11 = bcast(
f32Cst(builder, 2.00018790482477e-13f));
586 Value alpha13 = bcast(
f32Cst(builder, -2.76076847742355e-16f));
589 Value beta0 = bcast(
f32Cst(builder, 4.89352518554385e-03f));
590 Value beta2 = bcast(
f32Cst(builder, 2.26843463243900e-03f));
591 Value beta4 = bcast(
f32Cst(builder, 1.18534705686654e-04f));
592 Value beta6 = bcast(
f32Cst(builder, 1.19825839466702e-06f));
595 Value x2 = arith::MulFOp::create(builder, x, x);
598 Value p = math::FmaOp::create(builder, x2, alpha13, alpha11);
599 p = math::FmaOp::create(builder, x2, p, alpha9);
600 p = math::FmaOp::create(builder, x2, p, alpha7);
601 p = math::FmaOp::create(builder, x2, p, alpha5);
602 p = math::FmaOp::create(builder, x2, p, alpha3);
603 p = math::FmaOp::create(builder, x2, p, alpha1);
604 p = arith::MulFOp::create(builder, x, p);
607 Value q = math::FmaOp::create(builder, x2, beta6, beta4);
608 q = math::FmaOp::create(builder, x2, q, beta2);
609 q = math::FmaOp::create(builder, x2, q, beta0);
612 Value res = arith::SelectOp::create(builder, tinyMask, x,
613 arith::DivFOp::create(builder, p, q));
621 0.693147180559945309417232121458176568075500134360255254120680009493393621L
623 1.442695040888963407359924681001892137426645954152985934135449406931109219L
630template <
typename Op>
642template <
typename Op>
644LogApproximationBase<Op>::logMatchAndRewrite(Op op, PatternRewriter &rewriter,
649 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
651 ImplicitLocOpBuilder builder(op->
getLoc(), rewriter);
652 auto bcast = [&](Value value) -> Value {
656 Value cstZero = bcast(
f32Cst(builder, 0.0f));
657 Value cstOne = bcast(
f32Cst(builder, 1.0f));
658 Value cstNegHalf = bcast(
f32Cst(builder, -0.5f));
661 Value cstMinNormPos = bcast(
f32FromBits(builder, 0x00800000u));
662 Value cstMinusInf = bcast(
f32FromBits(builder, 0xff800000u));
663 Value cstPosInf = bcast(
f32FromBits(builder, 0x7f800000u));
664 Value cstNan = bcast(
f32FromBits(builder, 0x7fc00000));
667 Value cstCephesSQRTHF = bcast(
f32Cst(builder, 0.707106781186547524f));
668 Value cstCephesLogP0 = bcast(
f32Cst(builder, 7.0376836292E-2f));
669 Value cstCephesLogP1 = bcast(
f32Cst(builder, -1.1514610310E-1f));
670 Value cstCephesLogP2 = bcast(
f32Cst(builder, 1.1676998740E-1f));
671 Value cstCephesLogP3 = bcast(
f32Cst(builder, -1.2420140846E-1f));
672 Value cstCephesLogP4 = bcast(
f32Cst(builder, +1.4249322787E-1f));
673 Value cstCephesLogP5 = bcast(
f32Cst(builder, -1.6668057665E-1f));
674 Value cstCephesLogP6 = bcast(
f32Cst(builder, +2.0000714765E-1f));
675 Value cstCephesLogP7 = bcast(
f32Cst(builder, -2.4999993993E-1f));
676 Value cstCephesLogP8 = bcast(
f32Cst(builder, +3.3333331174E-1f));
678 Value x = op.getOperand();
681 x =
max(builder, x, cstMinNormPos);
684 std::pair<Value, Value> pair =
frexp(builder, x,
true);
686 Value e = pair.second;
696 Value mask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, x,
698 Value tmp = arith::SelectOp::create(builder, mask, x, cstZero);
700 x = arith::SubFOp::create(builder, x, cstOne);
701 e = arith::SubFOp::create(
702 builder, e, arith::SelectOp::create(builder, mask, cstOne, cstZero));
703 x = arith::AddFOp::create(builder, x, tmp);
705 Value x2 = arith::MulFOp::create(builder, x, x);
706 Value x3 = arith::MulFOp::create(builder, x2, x);
710 y0 = math::FmaOp::create(builder, cstCephesLogP0, x, cstCephesLogP1);
711 y1 = math::FmaOp::create(builder, cstCephesLogP3, x, cstCephesLogP4);
712 y2 = math::FmaOp::create(builder, cstCephesLogP6, x, cstCephesLogP7);
713 y0 = math::FmaOp::create(builder, y0, x, cstCephesLogP2);
714 y1 = math::FmaOp::create(builder, y1, x, cstCephesLogP5);
715 y2 = math::FmaOp::create(builder, y2, x, cstCephesLogP8);
716 y0 = math::FmaOp::create(builder, y0, x3, y1);
717 y0 = math::FmaOp::create(builder, y0, x3, y2);
718 y0 = arith::MulFOp::create(builder, y0, x3);
720 y0 = math::FmaOp::create(builder, cstNegHalf, x2, y0);
721 x = arith::AddFOp::create(builder, x, y0);
725 x = math::FmaOp::create(builder, x, cstLog2e, e);
728 x = math::FmaOp::create(builder, e, cstLn2, x);
731 Value invalidMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::ULT,
732 op.getOperand(), cstZero);
733 Value zeroMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ,
734 op.getOperand(), cstZero);
735 Value posInfMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ,
736 op.getOperand(), cstPosInf);
742 Value aproximation = arith::SelectOp::create(
743 builder, zeroMask, cstMinusInf,
744 arith::SelectOp::create(
745 builder, invalidMask, cstNan,
746 arith::SelectOp::create(builder, posInfMask, cstPosInf, x)));
754struct LogApproximation :
public LogApproximationBase<math::LogOp> {
755 using LogApproximationBase::LogApproximationBase;
757 LogicalResult matchAndRewrite(math::LogOp op,
758 PatternRewriter &rewriter)
const final {
759 return logMatchAndRewrite(op, rewriter,
false);
765struct Log2Approximation :
public LogApproximationBase<math::Log2Op> {
766 using LogApproximationBase::LogApproximationBase;
768 LogicalResult matchAndRewrite(math::Log2Op op,
769 PatternRewriter &rewriter)
const final {
770 return logMatchAndRewrite(op, rewriter,
true);
780struct Log1pApproximation :
public OpRewritePattern<math::Log1pOp> {
784 LogicalResult matchAndRewrite(math::Log1pOp op,
785 PatternRewriter &rewriter)
const final;
791Log1pApproximation::matchAndRewrite(math::Log1pOp op,
792 PatternRewriter &rewriter)
const {
796 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
798 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
799 auto bcast = [&](Value value) -> Value {
809 Value cstOne = bcast(
f32Cst(builder, 1.0f));
810 Value x = op.getOperand();
811 Value u = arith::AddFOp::create(builder, x, cstOne);
813 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, u, cstOne);
814 Value logU = math::LogOp::create(builder, u);
816 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, u, logU);
817 Value logLarge = arith::MulFOp::create(
819 arith::DivFOp::create(builder, logU,
820 arith::SubFOp::create(builder, u, cstOne)));
821 Value approximation = arith::SelectOp::create(
822 builder, arith::OrIOp::create(builder, uSmall, uInf), x, logLarge);
835struct AsinPolynomialApproximation :
public OpRewritePattern<math::AsinOp> {
839 LogicalResult matchAndRewrite(math::AsinOp op,
840 PatternRewriter &rewriter)
const final;
844AsinPolynomialApproximation::matchAndRewrite(math::AsinOp op,
845 PatternRewriter &rewriter)
const {
846 Value operand = op.getOperand();
849 if (!(elementType.
isF32() || elementType.
isF16()))
851 "only f32 and f16 type is supported.");
852 std::optional<VectorShape> shape =
vectorShape(operand);
854 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
855 auto bcast = [&](Value value) -> Value {
859 auto fma = [&](Value a, Value
b, Value c) -> Value {
860 return math::FmaOp::create(builder, a,
b, c);
863 auto mul = [&](Value a, Value
b) -> Value {
864 return arith::MulFOp::create(builder, a,
b);
867 auto sub = [&](Value a, Value
b) -> Value {
868 return arith::SubFOp::create(builder, a,
b);
871 auto abs = [&](Value a) -> Value {
return math::AbsFOp::create(builder, a); };
873 auto sqrt = [&](Value a) -> Value {
874 return math::SqrtOp::create(builder, a);
877 auto scopy = [&](Value a, Value
b) -> Value {
878 return math::CopySignOp::create(builder, a,
b);
881 auto sel = [&](Value a, Value
b, Value c) -> Value {
882 return arith::SelectOp::create(builder, a,
b, c);
885 Value abso =
abs(operand);
886 Value aa =
mul(operand, operand);
887 Value opp = sqrt(sub(bcast(
floatCst(builder, 1.0, elementType)), aa));
889 Value gt = arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, aa,
890 bcast(
floatCst(builder, 0.5, elementType)));
892 Value x = sel(gt, opp, abso);
897 Value r = bcast(
floatCst(builder, 5.5579749017470502e-2, elementType));
898 Value t = bcast(
floatCst(builder, -6.2027913464120114e-2, elementType));
900 r = fma(r, q, bcast(
floatCst(builder, 5.4224464349245036e-2, elementType)));
901 t = fma(t, q, bcast(
floatCst(builder, -1.1326992890324464e-2, elementType)));
902 r = fma(r, q, bcast(
floatCst(builder, 1.5268872539397656e-2, elementType)));
903 t = fma(t, q, bcast(
floatCst(builder, 1.0493798473372081e-2, elementType)));
904 r = fma(r, q, bcast(
floatCst(builder, 1.4106045900607047e-2, elementType)));
905 t = fma(t, q, bcast(
floatCst(builder, 1.7339776384962050e-2, elementType)));
906 r = fma(r, q, bcast(
floatCst(builder, 2.2372961589651054e-2, elementType)));
907 t = fma(t, q, bcast(
floatCst(builder, 3.0381912707941005e-2, elementType)));
908 r = fma(r, q, bcast(
floatCst(builder, 4.4642857881094775e-2, elementType)));
909 t = fma(t, q, bcast(
floatCst(builder, 7.4999999991367292e-2, elementType)));
911 r = fma(r, s, bcast(
floatCst(builder, 1.6666666666670193e-1, elementType)));
915 Value rsub = sub(bcast(
floatCst(builder, 1.57079632679, elementType)), r);
916 r = sel(gt, rsub, r);
917 r = scopy(r, operand);
931struct AcosPolynomialApproximation :
public OpRewritePattern<math::AcosOp> {
935 LogicalResult matchAndRewrite(math::AcosOp op,
936 PatternRewriter &rewriter)
const final;
940AcosPolynomialApproximation::matchAndRewrite(math::AcosOp op,
941 PatternRewriter &rewriter)
const {
942 Value operand = op.getOperand();
945 if (!(elementType.
isF32() || elementType.
isF16()))
947 "only f32 and f16 type is supported.");
948 std::optional<VectorShape> shape =
vectorShape(operand);
950 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
951 auto bcast = [&](Value value) -> Value {
955 auto fma = [&](Value a, Value
b, Value c) -> Value {
956 return math::FmaOp::create(builder, a,
b, c);
959 auto mul = [&](Value a, Value
b) -> Value {
960 return arith::MulFOp::create(builder, a,
b);
963 Value negOperand = arith::NegFOp::create(builder, operand);
964 Value zero = bcast(
floatCst(builder, 0.0, elementType));
965 Value half = bcast(
floatCst(builder, 0.5, elementType));
966 Value negOne = bcast(
floatCst(builder, -1.0, elementType));
968 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, operand, zero);
969 Value r = arith::SelectOp::create(builder, selR, negOperand, operand);
970 Value chkConst = bcast(
floatCst(builder, -0.5625, elementType));
972 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, r, chkConst);
975 fma(bcast(
floatCst(builder, 9.3282184640716537e-1, elementType)),
976 bcast(
floatCst(builder, 1.6839188885261840e+0, elementType)),
977 math::AsinOp::create(builder, r));
979 Value falseVal = math::SqrtOp::create(builder, fma(half, r, half));
980 falseVal = math::AsinOp::create(builder, falseVal);
981 falseVal =
mul(bcast(
floatCst(builder, 2.0, elementType)), falseVal);
983 r = arith::SelectOp::create(builder, firstPred, trueVal, falseVal);
986 Value greaterThanNegOne = arith::CmpFOp::create(
987 builder, arith::CmpFPredicate::OGE, operand, negOne);
990 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, operand, zero);
992 Value betweenNegOneZero =
993 arith::AndIOp::create(builder, greaterThanNegOne, lessThanZero);
995 trueVal = fma(bcast(
floatCst(builder, 1.8656436928143307e+0, elementType)),
996 bcast(
floatCst(builder, 1.6839188885261840e+0, elementType)),
997 arith::NegFOp::create(builder, r));
1000 arith::SelectOp::create(builder, betweenNegOneZero, trueVal, r);
1020 Value operand = op.getOperand();
1023 if (!(elementType.
isF32() || elementType.
isF16()))
1025 "only f32 and f16 type is supported.");
1033 const int intervalsCount = 3;
1034 const int polyDegree = 4;
1038 Value pp[intervalsCount][polyDegree + 1];
1039 pp[0][0] = bcast(
floatCst(builder, +0.00000000000000000e+00f, elementType));
1040 pp[0][1] = bcast(
floatCst(builder, +1.12837916222975858e+00f, elementType));
1041 pp[0][2] = bcast(
floatCst(builder, -5.23018562988006470e-01f, elementType));
1042 pp[0][3] = bcast(
floatCst(builder, +2.09741709609267072e-01f, elementType));
1043 pp[0][4] = bcast(
floatCst(builder, +2.58146801602987875e-02f, elementType));
1044 pp[1][0] = bcast(
floatCst(builder, +0.00000000000000000e+00f, elementType));
1045 pp[1][1] = bcast(
floatCst(builder, +1.12750687816789140e+00f, elementType));
1046 pp[1][2] = bcast(
floatCst(builder, -3.64721408487825775e-01f, elementType));
1047 pp[1][3] = bcast(
floatCst(builder, +1.18407396425136952e-01f, elementType));
1048 pp[1][4] = bcast(
floatCst(builder, +3.70645533056476558e-02f, elementType));
1049 pp[2][0] = bcast(
floatCst(builder, -3.30093071049483172e-03f, elementType));
1050 pp[2][1] = bcast(
floatCst(builder, +3.51961938357697011e-03f, elementType));
1051 pp[2][2] = bcast(
floatCst(builder, -1.41373622814988039e-03f, elementType));
1052 pp[2][3] = bcast(
floatCst(builder, +2.53447094961941348e-04f, elementType));
1053 pp[2][4] = bcast(
floatCst(builder, -1.71048029455037401e-05f, elementType));
1055 Value qq[intervalsCount][polyDegree + 1];
1056 qq[0][0] = bcast(
floatCst(builder, +1.000000000000000000e+00f, elementType));
1057 qq[0][1] = bcast(
floatCst(builder, -4.635138185962547255e-01f, elementType));
1058 qq[0][2] = bcast(
floatCst(builder, +5.192301327279782447e-01f, elementType));
1059 qq[0][3] = bcast(
floatCst(builder, -1.318089722204810087e-01f, elementType));
1060 qq[0][4] = bcast(
floatCst(builder, +7.397964654672315005e-02f, elementType));
1061 qq[1][0] = bcast(
floatCst(builder, +1.00000000000000000e+00f, elementType));
1062 qq[1][1] = bcast(
floatCst(builder, -3.27607011824493086e-01f, elementType));
1063 qq[1][2] = bcast(
floatCst(builder, +4.48369090658821977e-01f, elementType));
1064 qq[1][3] = bcast(
floatCst(builder, -8.83462621207857930e-02f, elementType));
1065 qq[1][4] = bcast(
floatCst(builder, +5.72442770283176093e-02f, elementType));
1066 qq[2][0] = bcast(
floatCst(builder, +1.00000000000000000e+00f, elementType));
1067 qq[2][1] = bcast(
floatCst(builder, -2.06069165953913769e+00f, elementType));
1068 qq[2][2] = bcast(
floatCst(builder, +1.62705939945477759e+00f, elementType));
1069 qq[2][3] = bcast(
floatCst(builder, -5.83389859211130017e-01f, elementType));
1070 qq[2][4] = bcast(
floatCst(builder, +8.21908939856640930e-02f, elementType));
1072 Value offsets[intervalsCount];
1073 offsets[0] = bcast(
floatCst(builder, 0.0f, elementType));
1074 offsets[1] = bcast(
floatCst(builder, 0.0f, elementType));
1075 offsets[2] = bcast(
floatCst(builder, 1.0f, elementType));
1077 Value bounds[intervalsCount];
1078 bounds[0] = bcast(
floatCst(builder, 0.8f, elementType));
1079 bounds[1] = bcast(
floatCst(builder, 2.0f, elementType));
1080 bounds[2] = bcast(
floatCst(builder, 3.75f, elementType));
1082 Value isNegativeArg =
1083 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, operand, zero);
1084 Value negArg = arith::NegFOp::create(builder, operand);
1085 Value x = arith::SelectOp::create(builder, isNegativeArg, negArg, operand);
1087 Value offset = offsets[0];
1088 Value p[polyDegree + 1];
1089 Value q[polyDegree + 1];
1090 for (
int i = 0; i <= polyDegree; ++i) {
1096 Value isLessThanBound[intervalsCount];
1097 for (
int j = 0;
j < intervalsCount - 1; ++
j) {
1098 isLessThanBound[
j] =
1099 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, x, bounds[
j]);
1100 for (
int i = 0; i <= polyDegree; ++i) {
1101 p[i] = arith::SelectOp::create(builder, isLessThanBound[
j], p[i],
1103 q[i] = arith::SelectOp::create(builder, isLessThanBound[
j], q[i],
1106 offset = arith::SelectOp::create(builder, isLessThanBound[
j], offset,
1109 isLessThanBound[intervalsCount - 1] = arith::CmpFOp::create(
1110 builder, arith::CmpFPredicate::ULT, x, bounds[intervalsCount - 1]);
1112 Value pPoly = makePolynomialCalculation(builder, p, x);
1113 Value qPoly = makePolynomialCalculation(builder, q, x);
1114 Value rationalPoly = arith::DivFOp::create(builder, pPoly, qPoly);
1115 Value formula = arith::AddFOp::create(builder, offset, rationalPoly);
1116 formula = arith::SelectOp::create(
1117 builder, isLessThanBound[intervalsCount - 1], formula, one);
1120 Value negFormula = arith::NegFOp::create(builder, formula);
1122 arith::SelectOp::create(builder, isNegativeArg, negFormula, formula);
1141 Value x = op.getOperand();
1163 Value a = math::AbsFOp::create(builder, x);
1164 Value p = arith::AddFOp::create(builder, a, pos2);
1165 Value r = arith::DivFOp::create(builder, one, p);
1166 Value q = math::FmaOp::create(builder, neg4, r, one);
1167 Value t = math::FmaOp::create(builder, arith::AddFOp::create(builder, q, one),
1170 math::FmaOp::create(builder, arith::NegFOp::create(builder, a), q, t);
1171 q = math::FmaOp::create(builder, r, e, q);
1173 p = bcast(
floatCst(builder, -0x1.a4a000p-12f, et));
1175 p = math::FmaOp::create(builder, p, q, c1);
1177 p = math::FmaOp::create(builder, p, q, c2);
1179 p = math::FmaOp::create(builder, p, q, c3);
1181 p = math::FmaOp::create(builder, p, q, c4);
1183 p = math::FmaOp::create(builder, p, q, c5);
1185 p = math::FmaOp::create(builder, p, q, c6);
1187 p = math::FmaOp::create(builder, p, q, c7);
1189 p = math::FmaOp::create(builder, p, q, c8);
1191 p = math::FmaOp::create(builder, p, q, c9);
1193 Value d = math::FmaOp::create(builder, pos2, a, one);
1194 r = arith::DivFOp::create(builder, one, d);
1195 q = math::FmaOp::create(builder, p, r, r);
1196 Value negfa = arith::NegFOp::create(builder, a);
1197 Value fmaqah = math::FmaOp::create(builder, q, negfa, onehalf);
1198 Value psubq = arith::SubFOp::create(builder, p, q);
1199 e = math::FmaOp::create(builder, fmaqah, pos2, psubq);
1200 r = math::FmaOp::create(builder, e, r, q);
1202 Value s = arith::MulFOp::create(builder, a, a);
1203 e = math::ExpOp::create(builder, arith::NegFOp::create(builder, s));
1205 t = math::FmaOp::create(builder, arith::NegFOp::create(builder, a), a, s);
1206 r = math::FmaOp::create(
1208 arith::MulFOp::create(builder, arith::MulFOp::create(builder, r, e), t));
1210 Value isNotLessThanInf = arith::XOrIOp::create(
1212 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, a, posInf),
1214 r = arith::SelectOp::create(builder, isNotLessThanInf,
1215 arith::AddFOp::create(builder, x, x), r);
1216 Value isGreaterThanClamp =
1217 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, a, clampVal);
1218 r = arith::SelectOp::create(builder, isGreaterThanClamp, zero, r);
1221 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, x, zero);
1222 r = arith::SelectOp::create(builder, isNegative,
1223 arith::SubFOp::create(builder, pos2, r), r);
1235 const std::optional<VectorShape>
shape,
Value value,
1236 float lowerBound,
float upperBound) {
1237 assert(!std::isnan(lowerBound));
1238 assert(!std::isnan(upperBound));
1244 auto selectCmp = [&builder](
auto pred,
Value value,
Value bound) {
1245 return arith::SelectOp::create(
1246 builder, arith::CmpFOp::create(builder, pred, value, bound), value,
1253 value = selectCmp(arith::CmpFPredicate::UGE, value,
1254 bcast(
f32Cst(builder, lowerBound)));
1255 value = selectCmp(arith::CmpFPredicate::ULE, value,
1256 bcast(
f32Cst(builder, upperBound)));
1260struct ExpApproximation :
public OpRewritePattern<math::ExpOp> {
1264 LogicalResult matchAndRewrite(math::ExpOp op,
1265 PatternRewriter &rewriter)
const final;
1269ExpApproximation::matchAndRewrite(math::ExpOp op,
1270 PatternRewriter &rewriter)
const {
1271 auto shape =
vectorShape(op.getOperand().getType());
1273 if (!elementTy.isF32())
1276 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1278 auto add = [&](Value a, Value
b) -> Value {
1279 return arith::AddFOp::create(builder, a,
b);
1281 auto bcast = [&](Value value) -> Value {
1282 return broadcast(builder, value, shape);
1284 auto floor = [&](Value a) {
return math::FloorOp::create(builder, a); };
1285 auto fmla = [&](Value a, Value
b, Value c) {
1286 return math::FmaOp::create(builder, a,
b, c);
1288 auto mul = [&](Value a, Value
b) -> Value {
1289 return arith::MulFOp::create(builder, a,
b);
1313 Value cstHalf = bcast(
f32Cst(builder, 0.5f));
1314 Value cstOne = bcast(
f32Cst(builder, 1.0f));
1317 Value cstLog2ef = bcast(
f32Cst(builder, 1.44269504088896341f));
1319 Value cstExpC1 = bcast(
f32Cst(builder, -0.693359375f));
1320 Value cstExpC2 = bcast(
f32Cst(builder, 2.12194440e-4f));
1321 Value cstExpP0 = bcast(
f32Cst(builder, 1.9875691500E-4f));
1322 Value cstExpP1 = bcast(
f32Cst(builder, 1.3981999507E-3f));
1323 Value cstExpP2 = bcast(
f32Cst(builder, 8.3334519073E-3f));
1324 Value cstExpP3 = bcast(
f32Cst(builder, 4.1665795894E-2f));
1325 Value cstExpP4 = bcast(
f32Cst(builder, 1.6666665459E-1f));
1326 Value cstExpP5 = bcast(
f32Cst(builder, 5.0000001201E-1f));
1333 Value x = op.getOperand();
1334 x = clampWithNormals(builder, shape, x, -87.8f, 88.8f);
1335 Value n =
floor(fmla(x, cstLog2ef, cstHalf));
1376 n = clampWithNormals(builder, shape, n, -127.0f, 127.0f);
1379 x = fmla(cstExpC1, n, x);
1380 x = fmla(cstExpC2, n, x);
1383 Value z = fmla(x, cstExpP0, cstExpP1);
1384 z = fmla(z, x, cstExpP2);
1385 z = fmla(z, x, cstExpP3);
1386 z = fmla(z, x, cstExpP4);
1387 z = fmla(z, x, cstExpP5);
1388 z = fmla(z,
mul(x, x), x);
1393 Value nI32 = arith::FPToSIOp::create(builder, i32Vec, n);
1396 Value pow2 =
exp2I32(builder, nI32);
1399 Value ret =
mul(z, pow2);
1402 return mlir::success();
1413struct ExpM1Approximation :
public OpRewritePattern<math::ExpM1Op> {
1417 LogicalResult matchAndRewrite(math::ExpM1Op op,
1418 PatternRewriter &rewriter)
const final;
1423ExpM1Approximation::matchAndRewrite(math::ExpM1Op op,
1424 PatternRewriter &rewriter)
const {
1428 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
1430 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1431 auto bcast = [&](Value value) -> Value {
1432 return broadcast(builder, value, shape);
1438 Value cstOne = bcast(
f32Cst(builder, 1.0f));
1439 Value cstNegOne = bcast(
f32Cst(builder, -1.0f));
1440 Value x = op.getOperand();
1441 Value u = math::ExpOp::create(builder, x);
1443 arith::CmpFOp::create(builder, arith::CmpFPredicate::UEQ, u, cstOne);
1444 Value uMinusOne = arith::SubFOp::create(builder, u, cstOne);
1445 Value uMinusOneEqNegOne = arith::CmpFOp::create(
1446 builder, arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne);
1448 Value logU = math::LogOp::create(builder, u);
1452 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, logU, u);
1455 Value expm1 = arith::MulFOp::create(builder, uMinusOne,
1456 arith::DivFOp::create(builder, x, logU));
1457 expm1 = arith::SelectOp::create(builder, isInf, u, expm1);
1458 Value approximation = arith::SelectOp::create(
1459 builder, uEqOneOrNaN, x,
1460 arith::SelectOp::create(builder, uMinusOneEqNegOne, cstNegOne, expm1));
1471template <
bool isSine,
typename OpTy>
1472struct SinAndCosApproximation :
public OpRewritePattern<OpTy> {
1474 using OpRewritePattern<OpTy>::OpRewritePattern;
1476 LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter)
const final;
1480#define TWO_OVER_PI \
1481 0.6366197723675813430755350534900574481378385829618257949906693762L
1483 1.5707963267948966192313216916397514420985846996875529104874722961L
1488template <
bool isSine,
typename OpTy>
1489LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
1492 llvm::is_one_of<OpTy, math::SinOp, math::CosOp>::value,
1493 "SinAndCosApproximation pattern expects math::SinOp or math::CosOp");
1505 return arith::MulFOp::create(builder, a,
b);
1508 return arith::SubFOp::create(builder, a,
b);
1510 auto floor = [&](
Value a) {
return math::FloorOp::create(builder, a); };
1513 auto fPToSingedInteger = [&](
Value a) ->
Value {
1514 return arith::FPToSIOp::create(builder, i32Vec, a);
1518 return arith::AndIOp::create(builder, a, bcast(
i32Cst(builder, 3)));
1522 return arith::CmpIOp::create(builder, arith::CmpIPredicate::eq, a,
b);
1526 return arith::CmpIOp::create(builder, arith::CmpIPredicate::sgt, a,
b);
1530 return arith::SelectOp::create(builder, cond, t, f);
1534 return math::FmaOp::create(builder, a,
b, c);
1538 return arith::OrIOp::create(builder, a,
b);
1544 Value x = op.getOperand();
1546 Value k = floor(
mul(x, twoOverPi));
1548 Value y = sub(x,
mul(k, piOverTwo));
1551 Value cstNegativeOne = bcast(
f32Cst(builder, -1.0));
1553 Value cstSC2 = bcast(
f32Cst(builder, -0.16666667163372039794921875f));
1554 Value cstSC4 = bcast(
f32Cst(builder, 8.333347737789154052734375e-3f));
1555 Value cstSC6 = bcast(
f32Cst(builder, -1.9842604524455964565277099609375e-4f));
1557 bcast(
f32Cst(builder, 2.760012648650445044040679931640625e-6f));
1559 bcast(
f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f));
1562 Value cstCC4 = bcast(
f32Cst(builder, 4.166664183139801025390625e-2f));
1563 Value cstCC6 = bcast(
f32Cst(builder, -1.388833043165504932403564453125e-3f));
1564 Value cstCC8 = bcast(
f32Cst(builder, 2.47562347794882953166961669921875e-5f));
1566 bcast(
f32Cst(builder, -2.59630184018533327616751194000244140625e-7f));
1568 Value kMod4 = modulo4(fPToSingedInteger(k));
1570 Value kR0 = isEqualTo(kMod4, bcast(
i32Cst(builder, 0)));
1571 Value kR1 = isEqualTo(kMod4, bcast(
i32Cst(builder, 1)));
1572 Value kR2 = isEqualTo(kMod4, bcast(
i32Cst(builder, 2)));
1573 Value kR3 = isEqualTo(kMod4, bcast(
i32Cst(builder, 3)));
1575 Value sinuseCos = isSine ? bitwiseOr(kR1, kR3) : bitwiseOr(kR0, kR2);
1576 Value negativeRange = isSine ? isGreaterThan(kMod4, bcast(
i32Cst(builder, 1)))
1577 : bitwiseOr(kR1, kR2);
1581 Value base = select(sinuseCos, cstOne, y);
1582 Value cstC2 = select(sinuseCos, cstCC2, cstSC2);
1583 Value cstC4 = select(sinuseCos, cstCC4, cstSC4);
1584 Value cstC6 = select(sinuseCos, cstCC6, cstSC6);
1585 Value cstC8 = select(sinuseCos, cstCC8, cstSC8);
1586 Value cstC10 = select(sinuseCos, cstCC10, cstSC10);
1588 Value v1 = fmla(y2, cstC10, cstC8);
1589 Value v2 = fmla(y2, v1, cstC6);
1590 Value v3 = fmla(y2, v2, cstC4);
1591 Value v4 = fmla(y2, v3, cstC2);
1592 Value v5 = fmla(y2, v4, cstOne);
1595 Value approximation = select(negativeRange,
mul(cstNegativeOne, v6), v6);
1607struct CbrtApproximation :
public OpRewritePattern<math::CbrtOp> {
1610 LogicalResult matchAndRewrite(math::CbrtOp op,
1611 PatternRewriter &rewriter)
const final;
1618CbrtApproximation::matchAndRewrite(math::CbrtOp op,
1619 PatternRewriter &rewriter)
const {
1620 auto operand = op.getOperand();
1624 ImplicitLocOpBuilder
b(op->getLoc(), rewriter);
1625 std::optional<VectorShape> shape =
vectorShape(operand);
1634 auto bconst = [&](TypedAttr attr) -> Value {
1635 Value value = arith::ConstantOp::create(
b, attr);
1640 Value intTwo = bconst(
b.getI32IntegerAttr(2));
1641 Value intFour = bconst(
b.getI32IntegerAttr(4));
1642 Value intEight = bconst(
b.getI32IntegerAttr(8));
1643 Value intMagic = bconst(
b.getI32IntegerAttr(0x2a5137a0));
1644 Value fpThird = bconst(
b.getF32FloatAttr(0.33333333f));
1645 Value fpTwo = bconst(
b.getF32FloatAttr(2.0f));
1646 Value fpZero = bconst(
b.getF32FloatAttr(0.0f));
1652 Value absValue = math::AbsFOp::create(
b, operand);
1653 Value intValue = arith::BitcastOp::create(
b, intTy, absValue);
1654 Value divideBy4 = arith::ShRSIOp::create(
b, intValue, intTwo);
1655 Value divideBy16 = arith::ShRSIOp::create(
b, intValue, intFour);
1656 intValue = arith::AddIOp::create(
b, divideBy4, divideBy16);
1659 divideBy16 = arith::ShRSIOp::create(
b, intValue, intFour);
1660 intValue = arith::AddIOp::create(
b, intValue, divideBy16);
1663 Value divideBy256 = arith::ShRSIOp::create(
b, intValue, intEight);
1664 intValue = arith::AddIOp::create(
b, intValue, divideBy256);
1667 intValue = arith::AddIOp::create(
b, intValue, intMagic);
1671 Value floatValue = arith::BitcastOp::create(
b, floatTy, intValue);
1672 Value squared = arith::MulFOp::create(
b, floatValue, floatValue);
1673 Value mulTwo = arith::MulFOp::create(
b, floatValue, fpTwo);
1674 Value divSquared = arith::DivFOp::create(
b, absValue, squared);
1675 floatValue = arith::AddFOp::create(
b, mulTwo, divSquared);
1676 floatValue = arith::MulFOp::create(
b, floatValue, fpThird);
1679 squared = arith::MulFOp::create(
b, floatValue, floatValue);
1680 mulTwo = arith::MulFOp::create(
b, floatValue, fpTwo);
1681 divSquared = arith::DivFOp::create(
b, absValue, squared);
1682 floatValue = arith::AddFOp::create(
b, mulTwo, divSquared);
1683 floatValue = arith::MulFOp::create(
b, floatValue, fpThird);
1687 arith::CmpFOp::create(
b, arith::CmpFPredicate::OEQ, absValue, fpZero);
1688 floatValue = arith::SelectOp::create(
b, isZero, fpZero, floatValue);
1689 floatValue = math::CopySignOp::create(
b, floatValue, operand);
1700struct RsqrtApproximation :
public OpRewritePattern<math::RsqrtOp> {
1703 LogicalResult matchAndRewrite(math::RsqrtOp op,
1704 PatternRewriter &rewriter)
const final;
1709RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
1710 PatternRewriter &rewriter)
const {
1714 std::optional<VectorShape> shape =
vectorShape(op.getOperand());
1717 if (!shape || shape->sizes.empty() || shape->sizes.back() % 8 != 0)
1720 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1721 auto bcast = [&](Value value) -> Value {
1722 return broadcast(builder, value, shape);
1725 Value cstPosInf = bcast(
f32FromBits(builder, 0x7f800000u));
1726 Value cstOnePointFive = bcast(
f32Cst(builder, 1.5f));
1727 Value cstNegHalf = bcast(
f32Cst(builder, -0.5f));
1728 Value cstMinNormPos = bcast(
f32FromBits(builder, 0x00800000u));
1730 Value negHalf = arith::MulFOp::create(builder, op.getOperand(), cstNegHalf);
1734 Value ltMinMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT,
1735 op.getOperand(), cstMinNormPos);
1736 Value infMask = arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ,
1737 op.getOperand(), cstPosInf);
1738 Value notNormalFiniteMask = arith::OrIOp::create(builder, ltMinMask, infMask);
1742 builder, op->getOperands(), 8, [&builder](
ValueRange operands) -> Value {
1743 return x86vector::RsqrtOp::create(builder, operands);
1750 Value inner = arith::MulFOp::create(builder, negHalf, yApprox);
1751 Value fma = math::FmaOp::create(builder, yApprox, inner, cstOnePointFive);
1752 Value yNewton = arith::MulFOp::create(builder, yApprox, fma);
1760 arith::SelectOp::create(builder, notNormalFiniteMask, yApprox, yNewton);
1783template <
typename OpType>
1788 if (predicate(OpType::getOperationName())) {
1824template <
typename OpType,
typename PatternType>
1828 if (predicate(OpType::getOperationName())) {
1837 AcosPolynomialApproximation>(
1840 AsinPolynomialApproximation>(
1849 CosOp, SinAndCosApproximation<false, math::CosOp>>(
patterns, predicate,
1869 SinOp, SinAndCosApproximation<true, math::SinOp>>(
patterns, predicate,
1879 return llvm::is_contained(
1880 {math::AtanOp::getOperationName(), math::Atan2Op::getOperationName(),
1881 math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
1882 math::Log2Op::getOperationName(), math::Log1pOp::getOperationName(),
1883 math::ErfOp::getOperationName(), math::ErfcOp::getOperationName(),
1884 math::ExpOp::getOperationName(), math::ExpM1Op::getOperationName(),
1885 math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
1886 math::CosOp::getOperationName()},
1891 patterns, [](StringRef name) ->
bool {
1892 return llvm::is_contained(
1893 {math::AtanOp::getOperationName(),
1894 math::Atan2Op::getOperationName(),
1895 math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
1896 math::Log2Op::getOperationName(),
1897 math::Log1pOp::getOperationName(), math::ErfOp::getOperationName(),
1898 math::ErfcOp::getOperationName(), math::AsinOp::getOperationName(),
1899 math::AcosOp::getOperationName(), math::ExpOp::getOperationName(),
1900 math::ExpM1Op::getOperationName(),
1901 math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
1902 math::CosOp::getOperationName()},
1907 auto predicateRsqrt = [](StringRef name) {
1908 return name == math::RsqrtOp::getOperationName();
static llvm::ManagedStatic< PassManagerOptions > options
static Value exp2I32(ImplicitLocOpBuilder &builder, Value arg)
static void populateMathF32ExpansionPattern(RewritePatternSet &patterns, llvm::function_ref< bool(StringRef)> predicate, PatternBenefit benefit)
static Value boolCst(ImplicitLocOpBuilder &builder, bool value)
static Value floatCst(ImplicitLocOpBuilder &builder, float value, Type elementType)
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static std::pair< Value, Value > frexp(ImplicitLocOpBuilder &builder, Value arg, bool isPositive=false)
static std::optional< VectorShape > vectorShape(Type type)
static Value i32Cst(ImplicitLocOpBuilder &builder, int32_t value)
static Type broadcast(Type type, std::optional< VectorShape > shape)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value f32FromBits(ImplicitLocOpBuilder &builder, uint32_t bits)
static Value f32Cst(ImplicitLocOpBuilder &builder, double value)
static void populateMathPolynomialApproximationPattern(RewritePatternSet &patterns, llvm::function_ref< bool(StringRef)> predicate, PatternBenefit benefit)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
IntegerAttr getI32IntegerAttr(int32_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
TypedAttr getZeroAttr(Type type)
FloatAttr getF32FloatAttr(float value)
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Location getLoc()
The source location the operation was defined or derived from.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
DynamicAPInt floor(const Fraction &f)
Fraction abs(const Fraction &f)
Include the generated interface declarations.
void populatePolynomialApproximateErfcPattern(RewritePatternSet &patterns)
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
void populateMathF32ExpansionPatterns(RewritePatternSet &patterns, llvm::function_ref< bool(StringRef)> predicate, PatternBenefit=1)
void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
int64_t computeMaxLinearIndex(ArrayRef< int64_t > basis)
Return the number of elements of basis (i.e.
void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns, llvm::function_ref< bool(StringRef)> predicate, PatternBenefit=1)
ArrayRef< int64_t > sizes
ArrayRef< bool > scalableFlags
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
LogicalResult matchAndRewrite(math::ErfOp op, PatternRewriter &rewriter) const final
LogicalResult matchAndRewrite(math::ErfcOp op, PatternRewriter &rewriter) const final
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.