19 #include <type_traits>
22 #define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARD
23 #include "mlir/Conversion/Passes.h.inc"
30 enum class AbsFn {
abs, sqrt, rsqrt };
45 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
46 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
48 Value ratioSq = b.
create<arith::MulFOp>(ratio, ratio, fmfWithNaNInf);
49 Value ratioSqPlusOne = b.
create<arith::AddFOp>(ratioSq, one, fmfWithNaNInf);
52 if (fn == AbsFn::rsqrt) {
53 ratioSqPlusOne = b.
create<math::RsqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
58 if (fn == AbsFn::sqrt) {
63 Value p025 = b.
create<math::PowFOp>(ratioSqPlusOne, quarter, fmfWithNaNInf);
64 result = b.
create<arith::MulFOp>(sqrt, p025, fmfWithNaNInf);
66 Value sqrt = b.
create<math::SqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
67 result = b.
create<arith::MulFOp>(
max, sqrt, fmfWithNaNInf);
70 Value isNaN = b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result,
71 result, fmfWithNaNInf);
72 return b.
create<arith::SelectOp>(isNaN,
min, result);
79 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
83 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
85 Value real = b.
create<complex::ReOp>(adaptor.getComplex());
86 Value imag = b.
create<complex::ImOp>(adaptor.getComplex());
87 rewriter.
replaceOp(op, computeAbs(real, imag, fmf, b));
98 matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
102 auto type = cast<ComplexType>(op.getType());
103 Type elementType = type.getElementType();
104 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
106 Value lhs = adaptor.getLhs();
107 Value rhs = adaptor.getRhs();
109 Value rhsSquared = b.
create<complex::MulOp>(type, rhs, rhs, fmf);
110 Value lhsSquared = b.
create<complex::MulOp>(type, lhs, lhs, fmf);
111 Value rhsSquaredPlusLhsSquared =
112 b.
create<complex::AddOp>(type, rhsSquared, lhsSquared, fmf);
113 Value sqrtOfRhsSquaredPlusLhsSquared =
114 b.
create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared, fmf);
118 Value one = b.
create<arith::ConstantOp>(elementType,
120 Value i = b.
create<complex::CreateOp>(type, zero, one);
121 Value iTimesLhs = b.
create<complex::MulOp>(i, lhs, fmf);
122 Value rhsPlusILhs = b.
create<complex::AddOp>(rhs, iTimesLhs, fmf);
125 rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
126 Value logResult = b.
create<complex::LogOp>(divResult, fmf);
130 Value negativeI = b.
create<complex::CreateOp>(type, zero, negativeOne);
137 template <
typename ComparisonOp, arith::CmpFPredicate p>
140 using ResultCombiner =
141 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
142 arith::AndIOp, arith::OrIOp>;
145 matchAndRewrite(ComparisonOp op,
typename ComparisonOp::Adaptor adaptor,
148 auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
150 Value realLhs = rewriter.
create<complex::ReOp>(loc, type, adaptor.getLhs());
151 Value imagLhs = rewriter.
create<complex::ImOp>(loc, type, adaptor.getLhs());
152 Value realRhs = rewriter.
create<complex::ReOp>(loc, type, adaptor.getRhs());
153 Value imagRhs = rewriter.
create<complex::ImOp>(loc, type, adaptor.getRhs());
154 Value realComparison =
155 rewriter.
create<arith::CmpFOp>(loc, p, realLhs, realRhs);
156 Value imagComparison =
157 rewriter.
create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
168 template <
typename BinaryComplexOp,
typename BinaryStandardOp>
173 matchAndRewrite(BinaryComplexOp op,
typename BinaryComplexOp::Adaptor adaptor,
175 auto type = cast<ComplexType>(adaptor.getLhs().getType());
176 auto elementType = cast<FloatType>(type.getElementType());
178 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
180 Value realLhs = b.
create<complex::ReOp>(elementType, adaptor.getLhs());
181 Value realRhs = b.
create<complex::ReOp>(elementType, adaptor.getRhs());
182 Value resultReal = b.
create<BinaryStandardOp>(elementType, realLhs, realRhs,
184 Value imagLhs = b.
create<complex::ImOp>(elementType, adaptor.getLhs());
185 Value imagRhs = b.
create<complex::ImOp>(elementType, adaptor.getRhs());
186 Value resultImag = b.
create<BinaryStandardOp>(elementType, imagLhs, imagRhs,
194 template <
typename TrigonometricOp>
201 matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
204 auto type = cast<ComplexType>(adaptor.getComplex().getType());
205 auto elementType = cast<FloatType>(type.getElementType());
206 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
209 rewriter.
create<complex::ReOp>(loc, elementType, adaptor.getComplex());
211 rewriter.
create<complex::ImOp>(loc, elementType, adaptor.getComplex());
217 loc, elementType, rewriter.
getFloatAttr(elementType, 0.5));
218 Value exp = rewriter.
create<math::ExpOp>(loc, imag, fmf);
219 Value scaledExp = rewriter.
create<arith::MulFOp>(loc, half, exp, fmf);
220 Value reciprocalExp = rewriter.
create<arith::DivFOp>(loc, half, exp, fmf);
221 Value sin = rewriter.
create<math::SinOp>(loc, real, fmf);
222 Value cos = rewriter.
create<math::CosOp>(loc, real, fmf);
225 combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
232 virtual std::pair<Value, Value>
235 arith::FastMathFlagsAttr fmf)
const = 0;
238 struct CosOpConversion :
public TrigonometricOpConversion<complex::CosOp> {
239 using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
244 arith::FastMathFlagsAttr fmf)
const override {
255 rewriter.
create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf);
256 Value resultReal = rewriter.
create<arith::MulFOp>(loc, sum, cos, fmf);
258 rewriter.
create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf);
259 Value resultImag = rewriter.
create<arith::MulFOp>(loc, diff, sin, fmf);
260 return {resultReal, resultImag};
268 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
271 auto type = cast<ComplexType>(adaptor.getLhs().getType());
272 auto elementType = cast<FloatType>(type.getElementType());
273 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
276 rewriter.
create<complex::ReOp>(loc, elementType, adaptor.getLhs());
278 rewriter.
create<complex::ImOp>(loc, elementType, adaptor.getLhs());
280 rewriter.
create<complex::ReOp>(loc, elementType, adaptor.getRhs());
282 rewriter.
create<complex::ImOp>(loc, elementType, adaptor.getRhs());
306 Value rhsRealImagRatio =
307 rewriter.
create<arith::DivFOp>(loc, rhsReal, rhsImag, fmf);
308 Value rhsRealImagDenom = rewriter.
create<arith::AddFOp>(
310 rewriter.
create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal, fmf),
312 Value realNumerator1 = rewriter.
create<arith::AddFOp>(
314 rewriter.
create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio, fmf),
316 Value resultReal1 = rewriter.
create<arith::DivFOp>(loc, realNumerator1,
317 rhsRealImagDenom, fmf);
318 Value imagNumerator1 = rewriter.
create<arith::SubFOp>(
320 rewriter.
create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio, fmf),
322 Value resultImag1 = rewriter.
create<arith::DivFOp>(loc, imagNumerator1,
323 rhsRealImagDenom, fmf);
325 Value rhsImagRealRatio =
326 rewriter.
create<arith::DivFOp>(loc, rhsImag, rhsReal, fmf);
327 Value rhsImagRealDenom = rewriter.
create<arith::AddFOp>(
329 rewriter.
create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag, fmf),
331 Value realNumerator2 = rewriter.
create<arith::AddFOp>(
333 rewriter.
create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio, fmf),
335 Value resultReal2 = rewriter.
create<arith::DivFOp>(loc, realNumerator2,
336 rhsImagRealDenom, fmf);
337 Value imagNumerator2 = rewriter.
create<arith::SubFOp>(
339 rewriter.
create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio, fmf),
341 Value resultImag2 = rewriter.
create<arith::DivFOp>(loc, imagNumerator2,
342 rhsImagRealDenom, fmf);
347 loc, elementType, rewriter.
getZeroAttr(elementType));
348 Value rhsRealAbs = rewriter.
create<math::AbsFOp>(loc, rhsReal, fmf);
349 Value rhsRealIsZero = rewriter.
create<arith::CmpFOp>(
350 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
351 Value rhsImagAbs = rewriter.
create<math::AbsFOp>(loc, rhsImag, fmf);
352 Value rhsImagIsZero = rewriter.
create<arith::CmpFOp>(
353 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
354 Value lhsRealIsNotNaN = rewriter.
create<arith::CmpFOp>(
355 loc, arith::CmpFPredicate::ORD, lhsReal, zero);
356 Value lhsImagIsNotNaN = rewriter.
create<arith::CmpFOp>(
357 loc, arith::CmpFPredicate::ORD, lhsImag, zero);
358 Value lhsContainsNotNaNValue =
359 rewriter.
create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
360 Value resultIsInfinity = rewriter.
create<arith::AndIOp>(
361 loc, lhsContainsNotNaNValue,
362 rewriter.
create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
366 elementType, APFloat::getInf(elementType.getFloatSemantics())));
367 Value infWithSignOfRhsReal =
368 rewriter.
create<math::CopySignOp>(loc, inf, rhsReal);
369 Value infinityResultReal =
370 rewriter.
create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal, fmf);
371 Value infinityResultImag =
372 rewriter.
create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag, fmf);
375 Value rhsRealFinite = rewriter.
create<arith::CmpFOp>(
376 loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
377 Value rhsImagFinite = rewriter.
create<arith::CmpFOp>(
378 loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
380 rewriter.
create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
381 Value lhsRealAbs = rewriter.
create<math::AbsFOp>(loc, lhsReal, fmf);
382 Value lhsRealInfinite = rewriter.
create<arith::CmpFOp>(
383 loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
384 Value lhsImagAbs = rewriter.
create<math::AbsFOp>(loc, lhsImag, fmf);
385 Value lhsImagInfinite = rewriter.
create<arith::CmpFOp>(
386 loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
388 rewriter.
create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
389 Value infNumFiniteDenom =
390 rewriter.
create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
392 loc, elementType, rewriter.
getFloatAttr(elementType, 1));
393 Value lhsRealIsInfWithSign = rewriter.
create<math::CopySignOp>(
394 loc, rewriter.
create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
396 Value lhsImagIsInfWithSign = rewriter.
create<math::CopySignOp>(
397 loc, rewriter.
create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
399 Value lhsRealIsInfWithSignTimesRhsReal =
400 rewriter.
create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal, fmf);
401 Value lhsImagIsInfWithSignTimesRhsImag =
402 rewriter.
create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag, fmf);
403 Value resultReal3 = rewriter.
create<arith::MulFOp>(
405 rewriter.
create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
406 lhsImagIsInfWithSignTimesRhsImag, fmf),
408 Value lhsRealIsInfWithSignTimesRhsImag =
409 rewriter.
create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag, fmf);
410 Value lhsImagIsInfWithSignTimesRhsReal =
411 rewriter.
create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal, fmf);
412 Value resultImag3 = rewriter.
create<arith::MulFOp>(
414 rewriter.
create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
415 lhsRealIsInfWithSignTimesRhsImag, fmf),
419 Value lhsRealFinite = rewriter.
create<arith::CmpFOp>(
420 loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
421 Value lhsImagFinite = rewriter.
create<arith::CmpFOp>(
422 loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
424 rewriter.
create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
425 Value rhsRealInfinite = rewriter.
create<arith::CmpFOp>(
426 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
427 Value rhsImagInfinite = rewriter.
create<arith::CmpFOp>(
428 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
430 rewriter.
create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
431 Value finiteNumInfiniteDenom =
432 rewriter.
create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
433 Value rhsRealIsInfWithSign = rewriter.
create<math::CopySignOp>(
434 loc, rewriter.
create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
436 Value rhsImagIsInfWithSign = rewriter.
create<math::CopySignOp>(
437 loc, rewriter.
create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
439 Value rhsRealIsInfWithSignTimesLhsReal =
440 rewriter.
create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign, fmf);
441 Value rhsImagIsInfWithSignTimesLhsImag =
442 rewriter.
create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign, fmf);
443 Value resultReal4 = rewriter.
create<arith::MulFOp>(
445 rewriter.
create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
446 rhsImagIsInfWithSignTimesLhsImag, fmf),
448 Value rhsRealIsInfWithSignTimesLhsImag =
449 rewriter.
create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign, fmf);
450 Value rhsImagIsInfWithSignTimesLhsReal =
451 rewriter.
create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign, fmf);
452 Value resultImag4 = rewriter.
create<arith::MulFOp>(
454 rewriter.
create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
455 rhsImagIsInfWithSignTimesLhsReal, fmf),
458 Value realAbsSmallerThanImagAbs = rewriter.
create<arith::CmpFOp>(
459 loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
460 Value resultReal = rewriter.
create<arith::SelectOp>(
461 loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
462 Value resultImag = rewriter.
create<arith::SelectOp>(
463 loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
464 Value resultRealSpecialCase3 = rewriter.
create<arith::SelectOp>(
465 loc, finiteNumInfiniteDenom, resultReal4, resultReal);
466 Value resultImagSpecialCase3 = rewriter.
create<arith::SelectOp>(
467 loc, finiteNumInfiniteDenom, resultImag4, resultImag);
468 Value resultRealSpecialCase2 = rewriter.
create<arith::SelectOp>(
469 loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
470 Value resultImagSpecialCase2 = rewriter.
create<arith::SelectOp>(
471 loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
472 Value resultRealSpecialCase1 = rewriter.
create<arith::SelectOp>(
473 loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
474 Value resultImagSpecialCase1 = rewriter.
create<arith::SelectOp>(
475 loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
477 Value resultRealIsNaN = rewriter.
create<arith::CmpFOp>(
478 loc, arith::CmpFPredicate::UNO, resultReal, zero);
479 Value resultImagIsNaN = rewriter.
create<arith::CmpFOp>(
480 loc, arith::CmpFPredicate::UNO, resultImag, zero);
482 rewriter.
create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
483 Value resultRealWithSpecialCases = rewriter.
create<arith::SelectOp>(
484 loc, resultIsNaN, resultRealSpecialCase1, resultReal);
485 Value resultImagWithSpecialCases = rewriter.
create<arith::SelectOp>(
486 loc, resultIsNaN, resultImagSpecialCase1, resultImag);
489 op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
498 matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
501 auto type = cast<ComplexType>(adaptor.getComplex().getType());
502 auto elementType = cast<FloatType>(type.getElementType());
503 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
506 rewriter.
create<complex::ReOp>(loc, elementType, adaptor.getComplex());
508 rewriter.
create<complex::ImOp>(loc, elementType, adaptor.getComplex());
509 Value expReal = rewriter.
create<math::ExpOp>(loc, real, fmf.getValue());
510 Value cosImag = rewriter.
create<math::CosOp>(loc, imag, fmf.getValue());
512 rewriter.
create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue());
513 Value sinImag = rewriter.
create<math::SinOp>(loc, imag, fmf.getValue());
515 rewriter.
create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue());
527 matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
529 auto type = cast<ComplexType>(adaptor.getComplex().getType());
530 auto elementType = cast<FloatType>(type.getElementType());
531 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
534 Value exp = b.
create<complex::ExpOp>(adaptor.getComplex(), fmf.getValue());
536 Value real = b.
create<complex::ReOp>(elementType, exp);
537 Value one = b.
create<arith::ConstantOp>(elementType,
539 Value realMinusOne = b.
create<arith::SubFOp>(real, one, fmf.getValue());
540 Value imag = b.
create<complex::ImOp>(elementType, exp);
552 matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
554 auto type = cast<ComplexType>(adaptor.getComplex().getType());
555 auto elementType = cast<FloatType>(type.getElementType());
556 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
559 Value abs = b.
create<complex::AbsOp>(elementType, adaptor.getComplex(),
561 Value resultReal = b.
create<math::LogOp>(elementType,
abs, fmf.getValue());
562 Value real = b.
create<complex::ReOp>(elementType, adaptor.getComplex());
563 Value imag = b.
create<complex::ImOp>(elementType, adaptor.getComplex());
565 b.
create<math::Atan2Op>(elementType, imag, real, fmf.getValue());
576 matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
578 auto type = cast<ComplexType>(adaptor.getComplex().getType());
579 auto elementType = cast<FloatType>(type.getElementType());
580 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
583 Value real = b.
create<complex::ReOp>(adaptor.getComplex());
584 Value imag = b.
create<complex::ImOp>(adaptor.getComplex());
586 Value half = b.
create<arith::ConstantOp>(elementType,
588 Value one = b.
create<arith::ConstantOp>(elementType,
590 Value realPlusOne = b.
create<arith::AddFOp>(real, one, fmf);
591 Value absRealPlusOne = b.
create<math::AbsFOp>(realPlusOne, fmf);
594 Value maxAbs = b.
create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf);
595 Value minAbs = b.
create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf);
597 Value useReal = b.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
598 realPlusOne, absImag, fmf);
599 Value maxMinusOne = b.
create<arith::SubFOp>(maxAbs, one, fmf);
600 Value maxAbsOfRealPlusOneAndImagMinusOne =
601 b.
create<arith::SelectOp>(useReal, real, maxMinusOne);
602 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
603 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
604 Value minMaxRatio = b.
create<arith::DivFOp>(minAbs, maxAbs, fmfWithNaNInf);
605 Value logOfMaxAbsOfRealPlusOneAndImag =
606 b.
create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf);
608 b.
create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmfWithNaNInf),
611 b.
create<arith::MulFOp>(half, logOfSqrtPart, fmfWithNaNInf),
612 logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf);
614 b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf),
616 Value resultImag = b.
create<math::Atan2Op>(imag, realPlusOne, fmf);
627 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
630 auto type = cast<ComplexType>(adaptor.getLhs().getType());
631 auto elementType = cast<FloatType>(type.getElementType());
632 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
633 auto fmfValue = fmf.getValue();
635 Value lhsReal = b.
create<complex::ReOp>(elementType, adaptor.getLhs());
636 Value lhsRealAbs = b.
create<math::AbsFOp>(lhsReal, fmfValue);
637 Value lhsImag = b.
create<complex::ImOp>(elementType, adaptor.getLhs());
638 Value lhsImagAbs = b.
create<math::AbsFOp>(lhsImag, fmfValue);
639 Value rhsReal = b.
create<complex::ReOp>(elementType, adaptor.getRhs());
640 Value rhsRealAbs = b.
create<math::AbsFOp>(rhsReal, fmfValue);
641 Value rhsImag = b.
create<complex::ImOp>(elementType, adaptor.getRhs());
642 Value rhsImagAbs = b.
create<math::AbsFOp>(rhsImag, fmfValue);
644 Value lhsRealTimesRhsReal =
645 b.
create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
646 Value lhsRealTimesRhsRealAbs =
647 b.
create<math::AbsFOp>(lhsRealTimesRhsReal, fmfValue);
648 Value lhsImagTimesRhsImag =
649 b.
create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
650 Value lhsImagTimesRhsImagAbs =
651 b.
create<math::AbsFOp>(lhsImagTimesRhsImag, fmfValue);
652 Value real = b.
create<arith::SubFOp>(lhsRealTimesRhsReal,
653 lhsImagTimesRhsImag, fmfValue);
655 Value lhsImagTimesRhsReal =
656 b.
create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
657 Value lhsImagTimesRhsRealAbs =
658 b.
create<math::AbsFOp>(lhsImagTimesRhsReal, fmfValue);
659 Value lhsRealTimesRhsImag =
660 b.
create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
661 Value lhsRealTimesRhsImagAbs =
662 b.
create<math::AbsFOp>(lhsRealTimesRhsImag, fmfValue);
663 Value imag = b.
create<arith::AddFOp>(lhsImagTimesRhsReal,
664 lhsRealTimesRhsImag, fmfValue);
668 b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
670 b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
671 Value isNan = b.
create<arith::AndIOp>(realIsNan, imagIsNan);
676 APFloat::getInf(elementType.getFloatSemantics())));
680 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
682 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
683 Value lhsIsInf = b.
create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
685 b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
687 b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
690 Value one = b.
create<arith::ConstantOp>(elementType,
692 Value lhsRealIsInfFloat =
693 b.
create<arith::SelectOp>(lhsRealIsInf, one, zero);
694 lhsReal = b.
create<arith::SelectOp>(
695 lhsIsInf, b.
create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
697 Value lhsImagIsInfFloat =
698 b.
create<arith::SelectOp>(lhsImagIsInf, one, zero);
699 lhsImag = b.
create<arith::SelectOp>(
700 lhsIsInf, b.
create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
702 Value lhsIsInfAndRhsRealIsNan =
703 b.
create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
704 rhsReal = b.
create<arith::SelectOp>(
705 lhsIsInfAndRhsRealIsNan, b.
create<math::CopySignOp>(zero, rhsReal),
707 Value lhsIsInfAndRhsImagIsNan =
708 b.
create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
709 rhsImag = b.
create<arith::SelectOp>(
710 lhsIsInfAndRhsImagIsNan, b.
create<math::CopySignOp>(zero, rhsImag),
715 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
717 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
718 Value rhsIsInf = b.
create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
720 b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
722 b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
723 Value rhsRealIsInfFloat =
724 b.
create<arith::SelectOp>(rhsRealIsInf, one, zero);
725 rhsReal = b.
create<arith::SelectOp>(
726 rhsIsInf, b.
create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
728 Value rhsImagIsInfFloat =
729 b.
create<arith::SelectOp>(rhsImagIsInf, one, zero);
730 rhsImag = b.
create<arith::SelectOp>(
731 rhsIsInf, b.
create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
733 Value rhsIsInfAndLhsRealIsNan =
734 b.
create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
735 lhsReal = b.
create<arith::SelectOp>(
736 rhsIsInfAndLhsRealIsNan, b.
create<math::CopySignOp>(zero, lhsReal),
738 Value rhsIsInfAndLhsImagIsNan =
739 b.
create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
740 lhsImag = b.
create<arith::SelectOp>(
741 rhsIsInfAndLhsImagIsNan, b.
create<math::CopySignOp>(zero, lhsImag),
743 Value recalc = b.
create<arith::OrIOp>(lhsIsInf, rhsIsInf);
747 Value lhsRealTimesRhsRealIsInf = b.
create<arith::CmpFOp>(
748 arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
749 Value lhsImagTimesRhsImagIsInf = b.
create<arith::CmpFOp>(
750 arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
751 Value isSpecialCase = b.
create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
752 lhsImagTimesRhsImagIsInf);
753 Value lhsRealTimesRhsImagIsInf = b.
create<arith::CmpFOp>(
754 arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
756 b.
create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
757 Value lhsImagTimesRhsRealIsInf = b.
create<arith::CmpFOp>(
758 arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
760 b.
create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
765 isSpecialCase = b.
create<arith::AndIOp>(isSpecialCase, notRecalc);
766 Value isSpecialCaseAndLhsRealIsNan =
767 b.
create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
768 lhsReal = b.
create<arith::SelectOp>(
769 isSpecialCaseAndLhsRealIsNan, b.
create<math::CopySignOp>(zero, lhsReal),
771 Value isSpecialCaseAndLhsImagIsNan =
772 b.
create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
773 lhsImag = b.
create<arith::SelectOp>(
774 isSpecialCaseAndLhsImagIsNan, b.
create<math::CopySignOp>(zero, lhsImag),
776 Value isSpecialCaseAndRhsRealIsNan =
777 b.
create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
778 rhsReal = b.
create<arith::SelectOp>(
779 isSpecialCaseAndRhsRealIsNan, b.
create<math::CopySignOp>(zero, rhsReal),
781 Value isSpecialCaseAndRhsImagIsNan =
782 b.
create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
783 rhsImag = b.
create<arith::SelectOp>(
784 isSpecialCaseAndRhsImagIsNan, b.
create<math::CopySignOp>(zero, rhsImag),
786 recalc = b.
create<arith::OrIOp>(recalc, isSpecialCase);
787 recalc = b.
create<arith::AndIOp>(isNan, recalc);
790 lhsRealTimesRhsReal = b.
create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
791 lhsImagTimesRhsImag = b.
create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
792 Value newReal = b.
create<arith::SubFOp>(lhsRealTimesRhsReal,
793 lhsImagTimesRhsImag, fmfValue);
794 real = b.
create<arith::SelectOp>(
795 recalc, b.
create<arith::MulFOp>(inf, newReal, fmfValue), real);
798 lhsImagTimesRhsReal = b.
create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
799 lhsRealTimesRhsImag = b.
create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
800 Value newImag = b.
create<arith::AddFOp>(lhsImagTimesRhsReal,
801 lhsRealTimesRhsImag, fmfValue);
802 imag = b.
create<arith::SelectOp>(
803 recalc, b.
create<arith::MulFOp>(inf, newImag, fmfValue), imag);
814 matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
817 auto type = cast<ComplexType>(adaptor.getComplex().getType());
818 auto elementType = cast<FloatType>(type.getElementType());
821 rewriter.
create<complex::ReOp>(loc, elementType, adaptor.getComplex());
823 rewriter.
create<complex::ImOp>(loc, elementType, adaptor.getComplex());
824 Value negReal = rewriter.
create<arith::NegFOp>(loc, real);
825 Value negImag = rewriter.
create<arith::NegFOp>(loc, imag);
831 struct SinOpConversion :
public TrigonometricOpConversion<complex::SinOp> {
832 using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
837 arith::FastMathFlagsAttr fmf)
const override {
848 rewriter.
create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf);
849 Value resultReal = rewriter.
create<arith::MulFOp>(loc, sum, sin, fmf);
851 rewriter.
create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf);
852 Value resultImag = rewriter.
create<arith::MulFOp>(loc, diff, cos, fmf);
853 return {resultReal, resultImag};
862 matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
866 auto type = cast<ComplexType>(op.getType());
867 auto elementType = cast<FloatType>(type.getElementType());
868 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
870 auto cst = [&](APFloat v) {
871 return b.
create<arith::ConstantOp>(elementType,
874 const auto &floatSemantics = elementType.getFloatSemantics();
876 Value half = b.
create<arith::ConstantOp>(elementType,
879 Value real = b.
create<complex::ReOp>(elementType, adaptor.getComplex());
880 Value imag = b.
create<complex::ImOp>(elementType, adaptor.getComplex());
881 Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
882 Value argArg = b.
create<math::Atan2Op>(imag, real, fmf);
883 Value sqrtArg = b.
create<arith::MulFOp>(argArg, half, fmf);
889 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf);
891 Value resultReal = b.
create<arith::MulFOp>(absSqrt, cos, fmf);
893 sinIsZero, zero, b.
create<arith::MulFOp>(absSqrt, sin, fmf));
894 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
895 arith::FastMathFlags::ninf)) {
896 Value inf = cst(APFloat::getInf(floatSemantics));
897 Value negInf = cst(APFloat::getInf(floatSemantics,
true));
898 Value nan = cst(APFloat::getNaN(floatSemantics));
899 Value absImag = b.
create<math::AbsFOp>(elementType, imag, fmf);
902 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
903 Value absImagIsNotInf =
904 b.
create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf);
906 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf);
908 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf);
910 resultReal = b.
create<arith::SelectOp>(
911 b.
create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero,
913 resultReal = b.
create<arith::SelectOp>(
914 b.
create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal);
916 Value imagSignInf = b.
create<math::CopySignOp>(inf, imag, fmf);
917 resultImag = b.
create<arith::SelectOp>(
918 b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt),
920 resultImag = b.
create<arith::SelectOp>(
921 b.
create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf,
926 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
927 resultReal = b.
create<arith::SelectOp>(resultIsZero, zero, resultReal);
928 resultImag = b.
create<arith::SelectOp>(resultIsZero, zero, resultImag);
940 matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
942 auto type = cast<ComplexType>(adaptor.getComplex().getType());
943 auto elementType = cast<FloatType>(type.getElementType());
945 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
947 Value real = b.
create<complex::ReOp>(elementType, adaptor.getComplex());
948 Value imag = b.
create<complex::ImOp>(elementType, adaptor.getComplex());
952 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
954 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
955 Value isZero = b.
create<arith::AndIOp>(realIsZero, imagIsZero);
956 auto abs = b.
create<complex::AbsOp>(elementType, adaptor.getComplex(), fmf);
959 Value sign = b.
create<complex::CreateOp>(type, realSign, imagSign);
961 adaptor.getComplex(), sign);
966 template <
typename Op>
971 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
975 auto type = cast<ComplexType>(adaptor.getComplex().getType());
976 auto elementType = cast<FloatType>(type.getElementType());
977 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
978 const auto &floatSemantics = elementType.getFloatSemantics();
981 b.
create<complex::ReOp>(loc, elementType, adaptor.getComplex());
983 b.
create<complex::ImOp>(loc, elementType, adaptor.getComplex());
987 if constexpr (std::is_same_v<Op, complex::TanOp>) {
989 std::swap(real, imag);
990 real = b.
create<arith::MulFOp>(real, negOne, fmf);
993 auto cst = [&](APFloat v) {
994 return b.
create<arith::ConstantOp>(elementType,
997 Value inf = cst(APFloat::getInf(floatSemantics));
998 Value four = b.
create<arith::ConstantOp>(elementType,
1000 Value twoReal = b.
create<arith::AddFOp>(real, real, fmf);
1001 Value negTwoReal = b.
create<arith::MulFOp>(negOne, twoReal, fmf);
1003 Value expTwoRealMinusOne = b.
create<math::ExpM1Op>(twoReal, fmf);
1004 Value expNegTwoRealMinusOne = b.
create<math::ExpM1Op>(negTwoReal, fmf);
1006 b.
create<arith::SubFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
1009 Value cosImagSq = b.
create<arith::MulFOp>(cosImag, cosImag, fmf);
1010 Value twoCosTwoImagPlusOne = b.
create<arith::MulFOp>(cosImagSq, four, fmf);
1014 four, b.
create<arith::MulFOp>(cosImag, sinImag, fmf), fmf);
1016 Value expSumMinusTwo =
1017 b.
create<arith::AddFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
1019 b.
create<arith::AddFOp>(expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
1021 Value isInf = b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
1022 expSumMinusTwo, inf, fmf);
1023 Value realLimit = b.
create<math::CopySignOp>(negOne, real, fmf);
1026 isInf, realLimit, b.
create<arith::DivFOp>(realNum, denom, fmf));
1027 Value resultImag = b.
create<arith::DivFOp>(imagNum, denom, fmf);
1029 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1030 arith::FastMathFlags::ninf)) {
1031 Value absReal = b.
create<math::AbsFOp>(real, fmf);
1034 Value nan = cst(APFloat::getNaN(floatSemantics));
1036 Value absRealIsInf =
1037 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
1039 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
1041 absRealIsInf, b.
create<arith::ConstantIntOp>(
true, 1));
1043 Value imagNumIsNaN = b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO,
1044 imagNum, imagNum, fmf);
1045 Value resultRealIsNaN =
1046 b.
create<arith::AndIOp>(imagNumIsNaN, absRealIsNotInf);
1048 imagIsZero, b.
create<arith::AndIOp>(absRealIsInf, imagNumIsNaN));
1050 resultReal = b.
create<arith::SelectOp>(resultRealIsNaN, nan, resultReal);
1052 b.
create<arith::SelectOp>(resultImagIsZero, zero, resultImag);
1055 if constexpr (std::is_same_v<Op, complex::TanOp>) {
1057 std::swap(resultReal, resultImag);
1058 resultImag = b.
create<arith::MulFOp>(resultImag, negOne, fmf);
1071 matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
1074 auto type = cast<ComplexType>(adaptor.getComplex().getType());
1075 auto elementType = cast<FloatType>(type.getElementType());
1077 rewriter.
create<complex::ReOp>(loc, elementType, adaptor.getComplex());
1079 rewriter.
create<complex::ImOp>(loc, elementType, adaptor.getComplex());
1080 Value negImag = rewriter.
create<arith::NegFOp>(loc, elementType, imag);
1093 arith::FastMathFlags fmf) {
1094 auto elementType = cast<FloatType>(type.getElementType());
1102 Value negD = builder.
create<arith::NegFOp>(d, fmf);
1103 Value argLhs = builder.
create<math::Atan2Op>(b, a, fmf);
1104 Value negDArgLhs = builder.
create<arith::MulFOp>(negD, argLhs, fmf);
1105 Value expNegDArgLhs = builder.
create<math::ExpOp>(negDArgLhs, fmf);
1107 Value coeff = builder.
create<arith::MulFOp>(absToC, expNegDArgLhs, fmf);
1109 Value cArgLhs = builder.
create<arith::MulFOp>(c, argLhs, fmf);
1110 Value dLnAbs = builder.
create<arith::MulFOp>(d, lnAbs, fmf);
1111 Value q = builder.
create<arith::AddFOp>(cArgLhs, dLnAbs, fmf);
1118 APFloat::getInf(elementType.getFloatSemantics())));
1123 Value complexOne = builder.
create<complex::CreateOp>(type, one, zero);
1124 Value complexZero = builder.
create<complex::CreateOp>(type, zero, zero);
1125 Value complexInf = builder.
create<complex::CreateOp>(type, inf, zero);
1132 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
abs, zero, fmf);
1134 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero, fmf);
1136 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero, fmf);
1138 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, b, zero, fmf);
1141 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLE, zero, c, fmf);
1142 Value coeffCosQ = builder.
create<arith::MulFOp>(coeff, cosQ, fmf);
1143 Value coeffSinQ = builder.
create<arith::MulFOp>(coeff, sinQ, fmf);
1144 Value complexOneOrZero =
1145 builder.
create<arith::SelectOp>(cEqZero, complexOne, complexZero);
1147 builder.
create<complex::CreateOp>(type, coeffCosQ, coeffSinQ);
1149 builder.
create<arith::AndIOp>(
1150 builder.
create<arith::AndIOp>(absEqZero, dEqZero), zeroLeC),
1151 complexOneOrZero, coeffCosSin);
1157 Value rhsEqZero = builder.
create<arith::AndIOp>(cEqZero, dEqZero);
1159 builder.
create<arith::SelectOp>(rhsEqZero, complexOne, cutoff0);
1164 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf),
1167 builder.
create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1);
1172 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf),
1176 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf));
1178 builder.
create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2);
1182 Value rhsLt0 = builder.create<arith::AndIOp>(
1184 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf));
1185 Value cutoff4 = builder.create<arith::SelectOp>(
1186 builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3);
1195 matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
1198 auto type = cast<ComplexType>(adaptor.getLhs().getType());
1199 auto elementType = cast<FloatType>(type.getElementType());
1201 Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
1202 Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
1204 rewriter.
replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
1205 c, d, op.getFastmath())});
1214 matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
1217 auto type = cast<ComplexType>(adaptor.getComplex().getType());
1218 auto elementType = cast<FloatType>(type.getElementType());
1220 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
1222 auto cst = [&](APFloat v) {
1223 return b.
create<arith::ConstantOp>(elementType,
1226 const auto &floatSemantics = elementType.getFloatSemantics();
1228 Value inf = cst(APFloat::getInf(floatSemantics));
1231 Value nan = cst(APFloat::getNaN(floatSemantics));
1233 Value real = b.
create<complex::ReOp>(elementType, adaptor.getComplex());
1234 Value imag = b.
create<complex::ImOp>(elementType, adaptor.getComplex());
1235 Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
1236 Value argArg = b.
create<math::Atan2Op>(imag, real, fmf);
1237 Value rsqrtArg = b.
create<arith::MulFOp>(argArg, negHalf, fmf);
1241 Value resultReal = b.
create<arith::MulFOp>(absRsqrt, cos, fmf);
1242 Value resultImag = b.
create<arith::MulFOp>(absRsqrt, sin, fmf);
1244 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1245 arith::FastMathFlags::ninf)) {
1249 Value realSignedZero = b.
create<math::CopySignOp>(zero, real, fmf);
1250 Value imagSignedZero = b.
create<math::CopySignOp>(zero, imag, fmf);
1251 Value negImagSignedZero =
1252 b.
create<arith::MulFOp>(negOne, imagSignedZero, fmf);
1254 Value absReal = b.
create<math::AbsFOp>(real, fmf);
1255 Value absImag = b.
create<math::AbsFOp>(imag, fmf);
1257 Value absImagIsInf =
1258 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
1260 b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf);
1262 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
1263 Value inIsNanInf = b.
create<arith::AndIOp>(absImagIsInf, realIsNan);
1265 Value resultIsZero = b.
create<arith::OrIOp>(inIsNanInf, realIsInf);
1268 b.
create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal);
1269 resultImag = b.
create<arith::SelectOp>(resultIsZero, negImagSignedZero,
1274 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf);
1276 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
1277 Value isZero = b.
create<arith::AndIOp>(isRealZero, isImagZero);
1279 resultReal = b.
create<arith::SelectOp>(isZero, inf, resultReal);
1280 resultImag = b.
create<arith::SelectOp>(isZero, nan, resultImag);
1292 matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
1295 auto type = op.getType();
1296 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
1299 rewriter.
create<complex::ReOp>(loc, type, adaptor.getComplex());
1301 rewriter.
create<complex::ImOp>(loc, type, adaptor.getComplex());
1318 BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1319 BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
1320 ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
1321 ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
1334 TanTanhOpConversion<complex::TanOp>,
1335 TanTanhOpConversion<complex::TanhOp>,
1343 struct ConvertComplexToStandardPass
1344 :
public impl::ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
1345 void runOnOperation()
override;
1348 void ConvertComplexToStandardPass::runOnOperation() {
1354 target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
1355 target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1358 signalPassFailure();
1363 return std::make_unique<ConvertComplexToStandardPass>();
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
TypedAttr getZeroAttr(Type type)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
This provides public APIs that all operations should have.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Fraction abs(const Fraction &f)
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
Include the generated interface declarations.
void populateComplexToStandardConversionPatterns(RewritePatternSet &patterns)
Populate the given list with patterns that convert from Complex to Standard.
std::unique_ptr< Pass > createConvertComplexToStandardPass()
Create a pass to convert Complex operations to the Standard dialect.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.