17 #include <type_traits>
20 #define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARDPASS
21 #include "mlir/Conversion/Passes.h.inc"
28 enum class AbsFn {
abs, sqrt, rsqrt };
36 Value absReal = math::AbsFOp::create(b, real, fmf);
37 Value absImag = math::AbsFOp::create(b, imag, fmf);
39 Value max = arith::MaximumFOp::create(b, absReal, absImag, fmf);
40 Value min = arith::MinimumFOp::create(b, absReal, absImag, fmf);
43 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
44 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
45 Value ratio = arith::DivFOp::create(b,
min,
max, fmfWithNaNInf);
46 Value ratioSq = arith::MulFOp::create(b, ratio, ratio, fmfWithNaNInf);
47 Value ratioSqPlusOne = arith::AddFOp::create(b, ratioSq, one, fmfWithNaNInf);
50 if (fn == AbsFn::rsqrt) {
51 ratioSqPlusOne = math::RsqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf);
52 min = math::RsqrtOp::create(b,
min, fmfWithNaNInf);
53 max = math::RsqrtOp::create(b,
max, fmfWithNaNInf);
56 if (fn == AbsFn::sqrt) {
57 Value quarter = arith::ConstantOp::create(
60 Value sqrt = math::SqrtOp::create(b,
max, fmfWithNaNInf);
62 math::PowFOp::create(b, ratioSqPlusOne, quarter, fmfWithNaNInf);
63 result = arith::MulFOp::create(b, sqrt, p025, fmfWithNaNInf);
65 Value sqrt = math::SqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf);
66 result = arith::MulFOp::create(b,
max, sqrt, fmfWithNaNInf);
69 Value isNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, result,
70 result, fmfWithNaNInf);
71 return arith::SelectOp::create(b, isNaN,
min, result);
78 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
82 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
84 Value real = complex::ReOp::create(b, adaptor.getComplex());
85 Value imag = complex::ImOp::create(b, adaptor.getComplex());
86 rewriter.
replaceOp(op, computeAbs(real, imag, fmf, b));
97 matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
101 auto type = cast<ComplexType>(op.getType());
102 Type elementType = type.getElementType();
103 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
105 Value lhs = adaptor.getLhs();
106 Value rhs = adaptor.getRhs();
108 Value rhsSquared = complex::MulOp::create(b, type, rhs, rhs, fmf);
109 Value lhsSquared = complex::MulOp::create(b, type, lhs, lhs, fmf);
110 Value rhsSquaredPlusLhsSquared =
111 complex::AddOp::create(b, type, rhsSquared, lhsSquared, fmf);
112 Value sqrtOfRhsSquaredPlusLhsSquared =
113 complex::SqrtOp::create(b, type, rhsSquaredPlusLhsSquared, fmf);
116 arith::ConstantOp::create(b, elementType, b.
getZeroAttr(elementType));
117 Value one = arith::ConstantOp::create(b, elementType,
119 Value i = complex::CreateOp::create(b, type, zero, one);
120 Value iTimesLhs = complex::MulOp::create(b, i, lhs, fmf);
121 Value rhsPlusILhs = complex::AddOp::create(b, rhs, iTimesLhs, fmf);
123 Value divResult = complex::DivOp::create(
124 b, rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
125 Value logResult = complex::LogOp::create(b, divResult, fmf);
127 Value negativeOne = arith::ConstantOp::create(
129 Value negativeI = complex::CreateOp::create(b, type, zero, negativeOne);
136 template <
typename ComparisonOp, arith::CmpFPredicate p>
139 using ResultCombiner =
140 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
141 arith::AndIOp, arith::OrIOp>;
144 matchAndRewrite(ComparisonOp op,
typename ComparisonOp::Adaptor adaptor,
146 auto loc = op.getLoc();
147 auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
150 complex::ReOp::create(rewriter, loc, type, adaptor.getLhs());
152 complex::ImOp::create(rewriter, loc, type, adaptor.getLhs());
154 complex::ReOp::create(rewriter, loc, type, adaptor.getRhs());
156 complex::ImOp::create(rewriter, loc, type, adaptor.getRhs());
157 Value realComparison =
158 arith::CmpFOp::create(rewriter, loc, p, realLhs, realRhs);
159 Value imagComparison =
160 arith::CmpFOp::create(rewriter, loc, p, imagLhs, imagRhs);
171 template <
typename BinaryComplexOp,
typename BinaryStandardOp>
176 matchAndRewrite(BinaryComplexOp op,
typename BinaryComplexOp::Adaptor adaptor,
178 auto type = cast<ComplexType>(adaptor.getLhs().getType());
179 auto elementType = cast<FloatType>(type.getElementType());
181 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
183 Value realLhs = complex::ReOp::create(b, elementType, adaptor.getLhs());
184 Value realRhs = complex::ReOp::create(b, elementType, adaptor.getRhs());
185 Value resultReal = BinaryStandardOp::create(b, elementType, realLhs,
186 realRhs, fmf.getValue());
187 Value imagLhs = complex::ImOp::create(b, elementType, adaptor.getLhs());
188 Value imagRhs = complex::ImOp::create(b, elementType, adaptor.getRhs());
189 Value resultImag = BinaryStandardOp::create(b, elementType, imagLhs,
190 imagRhs, fmf.getValue());
197 template <
typename TrigonometricOp>
204 matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
206 auto loc = op.getLoc();
207 auto type = cast<ComplexType>(adaptor.getComplex().getType());
208 auto elementType = cast<FloatType>(type.getElementType());
209 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
212 complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
214 complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
219 Value half = arith::ConstantOp::create(
220 rewriter, loc, elementType, rewriter.
getFloatAttr(elementType, 0.5));
221 Value exp = math::ExpOp::create(rewriter, loc, imag, fmf);
222 Value scaledExp = arith::MulFOp::create(rewriter, loc, half, exp, fmf);
223 Value reciprocalExp = arith::DivFOp::create(rewriter, loc, half, exp, fmf);
224 Value sin = math::SinOp::create(rewriter, loc, real, fmf);
225 Value cos = math::CosOp::create(rewriter, loc, real, fmf);
228 combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
235 virtual std::pair<Value, Value>
238 arith::FastMathFlagsAttr fmf)
const = 0;
241 struct CosOpConversion :
public TrigonometricOpConversion<complex::CosOp> {
242 using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
247 arith::FastMathFlagsAttr fmf)
const override {
258 arith::AddFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf);
259 Value resultReal = arith::MulFOp::create(rewriter, loc, sum, cos, fmf);
261 arith::SubFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf);
262 Value resultImag = arith::MulFOp::create(rewriter, loc, diff, sin, fmf);
263 return {resultReal, resultImag};
268 DivOpConversion(
MLIRContext *context, complex::ComplexRangeFlags target)
274 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
276 auto loc = op.getLoc();
277 auto type = cast<ComplexType>(adaptor.getLhs().getType());
278 auto elementType = cast<FloatType>(type.getElementType());
279 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
282 complex::ReOp::create(rewriter, loc, elementType, adaptor.getLhs());
284 complex::ImOp::create(rewriter, loc, elementType, adaptor.getLhs());
286 complex::ReOp::create(rewriter, loc, elementType, adaptor.getRhs());
288 complex::ImOp::create(rewriter, loc, elementType, adaptor.getRhs());
290 Value resultReal, resultImag;
292 if (complexRange == complex::ComplexRangeFlags::basic ||
293 complexRange == complex::ComplexRangeFlags::none) {
295 rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal,
297 }
else if (complexRange == complex::ComplexRangeFlags::improved) {
299 rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal,
310 complex::ComplexRangeFlags complexRange;
321 matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
323 auto loc = op.getLoc();
324 auto type = cast<ComplexType>(adaptor.getComplex().getType());
325 auto ET = cast<FloatType>(type.getElementType());
326 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
327 const auto &floatSemantics = ET.getFloatSemantics();
330 Value x = complex::ReOp::create(b, ET, adaptor.getComplex());
331 Value y = complex::ImOp::create(b, ET, adaptor.getComplex());
334 Value inf = arith::ConstantOp::create(
335 b, ET, b.
getFloatAttr(ET, APFloat::getInf(floatSemantics)));
337 Value exp = math::ExpOp::create(b, x, fmf);
338 Value xHalf = arith::MulFOp::create(b, x, half, fmf);
339 Value expHalf = math::ExpOp::create(b, xHalf, fmf);
340 Value cos = math::CosOp::create(b, y, fmf);
341 Value sin = math::SinOp::create(b, y, fmf);
344 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, exp, inf, fmf);
346 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, y, zero);
349 Value realNormal = arith::MulFOp::create(b, exp, cos, fmf);
350 Value expHalfCos = arith::MulFOp::create(b, expHalf, cos, fmf);
351 Value realOverflow = arith::MulFOp::create(b, expHalfCos, expHalf, fmf);
353 arith::SelectOp::create(b, expIsInf, realOverflow, realNormal);
357 Value imagNormal = arith::MulFOp::create(b, exp, sin, fmf);
358 Value expHalfSin = arith::MulFOp::create(b, expHalf, sin, fmf);
359 Value imagOverflow = arith::MulFOp::create(b, expHalfSin, expHalf, fmf);
361 arith::SelectOp::create(b, expIsInf, imagOverflow, imagNormal);
362 Value resultImag = arith::SelectOp::create(b, yIsZero, zero, imagNonZero);
372 arith::FastMathFlagsAttr fmf) {
373 auto argType = mlir::cast<FloatType>(arg.
getType());
375 arith::ConstantOp::create(b, b.
getFloatAttr(argType, coefficients[0]));
376 for (
unsigned i = 1; i < coefficients.size(); ++i) {
377 poly = math::FmaOp::create(
379 arith::ConstantOp::create(b, b.
getFloatAttr(argType, coefficients[i])),
393 matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
395 auto type = op.getType();
396 auto elemType = mlir::cast<FloatType>(type.getElementType());
398 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
400 Value real = complex::ReOp::create(b, adaptor.getComplex());
401 Value imag = complex::ImOp::create(b, adaptor.getComplex());
406 Value expm1Real = math::ExpM1Op::create(b, real, fmf);
407 Value expReal = arith::AddFOp::create(b, expm1Real, one, fmf);
409 Value sinImag = math::SinOp::create(b, imag, fmf);
410 Value cosm1Imag = emitCosm1(imag, fmf, b);
411 Value cosImag = arith::AddFOp::create(b, cosm1Imag, one, fmf);
413 Value realResult = arith::AddFOp::create(
414 b, arith::MulFOp::create(b, expm1Real, cosImag, fmf), cosm1Imag, fmf);
416 Value imagIsZero = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag,
417 zero, fmf.getValue());
418 Value imagResult = arith::SelectOp::create(
419 b, imagIsZero, zero, arith::MulFOp::create(b, expReal, sinImag, fmf));
427 Value emitCosm1(
Value arg, arith::FastMathFlagsAttr fmf,
429 auto argType = mlir::cast<FloatType>(arg.
getType());
430 auto negHalf = arith::ConstantOp::create(b, b.
getFloatAttr(argType, -0.5));
431 auto negOne = arith::ConstantOp::create(b, b.
getFloatAttr(argType, -1.0));
435 4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
436 2.0876754287081521758361E-9, -2.7557319214999787979814E-7,
437 2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
438 4.1666666666666666609054E-2,
440 Value cos = math::CosOp::create(b, arg, fmf);
441 Value forLargeArg = arith::AddFOp::create(b, cos, negOne, fmf);
443 Value argPow2 = arith::MulFOp::create(b, arg, arg, fmf);
444 Value argPow4 = arith::MulFOp::create(b, argPow2, argPow2, fmf);
445 Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf);
448 arith::AddFOp::create(b, arith::MulFOp::create(b, argPow4, poly, fmf),
449 arith::MulFOp::create(b, negHalf, argPow2, fmf));
453 arith::ConstantOp::create(b, b.
getFloatAttr(argType, 0.61685));
454 Value cond = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, argPow2,
455 piOver4Pow2, fmf.getValue());
456 return arith::SelectOp::create(b, cond, forLargeArg, forSmallArg);
464 matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
466 auto type = cast<ComplexType>(adaptor.getComplex().getType());
467 auto elementType = cast<FloatType>(type.getElementType());
468 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
471 Value abs = complex::AbsOp::create(b, elementType, adaptor.getComplex(),
473 Value resultReal = math::LogOp::create(b, elementType, abs, fmf.getValue());
474 Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
475 Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
477 math::Atan2Op::create(b, elementType, imag, real, fmf.getValue());
488 matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
490 auto type = cast<ComplexType>(adaptor.getComplex().getType());
491 auto elementType = cast<FloatType>(type.getElementType());
492 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
495 Value real = complex::ReOp::create(b, adaptor.getComplex());
496 Value imag = complex::ImOp::create(b, adaptor.getComplex());
498 Value half = arith::ConstantOp::create(b, elementType,
500 Value one = arith::ConstantOp::create(b, elementType,
502 Value realPlusOne = arith::AddFOp::create(b, real, one, fmf);
503 Value absRealPlusOne = math::AbsFOp::create(b, realPlusOne, fmf);
504 Value absImag = math::AbsFOp::create(b, imag, fmf);
506 Value maxAbs = arith::MaximumFOp::create(b, absRealPlusOne, absImag, fmf);
507 Value minAbs = arith::MinimumFOp::create(b, absRealPlusOne, absImag, fmf);
509 Value useReal = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT,
510 realPlusOne, absImag, fmf);
511 Value maxMinusOne = arith::SubFOp::create(b, maxAbs, one, fmf);
512 Value maxAbsOfRealPlusOneAndImagMinusOne =
513 arith::SelectOp::create(b, useReal, real, maxMinusOne);
514 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
515 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
516 Value minMaxRatio = arith::DivFOp::create(b, minAbs, maxAbs, fmfWithNaNInf);
517 Value logOfMaxAbsOfRealPlusOneAndImag =
518 math::Log1pOp::create(b, maxAbsOfRealPlusOneAndImagMinusOne, fmf);
519 Value logOfSqrtPart = math::Log1pOp::create(
520 b, arith::MulFOp::create(b, minMaxRatio, minMaxRatio, fmfWithNaNInf),
522 Value r = arith::AddFOp::create(
523 b, arith::MulFOp::create(b, half, logOfSqrtPart, fmfWithNaNInf),
524 logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf);
525 Value resultReal = arith::SelectOp::create(
527 arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, r, r,
530 Value resultImag = math::Atan2Op::create(b, imag, realPlusOne, fmf);
541 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
544 auto type = cast<ComplexType>(adaptor.getLhs().getType());
545 auto elementType = cast<FloatType>(type.getElementType());
546 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
547 auto fmfValue = fmf.getValue();
548 Value lhsReal = complex::ReOp::create(b, elementType, adaptor.getLhs());
549 Value lhsImag = complex::ImOp::create(b, elementType, adaptor.getLhs());
550 Value rhsReal = complex::ReOp::create(b, elementType, adaptor.getRhs());
551 Value rhsImag = complex::ImOp::create(b, elementType, adaptor.getRhs());
552 Value lhsRealTimesRhsReal =
553 arith::MulFOp::create(b, lhsReal, rhsReal, fmfValue);
554 Value lhsImagTimesRhsImag =
555 arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue);
556 Value real = arith::SubFOp::create(b, lhsRealTimesRhsReal,
557 lhsImagTimesRhsImag, fmfValue);
558 Value lhsImagTimesRhsReal =
559 arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue);
560 Value lhsRealTimesRhsImag =
561 arith::MulFOp::create(b, lhsReal, rhsImag, fmfValue);
562 Value imag = arith::AddFOp::create(b, lhsImagTimesRhsReal,
563 lhsRealTimesRhsImag, fmfValue);
573 matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
575 auto loc = op.getLoc();
576 auto type = cast<ComplexType>(adaptor.getComplex().getType());
577 auto elementType = cast<FloatType>(type.getElementType());
580 complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
582 complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
583 Value negReal = arith::NegFOp::create(rewriter, loc, real);
584 Value negImag = arith::NegFOp::create(rewriter, loc, imag);
590 struct SinOpConversion :
public TrigonometricOpConversion<complex::SinOp> {
591 using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
596 arith::FastMathFlagsAttr fmf)
const override {
607 arith::AddFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf);
608 Value resultReal = arith::MulFOp::create(rewriter, loc, sum, sin, fmf);
610 arith::SubFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf);
611 Value resultImag = arith::MulFOp::create(rewriter, loc, diff, cos, fmf);
612 return {resultReal, resultImag};
621 matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
625 auto type = cast<ComplexType>(op.getType());
626 auto elementType = cast<FloatType>(type.getElementType());
627 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
629 auto cst = [&](APFloat v) {
630 return arith::ConstantOp::create(b, elementType,
633 const auto &floatSemantics = elementType.getFloatSemantics();
635 Value half = arith::ConstantOp::create(b, elementType,
638 Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
639 Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
640 Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
641 Value argArg = math::Atan2Op::create(b, imag, real, fmf);
642 Value sqrtArg = arith::MulFOp::create(b, argArg, half, fmf);
643 Value cos = math::CosOp::create(b, sqrtArg, fmf);
644 Value sin = math::SinOp::create(b, sqrtArg, fmf);
648 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, sin, zero, fmf);
650 Value resultReal = arith::MulFOp::create(b, absSqrt, cos, fmf);
651 Value resultImag = arith::SelectOp::create(
652 b, sinIsZero, zero, arith::MulFOp::create(b, absSqrt, sin, fmf));
653 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
654 arith::FastMathFlags::ninf)) {
655 Value inf = cst(APFloat::getInf(floatSemantics));
656 Value negInf = cst(APFloat::getInf(floatSemantics,
true));
657 Value nan = cst(APFloat::getNaN(floatSemantics));
658 Value absImag = math::AbsFOp::create(b, elementType, imag, fmf);
660 Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
662 Value absImagIsNotInf = arith::CmpFOp::create(
663 b, arith::CmpFPredicate::ONE, absImag, inf, fmf);
665 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, inf, fmf);
666 Value realIsNegInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
669 resultReal = arith::SelectOp::create(
670 b, arith::AndIOp::create(b, realIsNegInf, absImagIsNotInf), zero,
672 resultReal = arith::SelectOp::create(
673 b, arith::OrIOp::create(b, absImagIsInf, realIsInf), inf, resultReal);
675 Value imagSignInf = math::CopySignOp::create(b, inf, imag, fmf);
676 resultImag = arith::SelectOp::create(
678 arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, absSqrt, absSqrt),
680 resultImag = arith::SelectOp::create(
681 b, arith::OrIOp::create(b, absImagIsInf, realIsNegInf), imagSignInf,
686 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
687 resultReal = arith::SelectOp::create(b, resultIsZero, zero, resultReal);
688 resultImag = arith::SelectOp::create(b, resultIsZero, zero, resultImag);
700 matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
702 auto type = cast<ComplexType>(adaptor.getComplex().getType());
703 auto elementType = cast<FloatType>(type.getElementType());
705 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
707 Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
708 Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
710 arith::ConstantOp::create(b, elementType, b.
getZeroAttr(elementType));
712 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero);
714 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero);
715 Value isZero = arith::AndIOp::create(b, realIsZero, imagIsZero);
717 complex::AbsOp::create(b, elementType, adaptor.getComplex(), fmf);
718 Value realSign = arith::DivFOp::create(b, real, abs, fmf);
719 Value imagSign = arith::DivFOp::create(b, imag, abs, fmf);
720 Value sign = complex::CreateOp::create(b, type, realSign, imagSign);
722 adaptor.getComplex(), sign);
727 template <
typename Op>
732 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
736 auto type = cast<ComplexType>(adaptor.getComplex().getType());
737 auto elementType = cast<FloatType>(type.getElementType());
738 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
739 const auto &floatSemantics = elementType.getFloatSemantics();
742 complex::ReOp::create(b, loc, elementType, adaptor.getComplex());
744 complex::ImOp::create(b, loc, elementType, adaptor.getComplex());
745 Value negOne = arith::ConstantOp::create(b, elementType,
748 if constexpr (std::is_same_v<Op, complex::TanOp>) {
750 std::swap(real, imag);
751 real = arith::MulFOp::create(b, real, negOne, fmf);
754 auto cst = [&](APFloat v) {
755 return arith::ConstantOp::create(b, elementType,
758 Value inf = cst(APFloat::getInf(floatSemantics));
759 Value four = arith::ConstantOp::create(b, elementType,
761 Value twoReal = arith::AddFOp::create(b, real, real, fmf);
762 Value negTwoReal = arith::MulFOp::create(b, negOne, twoReal, fmf);
764 Value expTwoRealMinusOne = math::ExpM1Op::create(b, twoReal, fmf);
765 Value expNegTwoRealMinusOne = math::ExpM1Op::create(b, negTwoReal, fmf);
766 Value realNum = arith::SubFOp::create(b, expTwoRealMinusOne,
767 expNegTwoRealMinusOne, fmf);
769 Value cosImag = math::CosOp::create(b, imag, fmf);
770 Value cosImagSq = arith::MulFOp::create(b, cosImag, cosImag, fmf);
771 Value twoCosTwoImagPlusOne = arith::MulFOp::create(b, cosImagSq, four, fmf);
772 Value sinImag = math::SinOp::create(b, imag, fmf);
774 Value imagNum = arith::MulFOp::create(
775 b, four, arith::MulFOp::create(b, cosImag, sinImag, fmf), fmf);
777 Value expSumMinusTwo = arith::AddFOp::create(b, expTwoRealMinusOne,
778 expNegTwoRealMinusOne, fmf);
780 arith::AddFOp::create(b, expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
782 Value isInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
783 expSumMinusTwo, inf, fmf);
784 Value realLimit = math::CopySignOp::create(b, negOne, real, fmf);
786 Value resultReal = arith::SelectOp::create(
787 b, isInf, realLimit, arith::DivFOp::create(b, realNum, denom, fmf));
788 Value resultImag = arith::DivFOp::create(b, imagNum, denom, fmf);
790 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
791 arith::FastMathFlags::ninf)) {
792 Value absReal = math::AbsFOp::create(b, real, fmf);
793 Value zero = arith::ConstantOp::create(b, elementType,
795 Value nan = cst(APFloat::getNaN(floatSemantics));
797 Value absRealIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
800 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf);
801 Value absRealIsNotInf = arith::XOrIOp::create(
804 Value imagNumIsNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO,
805 imagNum, imagNum, fmf);
806 Value resultRealIsNaN =
807 arith::AndIOp::create(b, imagNumIsNaN, absRealIsNotInf);
808 Value resultImagIsZero = arith::OrIOp::create(
809 b, imagIsZero, arith::AndIOp::create(b, absRealIsInf, imagNumIsNaN));
811 resultReal = arith::SelectOp::create(b, resultRealIsNaN, nan, resultReal);
813 arith::SelectOp::create(b, resultImagIsZero, zero, resultImag);
816 if constexpr (std::is_same_v<Op, complex::TanOp>) {
818 std::swap(resultReal, resultImag);
819 resultImag = arith::MulFOp::create(b, resultImag, negOne, fmf);
832 matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
834 auto loc = op.getLoc();
835 auto type = cast<ComplexType>(adaptor.getComplex().getType());
836 auto elementType = cast<FloatType>(type.getElementType());
838 complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
840 complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
841 Value negImag = arith::NegFOp::create(rewriter, loc, elementType, imag);
854 arith::FastMathFlags fmf) {
855 auto elementType = cast<FloatType>(type.getElementType());
857 Value a = complex::ReOp::create(builder, lhs);
858 Value b = complex::ImOp::create(builder, lhs);
860 Value abs = complex::AbsOp::create(builder, lhs, fmf);
861 Value absToC = math::PowFOp::create(builder, abs, c, fmf);
863 Value negD = arith::NegFOp::create(builder, d, fmf);
864 Value argLhs = math::Atan2Op::create(builder, b, a, fmf);
865 Value negDArgLhs = arith::MulFOp::create(builder, negD, argLhs, fmf);
866 Value expNegDArgLhs = math::ExpOp::create(builder, negDArgLhs, fmf);
868 Value coeff = arith::MulFOp::create(builder, absToC, expNegDArgLhs, fmf);
869 Value lnAbs = math::LogOp::create(builder, abs, fmf);
870 Value cArgLhs = arith::MulFOp::create(builder, c, argLhs, fmf);
871 Value dLnAbs = arith::MulFOp::create(builder, d, lnAbs, fmf);
872 Value q = arith::AddFOp::create(builder, cArgLhs, dLnAbs, fmf);
873 Value cosQ = math::CosOp::create(builder, q, fmf);
874 Value sinQ = math::SinOp::create(builder, q, fmf);
876 Value inf = arith::ConstantOp::create(
877 builder, elementType,
879 APFloat::getInf(elementType.getFloatSemantics())));
880 Value zero = arith::ConstantOp::create(
881 builder, elementType, builder.
getFloatAttr(elementType, 0.0));
882 Value one = arith::ConstantOp::create(builder, elementType,
884 Value complexOne = complex::CreateOp::create(builder, type, one, zero);
885 Value complexZero = complex::CreateOp::create(builder, type, zero, zero);
886 Value complexInf = complex::CreateOp::create(builder, type, inf, zero);
893 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, abs, zero, fmf);
895 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, d, zero, fmf);
897 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, c, zero, fmf);
899 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, b, zero, fmf);
902 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLE, zero, c, fmf);
903 Value coeffCosQ = arith::MulFOp::create(builder, coeff, cosQ, fmf);
904 Value coeffSinQ = arith::MulFOp::create(builder, coeff, sinQ, fmf);
905 Value complexOneOrZero =
906 arith::SelectOp::create(builder, cEqZero, complexOne, complexZero);
908 complex::CreateOp::create(builder, type, coeffCosQ, coeffSinQ);
909 Value cutoff0 = arith::SelectOp::create(
911 arith::AndIOp::create(
912 builder, arith::AndIOp::create(builder, absEqZero, dEqZero), zeroLeC),
913 complexOneOrZero, coeffCosSin);
919 Value rhsEqZero = arith::AndIOp::create(builder, cEqZero, dEqZero);
921 arith::SelectOp::create(builder, rhsEqZero, complexOne, cutoff0);
925 Value lhsEqOne = arith::AndIOp::create(
927 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, one, fmf),
930 arith::SelectOp::create(builder, lhsEqOne, complexOne, cutoff1);
934 Value lhsEqInf = arith::AndIOp::create(
936 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, inf, fmf),
938 Value rhsGt0 = arith::AndIOp::create(
940 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, c, zero, fmf));
941 Value cutoff3 = arith::SelectOp::create(
942 builder, arith::AndIOp::create(builder, lhsEqInf, rhsGt0), complexInf,
947 Value rhsLt0 = arith::AndIOp::create(
949 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, c, zero, fmf));
950 Value cutoff4 = arith::SelectOp::create(
951 builder, arith::AndIOp::create(builder, lhsEqInf, rhsLt0), complexZero,
961 matchAndRewrite(complex::PowiOp op, OpAdaptor adaptor,
964 auto type = cast<ComplexType>(op.getType());
965 auto elementType = cast<FloatType>(type.getElementType());
967 Value floatExponent =
968 arith::SIToFPOp::create(builder, elementType, adaptor.getRhs());
969 Value zero = arith::ConstantOp::create(
970 builder, elementType, builder.
getFloatAttr(elementType, 0.0));
971 Value complexExponent =
972 complex::CreateOp::create(builder, type, floatExponent, zero);
974 auto pow = complex::PowOp::create(builder, type, adaptor.getLhs(),
975 complexExponent, op.getFastmathAttr());
985 matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
988 auto type = cast<ComplexType>(adaptor.getLhs().getType());
989 auto elementType = cast<FloatType>(type.getElementType());
991 Value c = complex::ReOp::create(builder, elementType, adaptor.getRhs());
992 Value d = complex::ImOp::create(builder, elementType, adaptor.getRhs());
994 rewriter.
replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
995 c, d, op.getFastmath())});
1004 matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
1007 auto type = cast<ComplexType>(adaptor.getComplex().getType());
1008 auto elementType = cast<FloatType>(type.getElementType());
1010 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
1012 auto cst = [&](APFloat v) {
1013 return arith::ConstantOp::create(b, elementType,
1014 b.getFloatAttr(elementType, v));
1016 const auto &floatSemantics = elementType.getFloatSemantics();
1018 Value inf = cst(APFloat::getInf(floatSemantics));
1019 Value negHalf = arith::ConstantOp::create(
1020 b, elementType, b.getFloatAttr(elementType, -0.5));
1021 Value nan = cst(APFloat::getNaN(floatSemantics));
1023 Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
1024 Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
1025 Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
1026 Value argArg = math::Atan2Op::create(b, imag, real, fmf);
1027 Value rsqrtArg = arith::MulFOp::create(b, argArg, negHalf, fmf);
1028 Value cos = math::CosOp::create(b, rsqrtArg, fmf);
1029 Value sin = math::SinOp::create(b, rsqrtArg, fmf);
1031 Value resultReal = arith::MulFOp::create(b, absRsqrt, cos, fmf);
1032 Value resultImag = arith::MulFOp::create(b, absRsqrt, sin, fmf);
1034 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1035 arith::FastMathFlags::ninf)) {
1036 Value negOne = arith::ConstantOp::create(b, elementType,
1037 b.getFloatAttr(elementType, -1));
1039 Value realSignedZero = math::CopySignOp::create(b, zero, real, fmf);
1040 Value imagSignedZero = math::CopySignOp::create(b, zero, imag, fmf);
1041 Value negImagSignedZero =
1042 arith::MulFOp::create(b, negOne, imagSignedZero, fmf);
1044 Value absReal = math::AbsFOp::create(b, real, fmf);
1045 Value absImag = math::AbsFOp::create(b, imag, fmf);
1047 Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
1050 arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, real, real, fmf);
1051 Value realIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
1053 Value inIsNanInf = arith::AndIOp::create(b, absImagIsInf, realIsNan);
1055 Value resultIsZero = arith::OrIOp::create(b, inIsNanInf, realIsInf);
1058 arith::SelectOp::create(b, resultIsZero, realSignedZero, resultReal);
1059 resultImag = arith::SelectOp::create(b, resultIsZero, negImagSignedZero,
1064 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero, fmf);
1066 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf);
1067 Value isZero = arith::AndIOp::create(b, isRealZero, isImagZero);
1069 resultReal = arith::SelectOp::create(b, isZero, inf, resultReal);
1070 resultImag = arith::SelectOp::create(b, isZero, nan, resultImag);
1082 matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
1084 auto loc = op.getLoc();
1085 auto type = op.getType();
1086 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
1089 complex::ReOp::create(rewriter, loc, type, adaptor.getComplex());
1091 complex::ImOp::create(rewriter, loc, type, adaptor.getComplex());
1108 BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1109 BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
1110 ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
1111 ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
1123 TanTanhOpConversion<complex::TanOp>,
1124 TanTanhOpConversion<complex::TanhOp>,
1136 struct ConvertComplexToStandardPass
1137 :
public impl::ConvertComplexToStandardPassBase<
1138 ConvertComplexToStandardPass> {
1141 void runOnOperation()
override;
1144 void ConvertComplexToStandardPass::runOnOperation() {
1150 target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
1151 target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1154 signalPassFailure();
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
FloatAttr getFloatAttr(Type type, double value)
TypedAttr getZeroAttr(Type type)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
void convertDivToStandardUsingAlgebraic(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the arith/math dialects using algebraic method
void convertDivToStandardUsingRangeReduction(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the arith/math dialects using Smith's method
Fraction abs(const Fraction &f)
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
Include the generated interface declarations.
void populateComplexToStandardConversionPatterns(RewritePatternSet &patterns, mlir::complex::ComplexRangeFlags complexRange=mlir::complex::ComplexRangeFlags::improved)
Populate the given list with patterns that convert from Complex to Standard.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.