115 auto floatType = op.getOperand().getType();
122 Value isNegative = arith::CmpFOp::create(
123 rewriter, loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
124 Value isNegativeFloat =
125 arith::UIToFPOp::create(rewriter, loc, floatType, isNegative);
126 Value isNegativeTimesNegTwo =
127 arith::MulFOp::create(rewriter, loc, isNegativeFloat, negTwo);
128 Value sign = arith::AddFOp::create(rewriter, loc, isNegativeTimesNegTwo, one);
131 Value positiveX = arith::MulFOp::create(rewriter, loc, sign, op.getOperand());
134 Value negDoubledX = arith::MulFOp::create(rewriter, loc, negTwo, positiveX);
135 Value exp2x = math::ExpOp::create(rewriter, loc, negDoubledX);
136 Value dividend = arith::SubFOp::create(rewriter, loc, one, exp2x);
137 Value divisor = arith::AddFOp::create(rewriter, loc, one, exp2x);
138 Value positiveRes = arith::DivFOp::create(rewriter, loc, dividend, divisor);
228 auto shapedType = dyn_cast<ShapedType>(op.getType());
229 if (shapedType && !shapedType.hasStaticShape())
233 Value operand = op.getOperand();
241 Value gtCheck = arith::CmpFOp::create(
b, arith::CmpFPredicate::OGT, operand,
244 arith::SelectOp::create(
b, op->getLoc(), gtCheck, one, zero);
246 Value ret = arith::AddFOp::create(
b, opType, fpFixedConvert, incrValue);
258 Value base = op.getOperand(0);
259 Value power = op.getOperand(1);
262 auto convertFPowItoPowf = [&]() -> LogicalResult {
263 Value castPowerToFp =
264 arith::SIToFPOp::create(rewriter, op.getLoc(), baseType, power);
265 Value res = math::PowFOp::create(rewriter, op.getLoc(), baseType, base,
273 return convertFPowItoPowf();
277 return convertFPowItoPowf();
279 int64_t powerInt = value.getSExtValue();
280 bool isNegative = powerInt < 0;
281 int64_t absPower = std::abs(powerInt);
285 while (absPower > 0) {
287 res = arith::MulFOp::create(
b, baseType, base, res);
289 base = arith::MulFOp::create(
b, baseType, base, base);
295 .getFloatSemantics();
298 APFloat::getZero(sem,
false), rewriter);
301 APFloat::getZero(sem,
true), rewriter);
304 APFloat::getInf(sem,
false), rewriter);
307 APFloat::getInf(sem,
true), rewriter);
309 arith::CmpFOp::create(
b, arith::CmpFPredicate::OEQ, res, zero);
310 Value negZeroEqCheck =
311 arith::CmpFOp::create(
b, arith::CmpFPredicate::OEQ, res, negZero);
312 res = arith::DivFOp::create(
b, baseType, one, res);
314 arith::SelectOp::create(
b, op->getLoc(), zeroEqCheck, posInfinity, res);
315 res = arith::SelectOp::create(
b, op->getLoc(), negZeroEqCheck, negInfinity,
328 Value operandA = op.getOperand(0);
329 Value operandB = op.getOperand(1);
330 auto typeA = operandA.
getType();
331 auto typeB = operandB.
getType();
337 return arith::MulFOp::create(
b, x, y);
340 if (valueB.isZero()) {
346 if (valueB.isExactlyValue(1.0)) {
351 if (valueB.isExactlyValue(-1.0)) {
354 Value div = arith::DivFOp::create(
b, one, operandA);
358 if (valueB.isExactlyValue(0.5)) {
360 Value sqrt = math::SqrtOp::create(
b, operandA);
364 if (valueB.isExactlyValue(-0.5)) {
366 Value rsqrt = math::RsqrtOp::create(
b, operandA);
370 if (valueB.isExactlyValue(2.0)) {
372 rewriter.
replaceOp(op, mulf(operandA, operandA));
375 if (valueB.isExactlyValue(-2.0)) {
379 Value div = arith::DivFOp::create(
b, one, mulf(operandA, operandA));
383 if (valueB.isExactlyValue(3.0)) {
384 rewriter.
replaceOp(op, mulf(mulf(operandA, operandA), operandA));
389 Value logA = math::LogOp::create(
b, operandA);
390 Value mult = arith::MulFOp::create(
b, operandB, logA);
391 Value expResult = math::ExpOp::create(
b, mult);
416 Value operand = op.getOperand();
420 if (!opEType.
isF32()) {
424 Type i32Ty =
b.getI32Type();
425 if (
auto shapedTy = dyn_cast<ShapedType>(opType))
426 i32Ty = shapedTy.clone(i32Ty);
433 Value incrValue = math::CopySignOp::create(
b, half, operand);
434 Value add = arith::AddFOp::create(
b, opType, operand, incrValue);
457 Value operandBitcast = arith::BitcastOp::create(
b, i32Ty, operand);
458 Value operandExp = arith::AndIOp::create(
459 b, arith::ShRUIOp::create(
b, operandBitcast, c23), expMask);
460 Value operandBiasedExp = arith::SubIOp::create(
b, operandExp, c127);
461 Value isSpecialValOrLargeVal = arith::CmpIOp::create(
462 b, arith::CmpIPredicate::sge, operandBiasedExp, c23);
464 Value result = arith::SelectOp::create(
b, isSpecialValOrLargeVal, operand,
474 auto operand = op.getOperand();
475 auto operandTy = operand.
getType();
479 int32_t bitwidth = eTy.getIntOrFloatBitWidth();
483 uint64_t allbits = -1;
485 allbits = allbits >> (64 - bitwidth);
490 for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
493 auto mask =
createIntConst(loc, operandTy, allbits >> half, rewriter);
495 Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ule,
497 Value add = arith::AddIOp::create(rewriter, loc, count, bits);
498 Value shift = arith::ShLIOp::create(rewriter, loc, x, bits);
500 x = arith::SelectOp::create(rewriter, loc, pred, shift, x);
501 count = arith::SelectOp::create(rewriter, loc, pred,
add, count);
505 Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
509 Value sel = arith::SelectOp::create(rewriter, loc, pred, bwval, count);
519 auto operand = op.getOperand();
521 Type resultTy = op.getType();
525 if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
529 Type fTy = operandTy;
531 if (
auto shapedTy = dyn_cast<ShapedType>(fTy)) {
532 iTy = shapedTy.clone(iTy);
537 unsigned mantissaWidth =
538 llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
539 unsigned exponentWidth = bitWidth - mantissaWidth - 1;
556 Value operandBitcast = arith::BitcastOp::create(
b, iTy, operand);
557 Value round = math::RoundOp::create(
b, operand);
558 Value roundBitcast = arith::BitcastOp::create(
b, iTy, round);
561 Value operandExp = arith::AndIOp::create(
562 b, arith::ShRUIOp::create(
b, operandBitcast, c23), expMask);
563 Value operandBiasedExp = arith::SubIOp::create(
b, operandExp, c127);
564 Value roundExp = arith::AndIOp::create(
565 b, arith::ShRUIOp::create(
b, roundBitcast, c23), expMask);
566 Value roundBiasedExp = arith::SubIOp::create(
b, roundExp, c127);
570 Value clampedShift = arith::MaxSIOp::create(
b, shift, c0);
571 clampedShift = arith::MinSIOp::create(
b, clampedShift, c31);
572 return arith::ShRUIOp::create(
b, x, clampedShift);
575 auto maskMantissa = [&](
Value mantissa,
577 Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
578 return arith::AndIOp::create(
b, mantissa, shiftedMantissaMask);
595 Value roundBiasedExpEq0 =
596 arith::CmpIOp::create(
b, arith::CmpIPredicate::eq, roundBiasedExp, c0);
597 Value roundBiasedExpMinus1 = arith::SubIOp::create(
b, roundBiasedExp, c1);
598 Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
599 Value roundIsNotEvenOrSpecialVal = arith::CmpIOp::create(
600 b, arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
601 roundIsNotEvenOrSpecialVal =
602 arith::OrIOp::create(
b, roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
611 Value operandBiasedExpEqNeg1 = arith::CmpIOp::create(
612 b, arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
613 Value expectedOperandMaskedMantissa = arith::SelectOp::create(
614 b, operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
615 Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
616 Value operandIsHalfway =
617 arith::CmpIOp::create(
b, arith::CmpIPredicate::eq, operandMaskedMantissa,
618 expectedOperandMaskedMantissa);
620 Value operandBiasedExpGeNeg1 = arith::CmpIOp::create(
621 b, arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
622 Value operandBiasedExpLt23 = arith::CmpIOp::create(
623 b, arith::CmpIPredicate::slt, operandBiasedExp, c23);
625 arith::AndIOp::create(
b, operandIsHalfway, operandBiasedExpLt23);
627 arith::AndIOp::create(
b, operandIsHalfway, operandBiasedExpGeNeg1);
631 Value sign = math::CopySignOp::create(
b, c1Float, operand);
632 Value roundShifted = arith::SubFOp::create(
b, round, sign);
636 arith::AndIOp::create(
b, roundIsNotEvenOrSpecialVal, operandIsHalfway);
637 Value result = arith::SelectOp::create(
b, needsShift, roundShifted, round);
680 auto filter = [&](StringRef name) {
684 assert(
"math" == MathDialect::getDialectNamespace());
685 name.consume_front(
"math.");
686 return opMnemonics.empty() || (llvm::count(opMnemonics, name) > 0);
688 if (filter(CountLeadingZerosOp::getOperationName()))
690 if (filter(SinhOp::getOperationName()))
692 if (filter(CoshOp::getOperationName()))
694 if (filter(TanOp::getOperationName()))
696 if (filter(TanhOp::getOperationName()))
698 if (filter(AsinhOp::getOperationName()))
700 if (filter(AcoshOp::getOperationName()))
702 if (filter(AtanhOp::getOperationName()))
704 if (filter(FmaOp::getOperationName()))
706 if (filter(CeilOp::getOperationName()))
708 if (filter(Exp2Op::getOperationName()))
710 if (filter(PowFOp::getOperationName()))
712 if (filter(FPowIOp::getOperationName()))
714 if (filter(RoundOp::getOperationName()))
716 if (filter(RoundEvenOp::getOperationName()))
718 if (filter(RsqrtOp::getOperationName()))
720 if (filter(ClampFOp::getOperationName()))