20 #include <type_traits>
23 #define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARDPASS
24 #include "mlir/Conversion/Passes.h.inc"
31 enum class AbsFn {
abs, sqrt, rsqrt };
46 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
47 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
49 Value ratioSq = b.
create<arith::MulFOp>(ratio, ratio, fmfWithNaNInf);
50 Value ratioSqPlusOne = b.
create<arith::AddFOp>(ratioSq, one, fmfWithNaNInf);
53 if (fn == AbsFn::rsqrt) {
54 ratioSqPlusOne = b.
create<math::RsqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
59 if (fn == AbsFn::sqrt) {
64 Value p025 = b.
create<math::PowFOp>(ratioSqPlusOne, quarter, fmfWithNaNInf);
65 result = b.
create<arith::MulFOp>(sqrt, p025, fmfWithNaNInf);
67 Value sqrt = b.
create<math::SqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
68 result = b.
create<arith::MulFOp>(
max, sqrt, fmfWithNaNInf);
71 Value isNaN = b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result,
72 result, fmfWithNaNInf);
73 return b.
create<arith::SelectOp>(isNaN,
min, result);
80 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
84 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
86 Value real = b.
create<complex::ReOp>(adaptor.getComplex());
87 Value imag = b.
create<complex::ImOp>(adaptor.getComplex());
88 rewriter.
replaceOp(op, computeAbs(real, imag, fmf, b));
99 matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
103 auto type = cast<ComplexType>(op.getType());
104 Type elementType = type.getElementType();
105 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
107 Value lhs = adaptor.getLhs();
108 Value rhs = adaptor.getRhs();
110 Value rhsSquared = b.
create<complex::MulOp>(type, rhs, rhs, fmf);
111 Value lhsSquared = b.
create<complex::MulOp>(type, lhs, lhs, fmf);
112 Value rhsSquaredPlusLhsSquared =
113 b.
create<complex::AddOp>(type, rhsSquared, lhsSquared, fmf);
114 Value sqrtOfRhsSquaredPlusLhsSquared =
115 b.
create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared, fmf);
119 Value one = b.
create<arith::ConstantOp>(elementType,
121 Value i = b.
create<complex::CreateOp>(type, zero, one);
122 Value iTimesLhs = b.
create<complex::MulOp>(i, lhs, fmf);
123 Value rhsPlusILhs = b.
create<complex::AddOp>(rhs, iTimesLhs, fmf);
126 rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
127 Value logResult = b.
create<complex::LogOp>(divResult, fmf);
131 Value negativeI = b.
create<complex::CreateOp>(type, zero, negativeOne);
138 template <
typename ComparisonOp, arith::CmpFPredicate p>
141 using ResultCombiner =
142 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
143 arith::AndIOp, arith::OrIOp>;
146 matchAndRewrite(ComparisonOp op,
typename ComparisonOp::Adaptor adaptor,
148 auto loc = op.getLoc();
149 auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
151 Value realLhs = rewriter.
create<complex::ReOp>(loc, type, adaptor.getLhs());
152 Value imagLhs = rewriter.
create<complex::ImOp>(loc, type, adaptor.getLhs());
153 Value realRhs = rewriter.
create<complex::ReOp>(loc, type, adaptor.getRhs());
154 Value imagRhs = rewriter.
create<complex::ImOp>(loc, type, adaptor.getRhs());
155 Value realComparison =
156 rewriter.
create<arith::CmpFOp>(loc, p, realLhs, realRhs);
157 Value imagComparison =
158 rewriter.
create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
169 template <
typename BinaryComplexOp,
typename BinaryStandardOp>
174 matchAndRewrite(BinaryComplexOp op,
typename BinaryComplexOp::Adaptor adaptor,
176 auto type = cast<ComplexType>(adaptor.getLhs().getType());
177 auto elementType = cast<FloatType>(type.getElementType());
179 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
181 Value realLhs = b.
create<complex::ReOp>(elementType, adaptor.getLhs());
182 Value realRhs = b.
create<complex::ReOp>(elementType, adaptor.getRhs());
183 Value resultReal = b.
create<BinaryStandardOp>(elementType, realLhs, realRhs,
185 Value imagLhs = b.
create<complex::ImOp>(elementType, adaptor.getLhs());
186 Value imagRhs = b.
create<complex::ImOp>(elementType, adaptor.getRhs());
187 Value resultImag = b.
create<BinaryStandardOp>(elementType, imagLhs, imagRhs,
195 template <
typename TrigonometricOp>
202 matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
204 auto loc = op.getLoc();
205 auto type = cast<ComplexType>(adaptor.getComplex().getType());
206 auto elementType = cast<FloatType>(type.getElementType());
207 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
210 rewriter.
create<complex::ReOp>(loc, elementType, adaptor.getComplex());
212 rewriter.
create<complex::ImOp>(loc, elementType, adaptor.getComplex());
218 loc, elementType, rewriter.
getFloatAttr(elementType, 0.5));
219 Value exp = rewriter.
create<math::ExpOp>(loc, imag, fmf);
220 Value scaledExp = rewriter.
create<arith::MulFOp>(loc, half, exp, fmf);
221 Value reciprocalExp = rewriter.
create<arith::DivFOp>(loc, half, exp, fmf);
222 Value sin = rewriter.
create<math::SinOp>(loc, real, fmf);
223 Value cos = rewriter.
create<math::CosOp>(loc, real, fmf);
226 combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
233 virtual std::pair<Value, Value>
236 arith::FastMathFlagsAttr fmf)
const = 0;
239 struct CosOpConversion :
public TrigonometricOpConversion<complex::CosOp> {
240 using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
245 arith::FastMathFlagsAttr fmf)
const override {
256 rewriter.
create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf);
257 Value resultReal = rewriter.
create<arith::MulFOp>(loc, sum, cos, fmf);
259 rewriter.
create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf);
260 Value resultImag = rewriter.
create<arith::MulFOp>(loc, diff, sin, fmf);
261 return {resultReal, resultImag};
266 DivOpConversion(
MLIRContext *context, complex::ComplexRangeFlags target)
272 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
274 auto loc = op.getLoc();
275 auto type = cast<ComplexType>(adaptor.getLhs().getType());
276 auto elementType = cast<FloatType>(type.getElementType());
277 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
280 rewriter.
create<complex::ReOp>(loc, elementType, adaptor.getLhs());
282 rewriter.
create<complex::ImOp>(loc, elementType, adaptor.getLhs());
284 rewriter.
create<complex::ReOp>(loc, elementType, adaptor.getRhs());
286 rewriter.
create<complex::ImOp>(loc, elementType, adaptor.getRhs());
288 Value resultReal, resultImag;
290 if (complexRange == complex::ComplexRangeFlags::basic ||
291 complexRange == complex::ComplexRangeFlags::none) {
293 rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal,
295 }
else if (complexRange == complex::ComplexRangeFlags::improved) {
297 rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal,
308 complex::ComplexRangeFlags complexRange;
315 matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
317 auto loc = op.getLoc();
318 auto type = cast<ComplexType>(adaptor.getComplex().getType());
319 auto elementType = cast<FloatType>(type.getElementType());
320 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
323 rewriter.
create<complex::ReOp>(loc, elementType, adaptor.getComplex());
325 rewriter.
create<complex::ImOp>(loc, elementType, adaptor.getComplex());
326 Value expReal = rewriter.
create<math::ExpOp>(loc, real, fmf.getValue());
327 Value cosImag = rewriter.
create<math::CosOp>(loc, imag, fmf.getValue());
329 rewriter.
create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue());
330 Value sinImag = rewriter.
create<math::SinOp>(loc, imag, fmf.getValue());
332 rewriter.
create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue());
342 arith::FastMathFlagsAttr fmf) {
343 auto argType = mlir::cast<FloatType>(arg.
getType());
346 for (
unsigned i = 1; i < coefficients.size(); ++i) {
347 poly = b.
create<math::FmaOp>(
363 matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
365 auto type = op.getType();
366 auto elemType = mlir::cast<FloatType>(type.getElementType());
368 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
370 Value real = b.
create<complex::ReOp>(adaptor.getComplex());
371 Value imag = b.
create<complex::ImOp>(adaptor.getComplex());
376 Value expm1Real = b.
create<math::ExpM1Op>(real, fmf);
377 Value expReal = b.
create<arith::AddFOp>(expm1Real, one, fmf);
380 Value cosm1Imag = emitCosm1(imag, fmf, b);
381 Value cosImag = b.
create<arith::AddFOp>(cosm1Imag, one, fmf);
384 b.
create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf);
386 Value imagIsZero = b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag,
387 zero, fmf.getValue());
389 imagIsZero, zero, b.
create<arith::MulFOp>(expReal, sinImag, fmf));
397 Value emitCosm1(
Value arg, arith::FastMathFlagsAttr fmf,
399 auto argType = mlir::cast<FloatType>(arg.
getType());
405 4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
406 2.0876754287081521758361E-9, -2.7557319214999787979814E-7,
407 2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
408 4.1666666666666666609054E-2,
411 Value forLargeArg = b.
create<arith::AddFOp>(cos, negOne, fmf);
413 Value argPow2 = b.
create<arith::MulFOp>(arg, arg, fmf);
414 Value argPow4 = b.
create<arith::MulFOp>(argPow2, argPow2, fmf);
415 Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf);
418 b.
create<arith::AddFOp>(b.
create<arith::MulFOp>(argPow4, poly, fmf),
419 b.
create<arith::MulFOp>(negHalf, argPow2, fmf));
424 Value cond = b.
create<arith::CmpFOp>(arith::CmpFPredicate::OGE, argPow2,
425 piOver4Pow2, fmf.getValue());
426 return b.
create<arith::SelectOp>(cond, forLargeArg, forSmallArg);
434 matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
436 auto type = cast<ComplexType>(adaptor.getComplex().getType());
437 auto elementType = cast<FloatType>(type.getElementType());
438 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
441 Value abs = b.
create<complex::AbsOp>(elementType, adaptor.getComplex(),
443 Value resultReal = b.
create<math::LogOp>(elementType,
abs, fmf.getValue());
444 Value real = b.
create<complex::ReOp>(elementType, adaptor.getComplex());
445 Value imag = b.
create<complex::ImOp>(elementType, adaptor.getComplex());
447 b.
create<math::Atan2Op>(elementType, imag, real, fmf.getValue());
458 matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
460 auto type = cast<ComplexType>(adaptor.getComplex().getType());
461 auto elementType = cast<FloatType>(type.getElementType());
462 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
465 Value real = b.
create<complex::ReOp>(adaptor.getComplex());
466 Value imag = b.
create<complex::ImOp>(adaptor.getComplex());
468 Value half = b.
create<arith::ConstantOp>(elementType,
470 Value one = b.
create<arith::ConstantOp>(elementType,
472 Value realPlusOne = b.
create<arith::AddFOp>(real, one, fmf);
473 Value absRealPlusOne = b.
create<math::AbsFOp>(realPlusOne, fmf);
476 Value maxAbs = b.
create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf);
477 Value minAbs = b.
create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf);
479 Value useReal = b.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
480 realPlusOne, absImag, fmf);
481 Value maxMinusOne = b.
create<arith::SubFOp>(maxAbs, one, fmf);
482 Value maxAbsOfRealPlusOneAndImagMinusOne =
483 b.
create<arith::SelectOp>(useReal, real, maxMinusOne);
484 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
485 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
486 Value minMaxRatio = b.
create<arith::DivFOp>(minAbs, maxAbs, fmfWithNaNInf);
487 Value logOfMaxAbsOfRealPlusOneAndImag =
488 b.
create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf);
490 b.
create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmfWithNaNInf),
493 b.
create<arith::MulFOp>(half, logOfSqrtPart, fmfWithNaNInf),
494 logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf);
496 b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf),
498 Value resultImag = b.
create<math::Atan2Op>(imag, realPlusOne, fmf);
509 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
512 auto type = cast<ComplexType>(adaptor.getLhs().getType());
513 auto elementType = cast<FloatType>(type.getElementType());
514 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
515 auto fmfValue = fmf.getValue();
516 Value lhsReal = b.
create<complex::ReOp>(elementType, adaptor.getLhs());
517 Value lhsImag = b.
create<complex::ImOp>(elementType, adaptor.getLhs());
518 Value rhsReal = b.
create<complex::ReOp>(elementType, adaptor.getRhs());
519 Value rhsImag = b.
create<complex::ImOp>(elementType, adaptor.getRhs());
520 Value lhsRealTimesRhsReal =
521 b.
create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
522 Value lhsImagTimesRhsImag =
523 b.
create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
524 Value real = b.
create<arith::SubFOp>(lhsRealTimesRhsReal,
525 lhsImagTimesRhsImag, fmfValue);
526 Value lhsImagTimesRhsReal =
527 b.
create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
528 Value lhsRealTimesRhsImag =
529 b.
create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
530 Value imag = b.
create<arith::AddFOp>(lhsImagTimesRhsReal,
531 lhsRealTimesRhsImag, fmfValue);
541 matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
543 auto loc = op.getLoc();
544 auto type = cast<ComplexType>(adaptor.getComplex().getType());
545 auto elementType = cast<FloatType>(type.getElementType());
548 rewriter.
create<complex::ReOp>(loc, elementType, adaptor.getComplex());
550 rewriter.
create<complex::ImOp>(loc, elementType, adaptor.getComplex());
551 Value negReal = rewriter.
create<arith::NegFOp>(loc, real);
552 Value negImag = rewriter.
create<arith::NegFOp>(loc, imag);
558 struct SinOpConversion :
public TrigonometricOpConversion<complex::SinOp> {
559 using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
564 arith::FastMathFlagsAttr fmf)
const override {
575 rewriter.
create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf);
576 Value resultReal = rewriter.
create<arith::MulFOp>(loc, sum, sin, fmf);
578 rewriter.
create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf);
579 Value resultImag = rewriter.
create<arith::MulFOp>(loc, diff, cos, fmf);
580 return {resultReal, resultImag};
589 matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
593 auto type = cast<ComplexType>(op.getType());
594 auto elementType = cast<FloatType>(type.getElementType());
595 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
597 auto cst = [&](APFloat v) {
598 return b.
create<arith::ConstantOp>(elementType,
601 const auto &floatSemantics = elementType.getFloatSemantics();
603 Value half = b.
create<arith::ConstantOp>(elementType,
606 Value real = b.
create<complex::ReOp>(elementType, adaptor.getComplex());
607 Value imag = b.
create<complex::ImOp>(elementType, adaptor.getComplex());
608 Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
609 Value argArg = b.
create<math::Atan2Op>(imag, real, fmf);
610 Value sqrtArg = b.
create<arith::MulFOp>(argArg, half, fmf);
616 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf);
618 Value resultReal = b.
create<arith::MulFOp>(absSqrt, cos, fmf);
620 sinIsZero, zero, b.
create<arith::MulFOp>(absSqrt, sin, fmf));
621 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
622 arith::FastMathFlags::ninf)) {
623 Value inf = cst(APFloat::getInf(floatSemantics));
624 Value negInf = cst(APFloat::getInf(floatSemantics,
true));
625 Value nan = cst(APFloat::getNaN(floatSemantics));
626 Value absImag = b.
create<math::AbsFOp>(elementType, imag, fmf);
629 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
630 Value absImagIsNotInf =
631 b.
create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf);
633 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf);
635 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf);
637 resultReal = b.
create<arith::SelectOp>(
638 b.
create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero,
640 resultReal = b.
create<arith::SelectOp>(
641 b.
create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal);
643 Value imagSignInf = b.
create<math::CopySignOp>(inf, imag, fmf);
644 resultImag = b.
create<arith::SelectOp>(
645 b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt),
647 resultImag = b.
create<arith::SelectOp>(
648 b.
create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf,
653 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
654 resultReal = b.
create<arith::SelectOp>(resultIsZero, zero, resultReal);
655 resultImag = b.
create<arith::SelectOp>(resultIsZero, zero, resultImag);
667 matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
669 auto type = cast<ComplexType>(adaptor.getComplex().getType());
670 auto elementType = cast<FloatType>(type.getElementType());
672 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
674 Value real = b.
create<complex::ReOp>(elementType, adaptor.getComplex());
675 Value imag = b.
create<complex::ImOp>(elementType, adaptor.getComplex());
679 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
681 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
682 Value isZero = b.
create<arith::AndIOp>(realIsZero, imagIsZero);
683 auto abs = b.
create<complex::AbsOp>(elementType, adaptor.getComplex(), fmf);
686 Value sign = b.
create<complex::CreateOp>(type, realSign, imagSign);
688 adaptor.getComplex(), sign);
693 template <
typename Op>
698 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
702 auto type = cast<ComplexType>(adaptor.getComplex().getType());
703 auto elementType = cast<FloatType>(type.getElementType());
704 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
705 const auto &floatSemantics = elementType.getFloatSemantics();
708 b.
create<complex::ReOp>(loc, elementType, adaptor.getComplex());
710 b.
create<complex::ImOp>(loc, elementType, adaptor.getComplex());
714 if constexpr (std::is_same_v<Op, complex::TanOp>) {
716 std::swap(real, imag);
717 real = b.
create<arith::MulFOp>(real, negOne, fmf);
720 auto cst = [&](APFloat v) {
721 return b.
create<arith::ConstantOp>(elementType,
724 Value inf = cst(APFloat::getInf(floatSemantics));
725 Value four = b.
create<arith::ConstantOp>(elementType,
727 Value twoReal = b.
create<arith::AddFOp>(real, real, fmf);
728 Value negTwoReal = b.
create<arith::MulFOp>(negOne, twoReal, fmf);
730 Value expTwoRealMinusOne = b.
create<math::ExpM1Op>(twoReal, fmf);
731 Value expNegTwoRealMinusOne = b.
create<math::ExpM1Op>(negTwoReal, fmf);
733 b.
create<arith::SubFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
736 Value cosImagSq = b.
create<arith::MulFOp>(cosImag, cosImag, fmf);
737 Value twoCosTwoImagPlusOne = b.
create<arith::MulFOp>(cosImagSq, four, fmf);
741 four, b.
create<arith::MulFOp>(cosImag, sinImag, fmf), fmf);
743 Value expSumMinusTwo =
744 b.
create<arith::AddFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
746 b.
create<arith::AddFOp>(expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
748 Value isInf = b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
749 expSumMinusTwo, inf, fmf);
750 Value realLimit = b.
create<math::CopySignOp>(negOne, real, fmf);
753 isInf, realLimit, b.
create<arith::DivFOp>(realNum, denom, fmf));
754 Value resultImag = b.
create<arith::DivFOp>(imagNum, denom, fmf);
756 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
757 arith::FastMathFlags::ninf)) {
761 Value nan = cst(APFloat::getNaN(floatSemantics));
764 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
766 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
768 absRealIsInf, b.
create<arith::ConstantIntOp>(
true, 1));
770 Value imagNumIsNaN = b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO,
771 imagNum, imagNum, fmf);
772 Value resultRealIsNaN =
773 b.
create<arith::AndIOp>(imagNumIsNaN, absRealIsNotInf);
775 imagIsZero, b.
create<arith::AndIOp>(absRealIsInf, imagNumIsNaN));
777 resultReal = b.
create<arith::SelectOp>(resultRealIsNaN, nan, resultReal);
779 b.
create<arith::SelectOp>(resultImagIsZero, zero, resultImag);
782 if constexpr (std::is_same_v<Op, complex::TanOp>) {
784 std::swap(resultReal, resultImag);
785 resultImag = b.
create<arith::MulFOp>(resultImag, negOne, fmf);
798 matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
800 auto loc = op.getLoc();
801 auto type = cast<ComplexType>(adaptor.getComplex().getType());
802 auto elementType = cast<FloatType>(type.getElementType());
804 rewriter.
create<complex::ReOp>(loc, elementType, adaptor.getComplex());
806 rewriter.
create<complex::ImOp>(loc, elementType, adaptor.getComplex());
807 Value negImag = rewriter.
create<arith::NegFOp>(loc, elementType, imag);
820 arith::FastMathFlags fmf) {
821 auto elementType = cast<FloatType>(type.getElementType());
829 Value negD = builder.
create<arith::NegFOp>(d, fmf);
830 Value argLhs = builder.
create<math::Atan2Op>(b, a, fmf);
831 Value negDArgLhs = builder.
create<arith::MulFOp>(negD, argLhs, fmf);
832 Value expNegDArgLhs = builder.
create<math::ExpOp>(negDArgLhs, fmf);
834 Value coeff = builder.
create<arith::MulFOp>(absToC, expNegDArgLhs, fmf);
836 Value cArgLhs = builder.
create<arith::MulFOp>(c, argLhs, fmf);
837 Value dLnAbs = builder.
create<arith::MulFOp>(d, lnAbs, fmf);
838 Value q = builder.
create<arith::AddFOp>(cArgLhs, dLnAbs, fmf);
845 APFloat::getInf(elementType.getFloatSemantics())));
850 Value complexOne = builder.
create<complex::CreateOp>(type, one, zero);
851 Value complexZero = builder.
create<complex::CreateOp>(type, zero, zero);
852 Value complexInf = builder.
create<complex::CreateOp>(type, inf, zero);
859 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
abs, zero, fmf);
861 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero, fmf);
863 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero, fmf);
865 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, b, zero, fmf);
868 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLE, zero, c, fmf);
869 Value coeffCosQ = builder.
create<arith::MulFOp>(coeff, cosQ, fmf);
870 Value coeffSinQ = builder.
create<arith::MulFOp>(coeff, sinQ, fmf);
871 Value complexOneOrZero =
872 builder.
create<arith::SelectOp>(cEqZero, complexOne, complexZero);
874 builder.
create<complex::CreateOp>(type, coeffCosQ, coeffSinQ);
876 builder.
create<arith::AndIOp>(
877 builder.
create<arith::AndIOp>(absEqZero, dEqZero), zeroLeC),
878 complexOneOrZero, coeffCosSin);
884 Value rhsEqZero = builder.
create<arith::AndIOp>(cEqZero, dEqZero);
886 builder.
create<arith::SelectOp>(rhsEqZero, complexOne, cutoff0);
891 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf),
894 builder.
create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1);
899 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf),
903 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf));
905 builder.
create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2);
909 Value rhsLt0 = builder.create<arith::AndIOp>(
911 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf));
912 Value cutoff4 = builder.create<arith::SelectOp>(
913 builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3);
922 matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
925 auto type = cast<ComplexType>(adaptor.getLhs().getType());
926 auto elementType = cast<FloatType>(type.getElementType());
928 Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
929 Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
931 rewriter.
replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
932 c, d, op.getFastmath())});
941 matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
944 auto type = cast<ComplexType>(adaptor.getComplex().getType());
945 auto elementType = cast<FloatType>(type.getElementType());
947 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
949 auto cst = [&](APFloat v) {
950 return b.
create<arith::ConstantOp>(elementType,
953 const auto &floatSemantics = elementType.getFloatSemantics();
955 Value inf = cst(APFloat::getInf(floatSemantics));
958 Value nan = cst(APFloat::getNaN(floatSemantics));
960 Value real = b.
create<complex::ReOp>(elementType, adaptor.getComplex());
961 Value imag = b.
create<complex::ImOp>(elementType, adaptor.getComplex());
962 Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
963 Value argArg = b.
create<math::Atan2Op>(imag, real, fmf);
964 Value rsqrtArg = b.
create<arith::MulFOp>(argArg, negHalf, fmf);
968 Value resultReal = b.
create<arith::MulFOp>(absRsqrt, cos, fmf);
969 Value resultImag = b.
create<arith::MulFOp>(absRsqrt, sin, fmf);
971 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
972 arith::FastMathFlags::ninf)) {
976 Value realSignedZero = b.
create<math::CopySignOp>(zero, real, fmf);
977 Value imagSignedZero = b.
create<math::CopySignOp>(zero, imag, fmf);
978 Value negImagSignedZero =
979 b.
create<arith::MulFOp>(negOne, imagSignedZero, fmf);
985 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
987 b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf);
989 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
990 Value inIsNanInf = b.
create<arith::AndIOp>(absImagIsInf, realIsNan);
992 Value resultIsZero = b.
create<arith::OrIOp>(inIsNanInf, realIsInf);
995 b.
create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal);
996 resultImag = b.
create<arith::SelectOp>(resultIsZero, negImagSignedZero,
1001 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf);
1003 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
1004 Value isZero = b.
create<arith::AndIOp>(isRealZero, isImagZero);
1006 resultReal = b.
create<arith::SelectOp>(isZero, inf, resultReal);
1007 resultImag = b.
create<arith::SelectOp>(isZero, nan, resultImag);
1019 matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
1021 auto loc = op.getLoc();
1022 auto type = op.getType();
1023 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
1026 rewriter.
create<complex::ReOp>(loc, type, adaptor.getComplex());
1028 rewriter.
create<complex::ImOp>(loc, type, adaptor.getComplex());
1045 BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1046 BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
1047 ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
1048 ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
1060 TanTanhOpConversion<complex::TanOp>,
1061 TanTanhOpConversion<complex::TanhOp>,
1072 struct ConvertComplexToStandardPass
1073 :
public impl::ConvertComplexToStandardPassBase<
1074 ConvertComplexToStandardPass> {
1075 using ConvertComplexToStandardPassBase::ConvertComplexToStandardPassBase;
1077 void runOnOperation()
override;
1080 void ConvertComplexToStandardPass::runOnOperation() {
1086 target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
1087 target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1090 signalPassFailure();
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
FloatAttr getFloatAttr(Type type, double value)
TypedAttr getZeroAttr(Type type)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void convertDivToStandardUsingAlgebraic(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the arith/math dialects using algebraic method
void convertDivToStandardUsingRangeReduction(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the arith/math dialects using Smith's method
Fraction abs(const Fraction &f)
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
Include the generated interface declarations.
void populateComplexToStandardConversionPatterns(RewritePatternSet &patterns, mlir::complex::ComplexRangeFlags complexRange=mlir::complex::ComplexRangeFlags::improved)
Populate the given list with patterns that convert from Complex to Standard.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.