17 #include <type_traits>
20 #define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARDPASS
21 #include "mlir/Conversion/Passes.h.inc"
28 enum class AbsFn {
abs, sqrt, rsqrt };
36 Value absReal = math::AbsFOp::create(b, real, fmf);
37 Value absImag = math::AbsFOp::create(b, imag, fmf);
39 Value max = arith::MaximumFOp::create(b, absReal, absImag, fmf);
40 Value min = arith::MinimumFOp::create(b, absReal, absImag, fmf);
43 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
44 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
45 Value ratio = arith::DivFOp::create(b,
min,
max, fmfWithNaNInf);
46 Value ratioSq = arith::MulFOp::create(b, ratio, ratio, fmfWithNaNInf);
47 Value ratioSqPlusOne = arith::AddFOp::create(b, ratioSq, one, fmfWithNaNInf);
50 if (fn == AbsFn::rsqrt) {
51 ratioSqPlusOne = math::RsqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf);
52 min = math::RsqrtOp::create(b,
min, fmfWithNaNInf);
53 max = math::RsqrtOp::create(b,
max, fmfWithNaNInf);
56 if (fn == AbsFn::sqrt) {
57 Value quarter = arith::ConstantOp::create(
60 Value sqrt = math::SqrtOp::create(b,
max, fmfWithNaNInf);
62 math::PowFOp::create(b, ratioSqPlusOne, quarter, fmfWithNaNInf);
63 result = arith::MulFOp::create(b, sqrt, p025, fmfWithNaNInf);
65 Value sqrt = math::SqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf);
66 result = arith::MulFOp::create(b,
max, sqrt, fmfWithNaNInf);
69 Value isNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, result,
70 result, fmfWithNaNInf);
71 return arith::SelectOp::create(b, isNaN,
min, result);
78 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
82 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
84 Value real = complex::ReOp::create(b, adaptor.getComplex());
85 Value imag = complex::ImOp::create(b, adaptor.getComplex());
86 rewriter.
replaceOp(op, computeAbs(real, imag, fmf, b));
97 matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
101 auto type = cast<ComplexType>(op.getType());
102 Type elementType = type.getElementType();
103 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
105 Value lhs = adaptor.getLhs();
106 Value rhs = adaptor.getRhs();
108 Value rhsSquared = complex::MulOp::create(b, type, rhs, rhs, fmf);
109 Value lhsSquared = complex::MulOp::create(b, type, lhs, lhs, fmf);
110 Value rhsSquaredPlusLhsSquared =
111 complex::AddOp::create(b, type, rhsSquared, lhsSquared, fmf);
112 Value sqrtOfRhsSquaredPlusLhsSquared =
113 complex::SqrtOp::create(b, type, rhsSquaredPlusLhsSquared, fmf);
116 arith::ConstantOp::create(b, elementType, b.
getZeroAttr(elementType));
117 Value one = arith::ConstantOp::create(b, elementType,
119 Value i = complex::CreateOp::create(b, type, zero, one);
120 Value iTimesLhs = complex::MulOp::create(b, i, lhs, fmf);
121 Value rhsPlusILhs = complex::AddOp::create(b, rhs, iTimesLhs, fmf);
123 Value divResult = complex::DivOp::create(
124 b, rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
125 Value logResult = complex::LogOp::create(b, divResult, fmf);
127 Value negativeOne = arith::ConstantOp::create(
129 Value negativeI = complex::CreateOp::create(b, type, zero, negativeOne);
136 template <
typename ComparisonOp, arith::CmpFPredicate p>
139 using ResultCombiner =
140 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
141 arith::AndIOp, arith::OrIOp>;
144 matchAndRewrite(ComparisonOp op,
typename ComparisonOp::Adaptor adaptor,
146 auto loc = op.getLoc();
147 auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
150 complex::ReOp::create(rewriter, loc, type, adaptor.getLhs());
152 complex::ImOp::create(rewriter, loc, type, adaptor.getLhs());
154 complex::ReOp::create(rewriter, loc, type, adaptor.getRhs());
156 complex::ImOp::create(rewriter, loc, type, adaptor.getRhs());
157 Value realComparison =
158 arith::CmpFOp::create(rewriter, loc, p, realLhs, realRhs);
159 Value imagComparison =
160 arith::CmpFOp::create(rewriter, loc, p, imagLhs, imagRhs);
171 template <
typename BinaryComplexOp,
typename BinaryStandardOp>
176 matchAndRewrite(BinaryComplexOp op,
typename BinaryComplexOp::Adaptor adaptor,
178 auto type = cast<ComplexType>(adaptor.getLhs().getType());
179 auto elementType = cast<FloatType>(type.getElementType());
181 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
183 Value realLhs = complex::ReOp::create(b, elementType, adaptor.getLhs());
184 Value realRhs = complex::ReOp::create(b, elementType, adaptor.getRhs());
185 Value resultReal = BinaryStandardOp::create(b, elementType, realLhs,
186 realRhs, fmf.getValue());
187 Value imagLhs = complex::ImOp::create(b, elementType, adaptor.getLhs());
188 Value imagRhs = complex::ImOp::create(b, elementType, adaptor.getRhs());
189 Value resultImag = BinaryStandardOp::create(b, elementType, imagLhs,
190 imagRhs, fmf.getValue());
197 template <
typename TrigonometricOp>
204 matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
206 auto loc = op.getLoc();
207 auto type = cast<ComplexType>(adaptor.getComplex().getType());
208 auto elementType = cast<FloatType>(type.getElementType());
209 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
212 complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
214 complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
219 Value half = arith::ConstantOp::create(
220 rewriter, loc, elementType, rewriter.
getFloatAttr(elementType, 0.5));
221 Value exp = math::ExpOp::create(rewriter, loc, imag, fmf);
222 Value scaledExp = arith::MulFOp::create(rewriter, loc, half, exp, fmf);
223 Value reciprocalExp = arith::DivFOp::create(rewriter, loc, half, exp, fmf);
224 Value sin = math::SinOp::create(rewriter, loc, real, fmf);
225 Value cos = math::CosOp::create(rewriter, loc, real, fmf);
228 combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
235 virtual std::pair<Value, Value>
238 arith::FastMathFlagsAttr fmf)
const = 0;
241 struct CosOpConversion :
public TrigonometricOpConversion<complex::CosOp> {
242 using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
247 arith::FastMathFlagsAttr fmf)
const override {
258 arith::AddFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf);
259 Value resultReal = arith::MulFOp::create(rewriter, loc, sum, cos, fmf);
261 arith::SubFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf);
262 Value resultImag = arith::MulFOp::create(rewriter, loc, diff, sin, fmf);
263 return {resultReal, resultImag};
268 DivOpConversion(
MLIRContext *context, complex::ComplexRangeFlags target)
274 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
276 auto loc = op.getLoc();
277 auto type = cast<ComplexType>(adaptor.getLhs().getType());
278 auto elementType = cast<FloatType>(type.getElementType());
279 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
282 complex::ReOp::create(rewriter, loc, elementType, adaptor.getLhs());
284 complex::ImOp::create(rewriter, loc, elementType, adaptor.getLhs());
286 complex::ReOp::create(rewriter, loc, elementType, adaptor.getRhs());
288 complex::ImOp::create(rewriter, loc, elementType, adaptor.getRhs());
290 Value resultReal, resultImag;
292 if (complexRange == complex::ComplexRangeFlags::basic ||
293 complexRange == complex::ComplexRangeFlags::none) {
295 rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal,
297 }
else if (complexRange == complex::ComplexRangeFlags::improved) {
299 rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal,
310 complex::ComplexRangeFlags complexRange;
317 matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
319 auto loc = op.getLoc();
320 auto type = cast<ComplexType>(adaptor.getComplex().getType());
321 auto elementType = cast<FloatType>(type.getElementType());
322 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
325 complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
327 complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
328 Value expReal = math::ExpOp::create(rewriter, loc, real, fmf.getValue());
329 Value cosImag = math::CosOp::create(rewriter, loc, imag, fmf.getValue());
331 arith::MulFOp::create(rewriter, loc, expReal, cosImag, fmf.getValue());
332 Value sinImag = math::SinOp::create(rewriter, loc, imag, fmf.getValue());
334 arith::MulFOp::create(rewriter, loc, expReal, sinImag, fmf.getValue());
344 arith::FastMathFlagsAttr fmf) {
345 auto argType = mlir::cast<FloatType>(arg.
getType());
347 arith::ConstantOp::create(b, b.
getFloatAttr(argType, coefficients[0]));
348 for (
unsigned i = 1; i < coefficients.size(); ++i) {
349 poly = math::FmaOp::create(
351 arith::ConstantOp::create(b, b.
getFloatAttr(argType, coefficients[i])),
365 matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
367 auto type = op.getType();
368 auto elemType = mlir::cast<FloatType>(type.getElementType());
370 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
372 Value real = complex::ReOp::create(b, adaptor.getComplex());
373 Value imag = complex::ImOp::create(b, adaptor.getComplex());
378 Value expm1Real = math::ExpM1Op::create(b, real, fmf);
379 Value expReal = arith::AddFOp::create(b, expm1Real, one, fmf);
381 Value sinImag = math::SinOp::create(b, imag, fmf);
382 Value cosm1Imag = emitCosm1(imag, fmf, b);
383 Value cosImag = arith::AddFOp::create(b, cosm1Imag, one, fmf);
385 Value realResult = arith::AddFOp::create(
386 b, arith::MulFOp::create(b, expm1Real, cosImag, fmf), cosm1Imag, fmf);
388 Value imagIsZero = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag,
389 zero, fmf.getValue());
390 Value imagResult = arith::SelectOp::create(
391 b, imagIsZero, zero, arith::MulFOp::create(b, expReal, sinImag, fmf));
399 Value emitCosm1(
Value arg, arith::FastMathFlagsAttr fmf,
401 auto argType = mlir::cast<FloatType>(arg.
getType());
402 auto negHalf = arith::ConstantOp::create(b, b.
getFloatAttr(argType, -0.5));
403 auto negOne = arith::ConstantOp::create(b, b.
getFloatAttr(argType, -1.0));
407 4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
408 2.0876754287081521758361E-9, -2.7557319214999787979814E-7,
409 2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
410 4.1666666666666666609054E-2,
412 Value cos = math::CosOp::create(b, arg, fmf);
413 Value forLargeArg = arith::AddFOp::create(b, cos, negOne, fmf);
415 Value argPow2 = arith::MulFOp::create(b, arg, arg, fmf);
416 Value argPow4 = arith::MulFOp::create(b, argPow2, argPow2, fmf);
417 Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf);
420 arith::AddFOp::create(b, arith::MulFOp::create(b, argPow4, poly, fmf),
421 arith::MulFOp::create(b, negHalf, argPow2, fmf));
425 arith::ConstantOp::create(b, b.
getFloatAttr(argType, 0.61685));
426 Value cond = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, argPow2,
427 piOver4Pow2, fmf.getValue());
428 return arith::SelectOp::create(b, cond, forLargeArg, forSmallArg);
436 matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
438 auto type = cast<ComplexType>(adaptor.getComplex().getType());
439 auto elementType = cast<FloatType>(type.getElementType());
440 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
443 Value abs = complex::AbsOp::create(b, elementType, adaptor.getComplex(),
445 Value resultReal = math::LogOp::create(b, elementType, abs, fmf.getValue());
446 Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
447 Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
449 math::Atan2Op::create(b, elementType, imag, real, fmf.getValue());
460 matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
462 auto type = cast<ComplexType>(adaptor.getComplex().getType());
463 auto elementType = cast<FloatType>(type.getElementType());
464 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
467 Value real = complex::ReOp::create(b, adaptor.getComplex());
468 Value imag = complex::ImOp::create(b, adaptor.getComplex());
470 Value half = arith::ConstantOp::create(b, elementType,
472 Value one = arith::ConstantOp::create(b, elementType,
474 Value realPlusOne = arith::AddFOp::create(b, real, one, fmf);
475 Value absRealPlusOne = math::AbsFOp::create(b, realPlusOne, fmf);
476 Value absImag = math::AbsFOp::create(b, imag, fmf);
478 Value maxAbs = arith::MaximumFOp::create(b, absRealPlusOne, absImag, fmf);
479 Value minAbs = arith::MinimumFOp::create(b, absRealPlusOne, absImag, fmf);
481 Value useReal = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT,
482 realPlusOne, absImag, fmf);
483 Value maxMinusOne = arith::SubFOp::create(b, maxAbs, one, fmf);
484 Value maxAbsOfRealPlusOneAndImagMinusOne =
485 arith::SelectOp::create(b, useReal, real, maxMinusOne);
486 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
487 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
488 Value minMaxRatio = arith::DivFOp::create(b, minAbs, maxAbs, fmfWithNaNInf);
489 Value logOfMaxAbsOfRealPlusOneAndImag =
490 math::Log1pOp::create(b, maxAbsOfRealPlusOneAndImagMinusOne, fmf);
491 Value logOfSqrtPart = math::Log1pOp::create(
492 b, arith::MulFOp::create(b, minMaxRatio, minMaxRatio, fmfWithNaNInf),
494 Value r = arith::AddFOp::create(
495 b, arith::MulFOp::create(b, half, logOfSqrtPart, fmfWithNaNInf),
496 logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf);
497 Value resultReal = arith::SelectOp::create(
499 arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, r, r,
502 Value resultImag = math::Atan2Op::create(b, imag, realPlusOne, fmf);
513 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
516 auto type = cast<ComplexType>(adaptor.getLhs().getType());
517 auto elementType = cast<FloatType>(type.getElementType());
518 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
519 auto fmfValue = fmf.getValue();
520 Value lhsReal = complex::ReOp::create(b, elementType, adaptor.getLhs());
521 Value lhsImag = complex::ImOp::create(b, elementType, adaptor.getLhs());
522 Value rhsReal = complex::ReOp::create(b, elementType, adaptor.getRhs());
523 Value rhsImag = complex::ImOp::create(b, elementType, adaptor.getRhs());
524 Value lhsRealTimesRhsReal =
525 arith::MulFOp::create(b, lhsReal, rhsReal, fmfValue);
526 Value lhsImagTimesRhsImag =
527 arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue);
528 Value real = arith::SubFOp::create(b, lhsRealTimesRhsReal,
529 lhsImagTimesRhsImag, fmfValue);
530 Value lhsImagTimesRhsReal =
531 arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue);
532 Value lhsRealTimesRhsImag =
533 arith::MulFOp::create(b, lhsReal, rhsImag, fmfValue);
534 Value imag = arith::AddFOp::create(b, lhsImagTimesRhsReal,
535 lhsRealTimesRhsImag, fmfValue);
545 matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
547 auto loc = op.getLoc();
548 auto type = cast<ComplexType>(adaptor.getComplex().getType());
549 auto elementType = cast<FloatType>(type.getElementType());
552 complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
554 complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
555 Value negReal = arith::NegFOp::create(rewriter, loc, real);
556 Value negImag = arith::NegFOp::create(rewriter, loc, imag);
562 struct SinOpConversion :
public TrigonometricOpConversion<complex::SinOp> {
563 using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
568 arith::FastMathFlagsAttr fmf)
const override {
579 arith::AddFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf);
580 Value resultReal = arith::MulFOp::create(rewriter, loc, sum, sin, fmf);
582 arith::SubFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf);
583 Value resultImag = arith::MulFOp::create(rewriter, loc, diff, cos, fmf);
584 return {resultReal, resultImag};
593 matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
597 auto type = cast<ComplexType>(op.getType());
598 auto elementType = cast<FloatType>(type.getElementType());
599 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
601 auto cst = [&](APFloat v) {
602 return arith::ConstantOp::create(b, elementType,
605 const auto &floatSemantics = elementType.getFloatSemantics();
607 Value half = arith::ConstantOp::create(b, elementType,
610 Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
611 Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
612 Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
613 Value argArg = math::Atan2Op::create(b, imag, real, fmf);
614 Value sqrtArg = arith::MulFOp::create(b, argArg, half, fmf);
615 Value cos = math::CosOp::create(b, sqrtArg, fmf);
616 Value sin = math::SinOp::create(b, sqrtArg, fmf);
620 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, sin, zero, fmf);
622 Value resultReal = arith::MulFOp::create(b, absSqrt, cos, fmf);
623 Value resultImag = arith::SelectOp::create(
624 b, sinIsZero, zero, arith::MulFOp::create(b, absSqrt, sin, fmf));
625 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
626 arith::FastMathFlags::ninf)) {
627 Value inf = cst(APFloat::getInf(floatSemantics));
628 Value negInf = cst(APFloat::getInf(floatSemantics,
true));
629 Value nan = cst(APFloat::getNaN(floatSemantics));
630 Value absImag = math::AbsFOp::create(b, elementType, imag, fmf);
632 Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
634 Value absImagIsNotInf = arith::CmpFOp::create(
635 b, arith::CmpFPredicate::ONE, absImag, inf, fmf);
637 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, inf, fmf);
638 Value realIsNegInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
641 resultReal = arith::SelectOp::create(
642 b, arith::AndIOp::create(b, realIsNegInf, absImagIsNotInf), zero,
644 resultReal = arith::SelectOp::create(
645 b, arith::OrIOp::create(b, absImagIsInf, realIsInf), inf, resultReal);
647 Value imagSignInf = math::CopySignOp::create(b, inf, imag, fmf);
648 resultImag = arith::SelectOp::create(
650 arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, absSqrt, absSqrt),
652 resultImag = arith::SelectOp::create(
653 b, arith::OrIOp::create(b, absImagIsInf, realIsNegInf), imagSignInf,
658 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
659 resultReal = arith::SelectOp::create(b, resultIsZero, zero, resultReal);
660 resultImag = arith::SelectOp::create(b, resultIsZero, zero, resultImag);
672 matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
674 auto type = cast<ComplexType>(adaptor.getComplex().getType());
675 auto elementType = cast<FloatType>(type.getElementType());
677 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
679 Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
680 Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
682 arith::ConstantOp::create(b, elementType, b.
getZeroAttr(elementType));
684 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero);
686 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero);
687 Value isZero = arith::AndIOp::create(b, realIsZero, imagIsZero);
689 complex::AbsOp::create(b, elementType, adaptor.getComplex(), fmf);
690 Value realSign = arith::DivFOp::create(b, real, abs, fmf);
691 Value imagSign = arith::DivFOp::create(b, imag, abs, fmf);
692 Value sign = complex::CreateOp::create(b, type, realSign, imagSign);
694 adaptor.getComplex(), sign);
699 template <
typename Op>
704 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
708 auto type = cast<ComplexType>(adaptor.getComplex().getType());
709 auto elementType = cast<FloatType>(type.getElementType());
710 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
711 const auto &floatSemantics = elementType.getFloatSemantics();
714 complex::ReOp::create(b, loc, elementType, adaptor.getComplex());
716 complex::ImOp::create(b, loc, elementType, adaptor.getComplex());
717 Value negOne = arith::ConstantOp::create(b, elementType,
720 if constexpr (std::is_same_v<Op, complex::TanOp>) {
722 std::swap(real, imag);
723 real = arith::MulFOp::create(b, real, negOne, fmf);
726 auto cst = [&](APFloat v) {
727 return arith::ConstantOp::create(b, elementType,
730 Value inf = cst(APFloat::getInf(floatSemantics));
731 Value four = arith::ConstantOp::create(b, elementType,
733 Value twoReal = arith::AddFOp::create(b, real, real, fmf);
734 Value negTwoReal = arith::MulFOp::create(b, negOne, twoReal, fmf);
736 Value expTwoRealMinusOne = math::ExpM1Op::create(b, twoReal, fmf);
737 Value expNegTwoRealMinusOne = math::ExpM1Op::create(b, negTwoReal, fmf);
738 Value realNum = arith::SubFOp::create(b, expTwoRealMinusOne,
739 expNegTwoRealMinusOne, fmf);
741 Value cosImag = math::CosOp::create(b, imag, fmf);
742 Value cosImagSq = arith::MulFOp::create(b, cosImag, cosImag, fmf);
743 Value twoCosTwoImagPlusOne = arith::MulFOp::create(b, cosImagSq, four, fmf);
744 Value sinImag = math::SinOp::create(b, imag, fmf);
746 Value imagNum = arith::MulFOp::create(
747 b, four, arith::MulFOp::create(b, cosImag, sinImag, fmf), fmf);
749 Value expSumMinusTwo = arith::AddFOp::create(b, expTwoRealMinusOne,
750 expNegTwoRealMinusOne, fmf);
752 arith::AddFOp::create(b, expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
754 Value isInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
755 expSumMinusTwo, inf, fmf);
756 Value realLimit = math::CopySignOp::create(b, negOne, real, fmf);
758 Value resultReal = arith::SelectOp::create(
759 b, isInf, realLimit, arith::DivFOp::create(b, realNum, denom, fmf));
760 Value resultImag = arith::DivFOp::create(b, imagNum, denom, fmf);
762 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
763 arith::FastMathFlags::ninf)) {
764 Value absReal = math::AbsFOp::create(b, real, fmf);
765 Value zero = arith::ConstantOp::create(b, elementType,
767 Value nan = cst(APFloat::getNaN(floatSemantics));
769 Value absRealIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
772 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf);
773 Value absRealIsNotInf = arith::XOrIOp::create(
776 Value imagNumIsNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO,
777 imagNum, imagNum, fmf);
778 Value resultRealIsNaN =
779 arith::AndIOp::create(b, imagNumIsNaN, absRealIsNotInf);
780 Value resultImagIsZero = arith::OrIOp::create(
781 b, imagIsZero, arith::AndIOp::create(b, absRealIsInf, imagNumIsNaN));
783 resultReal = arith::SelectOp::create(b, resultRealIsNaN, nan, resultReal);
785 arith::SelectOp::create(b, resultImagIsZero, zero, resultImag);
788 if constexpr (std::is_same_v<Op, complex::TanOp>) {
790 std::swap(resultReal, resultImag);
791 resultImag = arith::MulFOp::create(b, resultImag, negOne, fmf);
804 matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
806 auto loc = op.getLoc();
807 auto type = cast<ComplexType>(adaptor.getComplex().getType());
808 auto elementType = cast<FloatType>(type.getElementType());
810 complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
812 complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
813 Value negImag = arith::NegFOp::create(rewriter, loc, elementType, imag);
826 arith::FastMathFlags fmf) {
827 auto elementType = cast<FloatType>(type.getElementType());
829 Value a = complex::ReOp::create(builder, lhs);
830 Value b = complex::ImOp::create(builder, lhs);
832 Value abs = complex::AbsOp::create(builder, lhs, fmf);
833 Value absToC = math::PowFOp::create(builder, abs, c, fmf);
835 Value negD = arith::NegFOp::create(builder, d, fmf);
836 Value argLhs = math::Atan2Op::create(builder, b, a, fmf);
837 Value negDArgLhs = arith::MulFOp::create(builder, negD, argLhs, fmf);
838 Value expNegDArgLhs = math::ExpOp::create(builder, negDArgLhs, fmf);
840 Value coeff = arith::MulFOp::create(builder, absToC, expNegDArgLhs, fmf);
841 Value lnAbs = math::LogOp::create(builder, abs, fmf);
842 Value cArgLhs = arith::MulFOp::create(builder, c, argLhs, fmf);
843 Value dLnAbs = arith::MulFOp::create(builder, d, lnAbs, fmf);
844 Value q = arith::AddFOp::create(builder, cArgLhs, dLnAbs, fmf);
845 Value cosQ = math::CosOp::create(builder, q, fmf);
846 Value sinQ = math::SinOp::create(builder, q, fmf);
848 Value inf = arith::ConstantOp::create(
849 builder, elementType,
851 APFloat::getInf(elementType.getFloatSemantics())));
852 Value zero = arith::ConstantOp::create(
853 builder, elementType, builder.
getFloatAttr(elementType, 0.0));
854 Value one = arith::ConstantOp::create(builder, elementType,
856 Value complexOne = complex::CreateOp::create(builder, type, one, zero);
857 Value complexZero = complex::CreateOp::create(builder, type, zero, zero);
858 Value complexInf = complex::CreateOp::create(builder, type, inf, zero);
865 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, abs, zero, fmf);
867 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, d, zero, fmf);
869 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, c, zero, fmf);
871 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, b, zero, fmf);
874 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLE, zero, c, fmf);
875 Value coeffCosQ = arith::MulFOp::create(builder, coeff, cosQ, fmf);
876 Value coeffSinQ = arith::MulFOp::create(builder, coeff, sinQ, fmf);
877 Value complexOneOrZero =
878 arith::SelectOp::create(builder, cEqZero, complexOne, complexZero);
880 complex::CreateOp::create(builder, type, coeffCosQ, coeffSinQ);
881 Value cutoff0 = arith::SelectOp::create(
883 arith::AndIOp::create(
884 builder, arith::AndIOp::create(builder, absEqZero, dEqZero), zeroLeC),
885 complexOneOrZero, coeffCosSin);
891 Value rhsEqZero = arith::AndIOp::create(builder, cEqZero, dEqZero);
893 arith::SelectOp::create(builder, rhsEqZero, complexOne, cutoff0);
897 Value lhsEqOne = arith::AndIOp::create(
899 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, one, fmf),
902 arith::SelectOp::create(builder, lhsEqOne, complexOne, cutoff1);
906 Value lhsEqInf = arith::AndIOp::create(
908 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, inf, fmf),
910 Value rhsGt0 = arith::AndIOp::create(
912 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, c, zero, fmf));
913 Value cutoff3 = arith::SelectOp::create(
914 builder, arith::AndIOp::create(builder, lhsEqInf, rhsGt0), complexInf,
919 Value rhsLt0 = arith::AndIOp::create(
921 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, c, zero, fmf));
922 Value cutoff4 = arith::SelectOp::create(
923 builder, arith::AndIOp::create(builder, lhsEqInf, rhsLt0), complexZero,
933 matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
936 auto type = cast<ComplexType>(adaptor.getLhs().getType());
937 auto elementType = cast<FloatType>(type.getElementType());
939 Value c = complex::ReOp::create(builder, elementType, adaptor.getRhs());
940 Value d = complex::ImOp::create(builder, elementType, adaptor.getRhs());
942 rewriter.
replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
943 c, d, op.getFastmath())});
952 matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
955 auto type = cast<ComplexType>(adaptor.getComplex().getType());
956 auto elementType = cast<FloatType>(type.getElementType());
958 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
960 auto cst = [&](APFloat v) {
961 return arith::ConstantOp::create(b, elementType,
962 b.getFloatAttr(elementType, v));
964 const auto &floatSemantics = elementType.getFloatSemantics();
966 Value inf = cst(APFloat::getInf(floatSemantics));
967 Value negHalf = arith::ConstantOp::create(
968 b, elementType, b.getFloatAttr(elementType, -0.5));
969 Value nan = cst(APFloat::getNaN(floatSemantics));
971 Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
972 Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
973 Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
974 Value argArg = math::Atan2Op::create(b, imag, real, fmf);
975 Value rsqrtArg = arith::MulFOp::create(b, argArg, negHalf, fmf);
976 Value cos = math::CosOp::create(b, rsqrtArg, fmf);
977 Value sin = math::SinOp::create(b, rsqrtArg, fmf);
979 Value resultReal = arith::MulFOp::create(b, absRsqrt, cos, fmf);
980 Value resultImag = arith::MulFOp::create(b, absRsqrt, sin, fmf);
982 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
983 arith::FastMathFlags::ninf)) {
984 Value negOne = arith::ConstantOp::create(b, elementType,
985 b.getFloatAttr(elementType, -1));
987 Value realSignedZero = math::CopySignOp::create(b, zero, real, fmf);
988 Value imagSignedZero = math::CopySignOp::create(b, zero, imag, fmf);
989 Value negImagSignedZero =
990 arith::MulFOp::create(b, negOne, imagSignedZero, fmf);
992 Value absReal = math::AbsFOp::create(b, real, fmf);
993 Value absImag = math::AbsFOp::create(b, imag, fmf);
995 Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
998 arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, real, real, fmf);
999 Value realIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
1001 Value inIsNanInf = arith::AndIOp::create(b, absImagIsInf, realIsNan);
1003 Value resultIsZero = arith::OrIOp::create(b, inIsNanInf, realIsInf);
1006 arith::SelectOp::create(b, resultIsZero, realSignedZero, resultReal);
1007 resultImag = arith::SelectOp::create(b, resultIsZero, negImagSignedZero,
1012 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero, fmf);
1014 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf);
1015 Value isZero = arith::AndIOp::create(b, isRealZero, isImagZero);
1017 resultReal = arith::SelectOp::create(b, isZero, inf, resultReal);
1018 resultImag = arith::SelectOp::create(b, isZero, nan, resultImag);
1030 matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
1032 auto loc = op.getLoc();
1033 auto type = op.getType();
1034 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
1037 complex::ReOp::create(rewriter, loc, type, adaptor.getComplex());
1039 complex::ImOp::create(rewriter, loc, type, adaptor.getComplex());
1056 BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1057 BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
1058 ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
1059 ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
1071 TanTanhOpConversion<complex::TanOp>,
1072 TanTanhOpConversion<complex::TanhOp>,
1083 struct ConvertComplexToStandardPass
1084 :
public impl::ConvertComplexToStandardPassBase<
1085 ConvertComplexToStandardPass> {
1088 void runOnOperation()
override;
1091 void ConvertComplexToStandardPass::runOnOperation() {
1097 target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
1098 target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1101 signalPassFailure();
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
FloatAttr getFloatAttr(Type type, double value)
TypedAttr getZeroAttr(Type type)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
void convertDivToStandardUsingAlgebraic(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the arith/math dialects using algebraic method
void convertDivToStandardUsingRangeReduction(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the arith/math dialects using Smith's method
Fraction abs(const Fraction &f)
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
Include the generated interface declarations.
void populateComplexToStandardConversionPatterns(RewritePatternSet &patterns, mlir::complex::ComplexRangeFlags complexRange=mlir::complex::ComplexRangeFlags::improved)
Populate the given list with patterns that convert from Complex to Standard.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.