19 #include <type_traits>
22 #define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARD
23 #include "mlir/Conversion/Passes.h.inc"
30 enum class AbsFn {
abs, sqrt, rsqrt };
45 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
46 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
48 Value ratioSq = b.
create<arith::MulFOp>(ratio, ratio, fmfWithNaNInf);
49 Value ratioSqPlusOne = b.
create<arith::AddFOp>(ratioSq, one, fmfWithNaNInf);
52 if (fn == AbsFn::rsqrt) {
53 ratioSqPlusOne = b.
create<math::RsqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
58 if (fn == AbsFn::sqrt) {
63 Value p025 = b.
create<math::PowFOp>(ratioSqPlusOne, quarter, fmfWithNaNInf);
64 result = b.
create<arith::MulFOp>(sqrt, p025, fmfWithNaNInf);
66 Value sqrt = b.
create<math::SqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
67 result = b.
create<arith::MulFOp>(
max, sqrt, fmfWithNaNInf);
70 Value isNaN = b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result,
71 result, fmfWithNaNInf);
72 return b.
create<arith::SelectOp>(isNaN,
min, result);
79 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
83 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
85 Value real = b.
create<complex::ReOp>(adaptor.getComplex());
86 Value imag = b.
create<complex::ImOp>(adaptor.getComplex());
87 rewriter.
replaceOp(op, computeAbs(real, imag, fmf, b));
98 matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
102 auto type = cast<ComplexType>(op.getType());
103 Type elementType = type.getElementType();
104 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
106 Value lhs = adaptor.getLhs();
107 Value rhs = adaptor.getRhs();
109 Value rhsSquared = b.
create<complex::MulOp>(type, rhs, rhs, fmf);
110 Value lhsSquared = b.
create<complex::MulOp>(type, lhs, lhs, fmf);
111 Value rhsSquaredPlusLhsSquared =
112 b.
create<complex::AddOp>(type, rhsSquared, lhsSquared, fmf);
113 Value sqrtOfRhsSquaredPlusLhsSquared =
114 b.
create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared, fmf);
118 Value one = b.
create<arith::ConstantOp>(elementType,
120 Value i = b.
create<complex::CreateOp>(type, zero, one);
121 Value iTimesLhs = b.
create<complex::MulOp>(i, lhs, fmf);
122 Value rhsPlusILhs = b.
create<complex::AddOp>(rhs, iTimesLhs, fmf);
125 rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
126 Value logResult = b.
create<complex::LogOp>(divResult, fmf);
130 Value negativeI = b.
create<complex::CreateOp>(type, zero, negativeOne);
137 template <
typename ComparisonOp, arith::CmpFPredicate p>
140 using ResultCombiner =
141 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
142 arith::AndIOp, arith::OrIOp>;
145 matchAndRewrite(ComparisonOp op,
typename ComparisonOp::Adaptor adaptor,
147 auto loc = op.getLoc();
148 auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
150 Value realLhs = rewriter.
create<complex::ReOp>(loc, type, adaptor.getLhs());
151 Value imagLhs = rewriter.
create<complex::ImOp>(loc, type, adaptor.getLhs());
152 Value realRhs = rewriter.
create<complex::ReOp>(loc, type, adaptor.getRhs());
153 Value imagRhs = rewriter.
create<complex::ImOp>(loc, type, adaptor.getRhs());
154 Value realComparison =
155 rewriter.
create<arith::CmpFOp>(loc, p, realLhs, realRhs);
156 Value imagComparison =
157 rewriter.
create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
168 template <
typename BinaryComplexOp,
typename BinaryStandardOp>
173 matchAndRewrite(BinaryComplexOp op,
typename BinaryComplexOp::Adaptor adaptor,
175 auto type = cast<ComplexType>(adaptor.getLhs().getType());
176 auto elementType = cast<FloatType>(type.getElementType());
178 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
180 Value realLhs = b.
create<complex::ReOp>(elementType, adaptor.getLhs());
181 Value realRhs = b.
create<complex::ReOp>(elementType, adaptor.getRhs());
182 Value resultReal = b.
create<BinaryStandardOp>(elementType, realLhs, realRhs,
184 Value imagLhs = b.
create<complex::ImOp>(elementType, adaptor.getLhs());
185 Value imagRhs = b.
create<complex::ImOp>(elementType, adaptor.getRhs());
186 Value resultImag = b.
create<BinaryStandardOp>(elementType, imagLhs, imagRhs,
194 template <
typename TrigonometricOp>
201 matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
203 auto loc = op.getLoc();
204 auto type = cast<ComplexType>(adaptor.getComplex().getType());
205 auto elementType = cast<FloatType>(type.getElementType());
206 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
209 rewriter.
create<complex::ReOp>(loc, elementType, adaptor.getComplex());
211 rewriter.
create<complex::ImOp>(loc, elementType, adaptor.getComplex());
217 loc, elementType, rewriter.
getFloatAttr(elementType, 0.5));
218 Value exp = rewriter.
create<math::ExpOp>(loc, imag, fmf);
219 Value scaledExp = rewriter.
create<arith::MulFOp>(loc, half, exp, fmf);
220 Value reciprocalExp = rewriter.
create<arith::DivFOp>(loc, half, exp, fmf);
221 Value sin = rewriter.
create<math::SinOp>(loc, real, fmf);
222 Value cos = rewriter.
create<math::CosOp>(loc, real, fmf);
225 combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
232 virtual std::pair<Value, Value>
235 arith::FastMathFlagsAttr fmf)
const = 0;
238 struct CosOpConversion :
public TrigonometricOpConversion<complex::CosOp> {
239 using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
244 arith::FastMathFlagsAttr fmf)
const override {
255 rewriter.
create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf);
256 Value resultReal = rewriter.
create<arith::MulFOp>(loc, sum, cos, fmf);
258 rewriter.
create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf);
259 Value resultImag = rewriter.
create<arith::MulFOp>(loc, diff, sin, fmf);
260 return {resultReal, resultImag};
268 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
270 auto loc = op.getLoc();
271 auto type = cast<ComplexType>(adaptor.getLhs().getType());
272 auto elementType = cast<FloatType>(type.getElementType());
273 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
276 rewriter.
create<complex::ReOp>(loc, elementType, adaptor.getLhs());
278 rewriter.
create<complex::ImOp>(loc, elementType, adaptor.getLhs());
280 rewriter.
create<complex::ReOp>(loc, elementType, adaptor.getRhs());
282 rewriter.
create<complex::ImOp>(loc, elementType, adaptor.getRhs());
306 Value rhsRealImagRatio =
307 rewriter.
create<arith::DivFOp>(loc, rhsReal, rhsImag, fmf);
308 Value rhsRealImagDenom = rewriter.
create<arith::AddFOp>(
310 rewriter.
create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal, fmf),
312 Value realNumerator1 = rewriter.
create<arith::AddFOp>(
314 rewriter.
create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio, fmf),
316 Value resultReal1 = rewriter.
create<arith::DivFOp>(loc, realNumerator1,
317 rhsRealImagDenom, fmf);
318 Value imagNumerator1 = rewriter.
create<arith::SubFOp>(
320 rewriter.
create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio, fmf),
322 Value resultImag1 = rewriter.
create<arith::DivFOp>(loc, imagNumerator1,
323 rhsRealImagDenom, fmf);
325 Value rhsImagRealRatio =
326 rewriter.
create<arith::DivFOp>(loc, rhsImag, rhsReal, fmf);
327 Value rhsImagRealDenom = rewriter.
create<arith::AddFOp>(
329 rewriter.
create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag, fmf),
331 Value realNumerator2 = rewriter.
create<arith::AddFOp>(
333 rewriter.
create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio, fmf),
335 Value resultReal2 = rewriter.
create<arith::DivFOp>(loc, realNumerator2,
336 rhsImagRealDenom, fmf);
337 Value imagNumerator2 = rewriter.
create<arith::SubFOp>(
339 rewriter.
create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio, fmf),
341 Value resultImag2 = rewriter.
create<arith::DivFOp>(loc, imagNumerator2,
342 rhsImagRealDenom, fmf);
347 loc, elementType, rewriter.
getZeroAttr(elementType));
348 Value rhsRealAbs = rewriter.
create<math::AbsFOp>(loc, rhsReal, fmf);
349 Value rhsRealIsZero = rewriter.
create<arith::CmpFOp>(
350 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
351 Value rhsImagAbs = rewriter.
create<math::AbsFOp>(loc, rhsImag, fmf);
352 Value rhsImagIsZero = rewriter.
create<arith::CmpFOp>(
353 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
354 Value lhsRealIsNotNaN = rewriter.
create<arith::CmpFOp>(
355 loc, arith::CmpFPredicate::ORD, lhsReal, zero);
356 Value lhsImagIsNotNaN = rewriter.
create<arith::CmpFOp>(
357 loc, arith::CmpFPredicate::ORD, lhsImag, zero);
358 Value lhsContainsNotNaNValue =
359 rewriter.
create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
360 Value resultIsInfinity = rewriter.
create<arith::AndIOp>(
361 loc, lhsContainsNotNaNValue,
362 rewriter.
create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
366 elementType, APFloat::getInf(elementType.getFloatSemantics())));
367 Value infWithSignOfRhsReal =
368 rewriter.
create<math::CopySignOp>(loc, inf, rhsReal);
369 Value infinityResultReal =
370 rewriter.
create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal, fmf);
371 Value infinityResultImag =
372 rewriter.
create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag, fmf);
375 Value rhsRealFinite = rewriter.
create<arith::CmpFOp>(
376 loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
377 Value rhsImagFinite = rewriter.
create<arith::CmpFOp>(
378 loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
380 rewriter.
create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
381 Value lhsRealAbs = rewriter.
create<math::AbsFOp>(loc, lhsReal, fmf);
382 Value lhsRealInfinite = rewriter.
create<arith::CmpFOp>(
383 loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
384 Value lhsImagAbs = rewriter.
create<math::AbsFOp>(loc, lhsImag, fmf);
385 Value lhsImagInfinite = rewriter.
create<arith::CmpFOp>(
386 loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
388 rewriter.
create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
389 Value infNumFiniteDenom =
390 rewriter.
create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
392 loc, elementType, rewriter.
getFloatAttr(elementType, 1));
393 Value lhsRealIsInfWithSign = rewriter.
create<math::CopySignOp>(
394 loc, rewriter.
create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
396 Value lhsImagIsInfWithSign = rewriter.
create<math::CopySignOp>(
397 loc, rewriter.
create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
399 Value lhsRealIsInfWithSignTimesRhsReal =
400 rewriter.
create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal, fmf);
401 Value lhsImagIsInfWithSignTimesRhsImag =
402 rewriter.
create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag, fmf);
403 Value resultReal3 = rewriter.
create<arith::MulFOp>(
405 rewriter.
create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
406 lhsImagIsInfWithSignTimesRhsImag, fmf),
408 Value lhsRealIsInfWithSignTimesRhsImag =
409 rewriter.
create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag, fmf);
410 Value lhsImagIsInfWithSignTimesRhsReal =
411 rewriter.
create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal, fmf);
412 Value resultImag3 = rewriter.
create<arith::MulFOp>(
414 rewriter.
create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
415 lhsRealIsInfWithSignTimesRhsImag, fmf),
419 Value lhsRealFinite = rewriter.
create<arith::CmpFOp>(
420 loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
421 Value lhsImagFinite = rewriter.
create<arith::CmpFOp>(
422 loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
424 rewriter.
create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
425 Value rhsRealInfinite = rewriter.
create<arith::CmpFOp>(
426 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
427 Value rhsImagInfinite = rewriter.
create<arith::CmpFOp>(
428 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
430 rewriter.
create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
431 Value finiteNumInfiniteDenom =
432 rewriter.
create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
433 Value rhsRealIsInfWithSign = rewriter.
create<math::CopySignOp>(
434 loc, rewriter.
create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
436 Value rhsImagIsInfWithSign = rewriter.
create<math::CopySignOp>(
437 loc, rewriter.
create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
439 Value rhsRealIsInfWithSignTimesLhsReal =
440 rewriter.
create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign, fmf);
441 Value rhsImagIsInfWithSignTimesLhsImag =
442 rewriter.
create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign, fmf);
443 Value resultReal4 = rewriter.
create<arith::MulFOp>(
445 rewriter.
create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
446 rhsImagIsInfWithSignTimesLhsImag, fmf),
448 Value rhsRealIsInfWithSignTimesLhsImag =
449 rewriter.
create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign, fmf);
450 Value rhsImagIsInfWithSignTimesLhsReal =
451 rewriter.
create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign, fmf);
452 Value resultImag4 = rewriter.
create<arith::MulFOp>(
454 rewriter.
create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
455 rhsImagIsInfWithSignTimesLhsReal, fmf),
458 Value realAbsSmallerThanImagAbs = rewriter.
create<arith::CmpFOp>(
459 loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
460 Value resultReal = rewriter.
create<arith::SelectOp>(
461 loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
462 Value resultImag = rewriter.
create<arith::SelectOp>(
463 loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
464 Value resultRealSpecialCase3 = rewriter.
create<arith::SelectOp>(
465 loc, finiteNumInfiniteDenom, resultReal4, resultReal);
466 Value resultImagSpecialCase3 = rewriter.
create<arith::SelectOp>(
467 loc, finiteNumInfiniteDenom, resultImag4, resultImag);
468 Value resultRealSpecialCase2 = rewriter.
create<arith::SelectOp>(
469 loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
470 Value resultImagSpecialCase2 = rewriter.
create<arith::SelectOp>(
471 loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
472 Value resultRealSpecialCase1 = rewriter.
create<arith::SelectOp>(
473 loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
474 Value resultImagSpecialCase1 = rewriter.
create<arith::SelectOp>(
475 loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
477 Value resultRealIsNaN = rewriter.
create<arith::CmpFOp>(
478 loc, arith::CmpFPredicate::UNO, resultReal, zero);
479 Value resultImagIsNaN = rewriter.
create<arith::CmpFOp>(
480 loc, arith::CmpFPredicate::UNO, resultImag, zero);
482 rewriter.
create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
483 Value resultRealWithSpecialCases = rewriter.
create<arith::SelectOp>(
484 loc, resultIsNaN, resultRealSpecialCase1, resultReal);
485 Value resultImagWithSpecialCases = rewriter.
create<arith::SelectOp>(
486 loc, resultIsNaN, resultImagSpecialCase1, resultImag);
489 op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
498 matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
500 auto loc = op.getLoc();
501 auto type = cast<ComplexType>(adaptor.getComplex().getType());
502 auto elementType = cast<FloatType>(type.getElementType());
503 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
506 rewriter.
create<complex::ReOp>(loc, elementType, adaptor.getComplex());
508 rewriter.
create<complex::ImOp>(loc, elementType, adaptor.getComplex());
509 Value expReal = rewriter.
create<math::ExpOp>(loc, real, fmf.getValue());
510 Value cosImag = rewriter.
create<math::CosOp>(loc, imag, fmf.getValue());
512 rewriter.
create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue());
513 Value sinImag = rewriter.
create<math::SinOp>(loc, imag, fmf.getValue());
515 rewriter.
create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue());
525 arith::FastMathFlagsAttr fmf) {
526 auto argType = mlir::cast<FloatType>(arg.
getType());
529 for (
unsigned i = 1; i < coefficients.size(); ++i) {
530 poly = b.
create<math::FmaOp>(
546 matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
548 auto type = op.getType();
549 auto elemType = mlir::cast<FloatType>(type.getElementType());
551 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
553 Value real = b.
create<complex::ReOp>(adaptor.getComplex());
554 Value imag = b.
create<complex::ImOp>(adaptor.getComplex());
559 Value expm1Real = b.
create<math::ExpM1Op>(real, fmf);
560 Value expReal = b.
create<arith::AddFOp>(expm1Real, one, fmf);
563 Value cosm1Imag = emitCosm1(imag, fmf, b);
564 Value cosImag = b.
create<arith::AddFOp>(cosm1Imag, one, fmf);
567 b.
create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf);
569 Value imagIsZero = b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag,
570 zero, fmf.getValue());
572 imagIsZero, zero, b.
create<arith::MulFOp>(expReal, sinImag, fmf));
580 Value emitCosm1(
Value arg, arith::FastMathFlagsAttr fmf,
582 auto argType = mlir::cast<FloatType>(arg.
getType());
588 4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
589 2.0876754287081521758361E-9, -2.7557319214999787979814E-7,
590 2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
591 4.1666666666666666609054E-2,
594 Value forLargeArg = b.
create<arith::AddFOp>(cos, negOne, fmf);
596 Value argPow2 = b.
create<arith::MulFOp>(arg, arg, fmf);
597 Value argPow4 = b.
create<arith::MulFOp>(argPow2, argPow2, fmf);
598 Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf);
601 b.
create<arith::AddFOp>(b.
create<arith::MulFOp>(argPow4, poly, fmf),
602 b.
create<arith::MulFOp>(negHalf, argPow2, fmf));
607 Value cond = b.
create<arith::CmpFOp>(arith::CmpFPredicate::OGE, argPow2,
608 piOver4Pow2, fmf.getValue());
609 return b.
create<arith::SelectOp>(cond, forLargeArg, forSmallArg);
617 matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
619 auto type = cast<ComplexType>(adaptor.getComplex().getType());
620 auto elementType = cast<FloatType>(type.getElementType());
621 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
624 Value abs = b.
create<complex::AbsOp>(elementType, adaptor.getComplex(),
626 Value resultReal = b.
create<math::LogOp>(elementType,
abs, fmf.getValue());
627 Value real = b.
create<complex::ReOp>(elementType, adaptor.getComplex());
628 Value imag = b.
create<complex::ImOp>(elementType, adaptor.getComplex());
630 b.
create<math::Atan2Op>(elementType, imag, real, fmf.getValue());
641 matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
643 auto type = cast<ComplexType>(adaptor.getComplex().getType());
644 auto elementType = cast<FloatType>(type.getElementType());
645 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
648 Value real = b.
create<complex::ReOp>(adaptor.getComplex());
649 Value imag = b.
create<complex::ImOp>(adaptor.getComplex());
651 Value half = b.
create<arith::ConstantOp>(elementType,
653 Value one = b.
create<arith::ConstantOp>(elementType,
655 Value realPlusOne = b.
create<arith::AddFOp>(real, one, fmf);
656 Value absRealPlusOne = b.
create<math::AbsFOp>(realPlusOne, fmf);
659 Value maxAbs = b.
create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf);
660 Value minAbs = b.
create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf);
662 Value useReal = b.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
663 realPlusOne, absImag, fmf);
664 Value maxMinusOne = b.
create<arith::SubFOp>(maxAbs, one, fmf);
665 Value maxAbsOfRealPlusOneAndImagMinusOne =
666 b.
create<arith::SelectOp>(useReal, real, maxMinusOne);
667 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
668 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
669 Value minMaxRatio = b.
create<arith::DivFOp>(minAbs, maxAbs, fmfWithNaNInf);
670 Value logOfMaxAbsOfRealPlusOneAndImag =
671 b.
create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf);
673 b.
create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmfWithNaNInf),
676 b.
create<arith::MulFOp>(half, logOfSqrtPart, fmfWithNaNInf),
677 logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf);
679 b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf),
681 Value resultImag = b.
create<math::Atan2Op>(imag, realPlusOne, fmf);
692 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
695 auto type = cast<ComplexType>(adaptor.getLhs().getType());
696 auto elementType = cast<FloatType>(type.getElementType());
697 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
698 auto fmfValue = fmf.getValue();
699 Value lhsReal = b.
create<complex::ReOp>(elementType, adaptor.getLhs());
700 Value lhsImag = b.
create<complex::ImOp>(elementType, adaptor.getLhs());
701 Value rhsReal = b.
create<complex::ReOp>(elementType, adaptor.getRhs());
702 Value rhsImag = b.
create<complex::ImOp>(elementType, adaptor.getRhs());
703 Value lhsRealTimesRhsReal =
704 b.
create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
705 Value lhsImagTimesRhsImag =
706 b.
create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
707 Value real = b.
create<arith::SubFOp>(lhsRealTimesRhsReal,
708 lhsImagTimesRhsImag, fmfValue);
709 Value lhsImagTimesRhsReal =
710 b.
create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
711 Value lhsRealTimesRhsImag =
712 b.
create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
713 Value imag = b.
create<arith::AddFOp>(lhsImagTimesRhsReal,
714 lhsRealTimesRhsImag, fmfValue);
724 matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
726 auto loc = op.getLoc();
727 auto type = cast<ComplexType>(adaptor.getComplex().getType());
728 auto elementType = cast<FloatType>(type.getElementType());
731 rewriter.
create<complex::ReOp>(loc, elementType, adaptor.getComplex());
733 rewriter.
create<complex::ImOp>(loc, elementType, adaptor.getComplex());
734 Value negReal = rewriter.
create<arith::NegFOp>(loc, real);
735 Value negImag = rewriter.
create<arith::NegFOp>(loc, imag);
741 struct SinOpConversion :
public TrigonometricOpConversion<complex::SinOp> {
742 using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
747 arith::FastMathFlagsAttr fmf)
const override {
758 rewriter.
create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf);
759 Value resultReal = rewriter.
create<arith::MulFOp>(loc, sum, sin, fmf);
761 rewriter.
create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf);
762 Value resultImag = rewriter.
create<arith::MulFOp>(loc, diff, cos, fmf);
763 return {resultReal, resultImag};
772 matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
776 auto type = cast<ComplexType>(op.getType());
777 auto elementType = cast<FloatType>(type.getElementType());
778 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
780 auto cst = [&](APFloat v) {
781 return b.
create<arith::ConstantOp>(elementType,
784 const auto &floatSemantics = elementType.getFloatSemantics();
786 Value half = b.
create<arith::ConstantOp>(elementType,
789 Value real = b.
create<complex::ReOp>(elementType, adaptor.getComplex());
790 Value imag = b.
create<complex::ImOp>(elementType, adaptor.getComplex());
791 Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
792 Value argArg = b.
create<math::Atan2Op>(imag, real, fmf);
793 Value sqrtArg = b.
create<arith::MulFOp>(argArg, half, fmf);
799 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf);
801 Value resultReal = b.
create<arith::MulFOp>(absSqrt, cos, fmf);
803 sinIsZero, zero, b.
create<arith::MulFOp>(absSqrt, sin, fmf));
804 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
805 arith::FastMathFlags::ninf)) {
806 Value inf = cst(APFloat::getInf(floatSemantics));
807 Value negInf = cst(APFloat::getInf(floatSemantics,
true));
808 Value nan = cst(APFloat::getNaN(floatSemantics));
809 Value absImag = b.
create<math::AbsFOp>(elementType, imag, fmf);
812 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
813 Value absImagIsNotInf =
814 b.
create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf);
816 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf);
818 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf);
820 resultReal = b.
create<arith::SelectOp>(
821 b.
create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero,
823 resultReal = b.
create<arith::SelectOp>(
824 b.
create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal);
826 Value imagSignInf = b.
create<math::CopySignOp>(inf, imag, fmf);
827 resultImag = b.
create<arith::SelectOp>(
828 b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt),
830 resultImag = b.
create<arith::SelectOp>(
831 b.
create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf,
836 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
837 resultReal = b.
create<arith::SelectOp>(resultIsZero, zero, resultReal);
838 resultImag = b.
create<arith::SelectOp>(resultIsZero, zero, resultImag);
850 matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
852 auto type = cast<ComplexType>(adaptor.getComplex().getType());
853 auto elementType = cast<FloatType>(type.getElementType());
855 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
857 Value real = b.
create<complex::ReOp>(elementType, adaptor.getComplex());
858 Value imag = b.
create<complex::ImOp>(elementType, adaptor.getComplex());
862 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
864 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
865 Value isZero = b.
create<arith::AndIOp>(realIsZero, imagIsZero);
866 auto abs = b.
create<complex::AbsOp>(elementType, adaptor.getComplex(), fmf);
869 Value sign = b.
create<complex::CreateOp>(type, realSign, imagSign);
871 adaptor.getComplex(), sign);
876 template <
typename Op>
881 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
885 auto type = cast<ComplexType>(adaptor.getComplex().getType());
886 auto elementType = cast<FloatType>(type.getElementType());
887 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
888 const auto &floatSemantics = elementType.getFloatSemantics();
891 b.
create<complex::ReOp>(loc, elementType, adaptor.getComplex());
893 b.
create<complex::ImOp>(loc, elementType, adaptor.getComplex());
897 if constexpr (std::is_same_v<Op, complex::TanOp>) {
899 std::swap(real, imag);
900 real = b.
create<arith::MulFOp>(real, negOne, fmf);
903 auto cst = [&](APFloat v) {
904 return b.
create<arith::ConstantOp>(elementType,
907 Value inf = cst(APFloat::getInf(floatSemantics));
908 Value four = b.
create<arith::ConstantOp>(elementType,
910 Value twoReal = b.
create<arith::AddFOp>(real, real, fmf);
911 Value negTwoReal = b.
create<arith::MulFOp>(negOne, twoReal, fmf);
913 Value expTwoRealMinusOne = b.
create<math::ExpM1Op>(twoReal, fmf);
914 Value expNegTwoRealMinusOne = b.
create<math::ExpM1Op>(negTwoReal, fmf);
916 b.
create<arith::SubFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
919 Value cosImagSq = b.
create<arith::MulFOp>(cosImag, cosImag, fmf);
920 Value twoCosTwoImagPlusOne = b.
create<arith::MulFOp>(cosImagSq, four, fmf);
924 four, b.
create<arith::MulFOp>(cosImag, sinImag, fmf), fmf);
926 Value expSumMinusTwo =
927 b.
create<arith::AddFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
929 b.
create<arith::AddFOp>(expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
931 Value isInf = b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
932 expSumMinusTwo, inf, fmf);
933 Value realLimit = b.
create<math::CopySignOp>(negOne, real, fmf);
936 isInf, realLimit, b.
create<arith::DivFOp>(realNum, denom, fmf));
937 Value resultImag = b.
create<arith::DivFOp>(imagNum, denom, fmf);
939 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
940 arith::FastMathFlags::ninf)) {
944 Value nan = cst(APFloat::getNaN(floatSemantics));
947 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
949 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
951 absRealIsInf, b.
create<arith::ConstantIntOp>(
true, 1));
953 Value imagNumIsNaN = b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO,
954 imagNum, imagNum, fmf);
955 Value resultRealIsNaN =
956 b.
create<arith::AndIOp>(imagNumIsNaN, absRealIsNotInf);
958 imagIsZero, b.
create<arith::AndIOp>(absRealIsInf, imagNumIsNaN));
960 resultReal = b.
create<arith::SelectOp>(resultRealIsNaN, nan, resultReal);
962 b.
create<arith::SelectOp>(resultImagIsZero, zero, resultImag);
965 if constexpr (std::is_same_v<Op, complex::TanOp>) {
967 std::swap(resultReal, resultImag);
968 resultImag = b.
create<arith::MulFOp>(resultImag, negOne, fmf);
981 matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
983 auto loc = op.getLoc();
984 auto type = cast<ComplexType>(adaptor.getComplex().getType());
985 auto elementType = cast<FloatType>(type.getElementType());
987 rewriter.
create<complex::ReOp>(loc, elementType, adaptor.getComplex());
989 rewriter.
create<complex::ImOp>(loc, elementType, adaptor.getComplex());
990 Value negImag = rewriter.
create<arith::NegFOp>(loc, elementType, imag);
1003 arith::FastMathFlags fmf) {
1004 auto elementType = cast<FloatType>(type.getElementType());
1012 Value negD = builder.
create<arith::NegFOp>(d, fmf);
1013 Value argLhs = builder.
create<math::Atan2Op>(b, a, fmf);
1014 Value negDArgLhs = builder.
create<arith::MulFOp>(negD, argLhs, fmf);
1015 Value expNegDArgLhs = builder.
create<math::ExpOp>(negDArgLhs, fmf);
1017 Value coeff = builder.
create<arith::MulFOp>(absToC, expNegDArgLhs, fmf);
1019 Value cArgLhs = builder.
create<arith::MulFOp>(c, argLhs, fmf);
1020 Value dLnAbs = builder.
create<arith::MulFOp>(d, lnAbs, fmf);
1021 Value q = builder.
create<arith::AddFOp>(cArgLhs, dLnAbs, fmf);
1028 APFloat::getInf(elementType.getFloatSemantics())));
1033 Value complexOne = builder.
create<complex::CreateOp>(type, one, zero);
1034 Value complexZero = builder.
create<complex::CreateOp>(type, zero, zero);
1035 Value complexInf = builder.
create<complex::CreateOp>(type, inf, zero);
1042 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
abs, zero, fmf);
1044 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero, fmf);
1046 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero, fmf);
1048 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, b, zero, fmf);
1051 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLE, zero, c, fmf);
1052 Value coeffCosQ = builder.
create<arith::MulFOp>(coeff, cosQ, fmf);
1053 Value coeffSinQ = builder.
create<arith::MulFOp>(coeff, sinQ, fmf);
1054 Value complexOneOrZero =
1055 builder.
create<arith::SelectOp>(cEqZero, complexOne, complexZero);
1057 builder.
create<complex::CreateOp>(type, coeffCosQ, coeffSinQ);
1059 builder.
create<arith::AndIOp>(
1060 builder.
create<arith::AndIOp>(absEqZero, dEqZero), zeroLeC),
1061 complexOneOrZero, coeffCosSin);
1067 Value rhsEqZero = builder.
create<arith::AndIOp>(cEqZero, dEqZero);
1069 builder.
create<arith::SelectOp>(rhsEqZero, complexOne, cutoff0);
1074 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf),
1077 builder.
create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1);
1082 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf),
1086 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf));
1088 builder.
create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2);
1092 Value rhsLt0 = builder.create<arith::AndIOp>(
1094 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf));
1095 Value cutoff4 = builder.create<arith::SelectOp>(
1096 builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3);
1105 matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
1108 auto type = cast<ComplexType>(adaptor.getLhs().getType());
1109 auto elementType = cast<FloatType>(type.getElementType());
1111 Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
1112 Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
1114 rewriter.
replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
1115 c, d, op.getFastmath())});
1124 matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
1127 auto type = cast<ComplexType>(adaptor.getComplex().getType());
1128 auto elementType = cast<FloatType>(type.getElementType());
1130 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
1132 auto cst = [&](APFloat v) {
1133 return b.
create<arith::ConstantOp>(elementType,
1136 const auto &floatSemantics = elementType.getFloatSemantics();
1138 Value inf = cst(APFloat::getInf(floatSemantics));
1141 Value nan = cst(APFloat::getNaN(floatSemantics));
1143 Value real = b.
create<complex::ReOp>(elementType, adaptor.getComplex());
1144 Value imag = b.
create<complex::ImOp>(elementType, adaptor.getComplex());
1145 Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
1146 Value argArg = b.
create<math::Atan2Op>(imag, real, fmf);
1147 Value rsqrtArg = b.
create<arith::MulFOp>(argArg, negHalf, fmf);
1151 Value resultReal = b.
create<arith::MulFOp>(absRsqrt, cos, fmf);
1152 Value resultImag = b.
create<arith::MulFOp>(absRsqrt, sin, fmf);
1154 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1155 arith::FastMathFlags::ninf)) {
1159 Value realSignedZero = b.
create<math::CopySignOp>(zero, real, fmf);
1160 Value imagSignedZero = b.
create<math::CopySignOp>(zero, imag, fmf);
1161 Value negImagSignedZero =
1162 b.
create<arith::MulFOp>(negOne, imagSignedZero, fmf);
1164 Value absReal = b.
create<math::AbsFOp>(real, fmf);
1165 Value absImag = b.
create<math::AbsFOp>(imag, fmf);
1167 Value absImagIsInf =
1168 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
1170 b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf);
1172 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
1173 Value inIsNanInf = b.
create<arith::AndIOp>(absImagIsInf, realIsNan);
1175 Value resultIsZero = b.
create<arith::OrIOp>(inIsNanInf, realIsInf);
1178 b.
create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal);
1179 resultImag = b.
create<arith::SelectOp>(resultIsZero, negImagSignedZero,
1184 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf);
1186 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
1187 Value isZero = b.
create<arith::AndIOp>(isRealZero, isImagZero);
1189 resultReal = b.
create<arith::SelectOp>(isZero, inf, resultReal);
1190 resultImag = b.
create<arith::SelectOp>(isZero, nan, resultImag);
1202 matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
1204 auto loc = op.getLoc();
1205 auto type = op.getType();
1206 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
1209 rewriter.
create<complex::ReOp>(loc, type, adaptor.getComplex());
1211 rewriter.
create<complex::ImOp>(loc, type, adaptor.getComplex());
1228 BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1229 BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
1230 ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
1231 ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
1244 TanTanhOpConversion<complex::TanOp>,
1245 TanTanhOpConversion<complex::TanhOp>,
1253 struct ConvertComplexToStandardPass
1254 :
public impl::ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
1255 void runOnOperation()
override;
1258 void ConvertComplexToStandardPass::runOnOperation() {
1264 target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
1265 target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1268 signalPassFailure();
1273 return std::make_unique<ConvertComplexToStandardPass>();
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
FloatAttr getFloatAttr(Type type, double value)
TypedAttr getZeroAttr(Type type)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Fraction abs(const Fraction &f)
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
Include the generated interface declarations.
void populateComplexToStandardConversionPatterns(RewritePatternSet &patterns)
Populate the given list with patterns that convert from Complex to Standard.
const FrozenRewritePatternSet & patterns
std::unique_ptr< Pass > createConvertComplexToStandardPass()
Create a pass to convert Complex operations to the Standard dialect.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.