19 #include <type_traits>
22 #define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARD
23 #include "mlir/Conversion/Passes.h.inc"
30 enum class AbsFn {
abs, sqrt, rsqrt };
45 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
46 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
48 Value ratioSq = b.
create<arith::MulFOp>(ratio, ratio, fmfWithNaNInf);
49 Value ratioSqPlusOne = b.
create<arith::AddFOp>(ratioSq, one, fmfWithNaNInf);
52 if (fn == AbsFn::rsqrt) {
53 ratioSqPlusOne = b.
create<math::RsqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
58 if (fn == AbsFn::sqrt) {
63 Value p025 = b.
create<math::PowFOp>(ratioSqPlusOne, quarter, fmfWithNaNInf);
64 result = b.
create<arith::MulFOp>(sqrt, p025, fmfWithNaNInf);
66 Value sqrt = b.
create<math::SqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
67 result = b.
create<arith::MulFOp>(
max, sqrt, fmfWithNaNInf);
70 Value isNaN = b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result,
71 result, fmfWithNaNInf);
72 return b.
create<arith::SelectOp>(isNaN,
min, result);
79 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
83 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
85 Value real = b.
create<complex::ReOp>(adaptor.getComplex());
86 Value imag = b.
create<complex::ImOp>(adaptor.getComplex());
87 rewriter.
replaceOp(op, computeAbs(real, imag, fmf, b));
98 matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
102 auto type = cast<ComplexType>(op.getType());
103 Type elementType = type.getElementType();
104 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
106 Value lhs = adaptor.getLhs();
107 Value rhs = adaptor.getRhs();
109 Value rhsSquared = b.
create<complex::MulOp>(type, rhs, rhs, fmf);
110 Value lhsSquared = b.
create<complex::MulOp>(type, lhs, lhs, fmf);
111 Value rhsSquaredPlusLhsSquared =
112 b.
create<complex::AddOp>(type, rhsSquared, lhsSquared, fmf);
113 Value sqrtOfRhsSquaredPlusLhsSquared =
114 b.
create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared, fmf);
118 Value one = b.
create<arith::ConstantOp>(elementType,
120 Value i = b.
create<complex::CreateOp>(type, zero, one);
121 Value iTimesLhs = b.
create<complex::MulOp>(i, lhs, fmf);
122 Value rhsPlusILhs = b.
create<complex::AddOp>(rhs, iTimesLhs, fmf);
125 rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
126 Value logResult = b.
create<complex::LogOp>(divResult, fmf);
130 Value negativeI = b.
create<complex::CreateOp>(type, zero, negativeOne);
137 template <
typename ComparisonOp, arith::CmpFPredicate p>
140 using ResultCombiner =
141 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
142 arith::AndIOp, arith::OrIOp>;
145 matchAndRewrite(ComparisonOp op,
typename ComparisonOp::Adaptor adaptor,
147 auto loc = op.getLoc();
148 auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
150 Value realLhs = rewriter.
create<complex::ReOp>(loc, type, adaptor.getLhs());
151 Value imagLhs = rewriter.
create<complex::ImOp>(loc, type, adaptor.getLhs());
152 Value realRhs = rewriter.
create<complex::ReOp>(loc, type, adaptor.getRhs());
153 Value imagRhs = rewriter.
create<complex::ImOp>(loc, type, adaptor.getRhs());
154 Value realComparison =
155 rewriter.
create<arith::CmpFOp>(loc, p, realLhs, realRhs);
156 Value imagComparison =
157 rewriter.
create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
168 template <
typename BinaryComplexOp,
typename BinaryStandardOp>
173 matchAndRewrite(BinaryComplexOp op,
typename BinaryComplexOp::Adaptor adaptor,
175 auto type = cast<ComplexType>(adaptor.getLhs().getType());
176 auto elementType = cast<FloatType>(type.getElementType());
178 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
180 Value realLhs = b.
create<complex::ReOp>(elementType, adaptor.getLhs());
181 Value realRhs = b.
create<complex::ReOp>(elementType, adaptor.getRhs());
182 Value resultReal = b.
create<BinaryStandardOp>(elementType, realLhs, realRhs,
184 Value imagLhs = b.
create<complex::ImOp>(elementType, adaptor.getLhs());
185 Value imagRhs = b.
create<complex::ImOp>(elementType, adaptor.getRhs());
186 Value resultImag = b.
create<BinaryStandardOp>(elementType, imagLhs, imagRhs,
194 template <
typename TrigonometricOp>
201 matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
203 auto loc = op.getLoc();
204 auto type = cast<ComplexType>(adaptor.getComplex().getType());
205 auto elementType = cast<FloatType>(type.getElementType());
206 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
209 rewriter.
create<complex::ReOp>(loc, elementType, adaptor.getComplex());
211 rewriter.
create<complex::ImOp>(loc, elementType, adaptor.getComplex());
217 loc, elementType, rewriter.
getFloatAttr(elementType, 0.5));
218 Value exp = rewriter.
create<math::ExpOp>(loc, imag, fmf);
219 Value scaledExp = rewriter.
create<arith::MulFOp>(loc, half, exp, fmf);
220 Value reciprocalExp = rewriter.
create<arith::DivFOp>(loc, half, exp, fmf);
221 Value sin = rewriter.
create<math::SinOp>(loc, real, fmf);
222 Value cos = rewriter.
create<math::CosOp>(loc, real, fmf);
225 combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
232 virtual std::pair<Value, Value>
235 arith::FastMathFlagsAttr fmf)
const = 0;
238 struct CosOpConversion :
public TrigonometricOpConversion<complex::CosOp> {
239 using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
244 arith::FastMathFlagsAttr fmf)
const override {
255 rewriter.
create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf);
256 Value resultReal = rewriter.
create<arith::MulFOp>(loc, sum, cos, fmf);
258 rewriter.
create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf);
259 Value resultImag = rewriter.
create<arith::MulFOp>(loc, diff, sin, fmf);
260 return {resultReal, resultImag};
268 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
270 auto loc = op.getLoc();
271 auto type = cast<ComplexType>(adaptor.getLhs().getType());
272 auto elementType = cast<FloatType>(type.getElementType());
273 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
276 rewriter.
create<complex::ReOp>(loc, elementType, adaptor.getLhs());
278 rewriter.
create<complex::ImOp>(loc, elementType, adaptor.getLhs());
280 rewriter.
create<complex::ReOp>(loc, elementType, adaptor.getRhs());
282 rewriter.
create<complex::ImOp>(loc, elementType, adaptor.getRhs());
306 Value rhsRealImagRatio =
307 rewriter.
create<arith::DivFOp>(loc, rhsReal, rhsImag, fmf);
308 Value rhsRealImagDenom = rewriter.
create<arith::AddFOp>(
310 rewriter.
create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal, fmf),
312 Value realNumerator1 = rewriter.
create<arith::AddFOp>(
314 rewriter.
create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio, fmf),
316 Value resultReal1 = rewriter.
create<arith::DivFOp>(loc, realNumerator1,
317 rhsRealImagDenom, fmf);
318 Value imagNumerator1 = rewriter.
create<arith::SubFOp>(
320 rewriter.
create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio, fmf),
322 Value resultImag1 = rewriter.
create<arith::DivFOp>(loc, imagNumerator1,
323 rhsRealImagDenom, fmf);
325 Value rhsImagRealRatio =
326 rewriter.
create<arith::DivFOp>(loc, rhsImag, rhsReal, fmf);
327 Value rhsImagRealDenom = rewriter.
create<arith::AddFOp>(
329 rewriter.
create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag, fmf),
331 Value realNumerator2 = rewriter.
create<arith::AddFOp>(
333 rewriter.
create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio, fmf),
335 Value resultReal2 = rewriter.
create<arith::DivFOp>(loc, realNumerator2,
336 rhsImagRealDenom, fmf);
337 Value imagNumerator2 = rewriter.
create<arith::SubFOp>(
339 rewriter.
create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio, fmf),
341 Value resultImag2 = rewriter.
create<arith::DivFOp>(loc, imagNumerator2,
342 rhsImagRealDenom, fmf);
347 loc, elementType, rewriter.
getZeroAttr(elementType));
348 Value rhsRealAbs = rewriter.
create<math::AbsFOp>(loc, rhsReal, fmf);
349 Value rhsRealIsZero = rewriter.
create<arith::CmpFOp>(
350 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
351 Value rhsImagAbs = rewriter.
create<math::AbsFOp>(loc, rhsImag, fmf);
352 Value rhsImagIsZero = rewriter.
create<arith::CmpFOp>(
353 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
354 Value lhsRealIsNotNaN = rewriter.
create<arith::CmpFOp>(
355 loc, arith::CmpFPredicate::ORD, lhsReal, zero);
356 Value lhsImagIsNotNaN = rewriter.
create<arith::CmpFOp>(
357 loc, arith::CmpFPredicate::ORD, lhsImag, zero);
358 Value lhsContainsNotNaNValue =
359 rewriter.
create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
360 Value resultIsInfinity = rewriter.
create<arith::AndIOp>(
361 loc, lhsContainsNotNaNValue,
362 rewriter.
create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
366 elementType, APFloat::getInf(elementType.getFloatSemantics())));
367 Value infWithSignOfRhsReal =
368 rewriter.
create<math::CopySignOp>(loc, inf, rhsReal);
369 Value infinityResultReal =
370 rewriter.
create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal, fmf);
371 Value infinityResultImag =
372 rewriter.
create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag, fmf);
375 Value rhsRealFinite = rewriter.
create<arith::CmpFOp>(
376 loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
377 Value rhsImagFinite = rewriter.
create<arith::CmpFOp>(
378 loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
380 rewriter.
create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
381 Value lhsRealAbs = rewriter.
create<math::AbsFOp>(loc, lhsReal, fmf);
382 Value lhsRealInfinite = rewriter.
create<arith::CmpFOp>(
383 loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
384 Value lhsImagAbs = rewriter.
create<math::AbsFOp>(loc, lhsImag, fmf);
385 Value lhsImagInfinite = rewriter.
create<arith::CmpFOp>(
386 loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
388 rewriter.
create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
389 Value infNumFiniteDenom =
390 rewriter.
create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
392 loc, elementType, rewriter.
getFloatAttr(elementType, 1));
393 Value lhsRealIsInfWithSign = rewriter.
create<math::CopySignOp>(
394 loc, rewriter.
create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
396 Value lhsImagIsInfWithSign = rewriter.
create<math::CopySignOp>(
397 loc, rewriter.
create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
399 Value lhsRealIsInfWithSignTimesRhsReal =
400 rewriter.
create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal, fmf);
401 Value lhsImagIsInfWithSignTimesRhsImag =
402 rewriter.
create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag, fmf);
403 Value resultReal3 = rewriter.
create<arith::MulFOp>(
405 rewriter.
create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
406 lhsImagIsInfWithSignTimesRhsImag, fmf),
408 Value lhsRealIsInfWithSignTimesRhsImag =
409 rewriter.
create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag, fmf);
410 Value lhsImagIsInfWithSignTimesRhsReal =
411 rewriter.
create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal, fmf);
412 Value resultImag3 = rewriter.
create<arith::MulFOp>(
414 rewriter.
create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
415 lhsRealIsInfWithSignTimesRhsImag, fmf),
419 Value lhsRealFinite = rewriter.
create<arith::CmpFOp>(
420 loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
421 Value lhsImagFinite = rewriter.
create<arith::CmpFOp>(
422 loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
424 rewriter.
create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
425 Value rhsRealInfinite = rewriter.
create<arith::CmpFOp>(
426 loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
427 Value rhsImagInfinite = rewriter.
create<arith::CmpFOp>(
428 loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
430 rewriter.
create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
431 Value finiteNumInfiniteDenom =
432 rewriter.
create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
433 Value rhsRealIsInfWithSign = rewriter.
create<math::CopySignOp>(
434 loc, rewriter.
create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
436 Value rhsImagIsInfWithSign = rewriter.
create<math::CopySignOp>(
437 loc, rewriter.
create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
439 Value rhsRealIsInfWithSignTimesLhsReal =
440 rewriter.
create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign, fmf);
441 Value rhsImagIsInfWithSignTimesLhsImag =
442 rewriter.
create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign, fmf);
443 Value resultReal4 = rewriter.
create<arith::MulFOp>(
445 rewriter.
create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
446 rhsImagIsInfWithSignTimesLhsImag, fmf),
448 Value rhsRealIsInfWithSignTimesLhsImag =
449 rewriter.
create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign, fmf);
450 Value rhsImagIsInfWithSignTimesLhsReal =
451 rewriter.
create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign, fmf);
452 Value resultImag4 = rewriter.
create<arith::MulFOp>(
454 rewriter.
create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
455 rhsImagIsInfWithSignTimesLhsReal, fmf),
458 Value realAbsSmallerThanImagAbs = rewriter.
create<arith::CmpFOp>(
459 loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
460 Value resultReal = rewriter.
create<arith::SelectOp>(
461 loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
462 Value resultImag = rewriter.
create<arith::SelectOp>(
463 loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
464 Value resultRealSpecialCase3 = rewriter.
create<arith::SelectOp>(
465 loc, finiteNumInfiniteDenom, resultReal4, resultReal);
466 Value resultImagSpecialCase3 = rewriter.
create<arith::SelectOp>(
467 loc, finiteNumInfiniteDenom, resultImag4, resultImag);
468 Value resultRealSpecialCase2 = rewriter.
create<arith::SelectOp>(
469 loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
470 Value resultImagSpecialCase2 = rewriter.
create<arith::SelectOp>(
471 loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
472 Value resultRealSpecialCase1 = rewriter.
create<arith::SelectOp>(
473 loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
474 Value resultImagSpecialCase1 = rewriter.
create<arith::SelectOp>(
475 loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
477 Value resultRealIsNaN = rewriter.
create<arith::CmpFOp>(
478 loc, arith::CmpFPredicate::UNO, resultReal, zero);
479 Value resultImagIsNaN = rewriter.
create<arith::CmpFOp>(
480 loc, arith::CmpFPredicate::UNO, resultImag, zero);
482 rewriter.
create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
483 Value resultRealWithSpecialCases = rewriter.
create<arith::SelectOp>(
484 loc, resultIsNaN, resultRealSpecialCase1, resultReal);
485 Value resultImagWithSpecialCases = rewriter.
create<arith::SelectOp>(
486 loc, resultIsNaN, resultImagSpecialCase1, resultImag);
489 op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
498 matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
500 auto loc = op.getLoc();
501 auto type = cast<ComplexType>(adaptor.getComplex().getType());
502 auto elementType = cast<FloatType>(type.getElementType());
503 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
506 rewriter.
create<complex::ReOp>(loc, elementType, adaptor.getComplex());
508 rewriter.
create<complex::ImOp>(loc, elementType, adaptor.getComplex());
509 Value expReal = rewriter.
create<math::ExpOp>(loc, real, fmf.getValue());
510 Value cosImag = rewriter.
create<math::CosOp>(loc, imag, fmf.getValue());
512 rewriter.
create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue());
513 Value sinImag = rewriter.
create<math::SinOp>(loc, imag, fmf.getValue());
515 rewriter.
create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue());
525 arith::FastMathFlagsAttr fmf) {
526 auto argType = mlir::cast<FloatType>(arg.
getType());
529 for (
unsigned i = 1; i < coefficients.size(); ++i) {
530 poly = b.
create<math::FmaOp>(
546 matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
548 auto type = op.getType();
549 auto elemType = mlir::cast<FloatType>(type.getElementType());
551 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
553 Value real = b.
create<complex::ReOp>(adaptor.getComplex());
554 Value imag = b.
create<complex::ImOp>(adaptor.getComplex());
559 Value expm1Real = b.
create<math::ExpM1Op>(real, fmf);
560 Value expReal = b.
create<arith::AddFOp>(expm1Real, one, fmf);
563 Value cosm1Imag = emitCosm1(imag, fmf, b);
564 Value cosImag = b.
create<arith::AddFOp>(cosm1Imag, one, fmf);
567 b.
create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf);
569 Value imagIsZero = b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag,
570 zero, fmf.getValue());
572 imagIsZero, zero, b.
create<arith::MulFOp>(expReal, sinImag, fmf));
580 Value emitCosm1(
Value arg, arith::FastMathFlagsAttr fmf,
582 auto argType = mlir::cast<FloatType>(arg.
getType());
588 4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
589 2.0876754287081521758361E-9, -2.7557319214999787979814E-7,
590 2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
591 4.1666666666666666609054E-2,
594 Value forLargeArg = b.
create<arith::AddFOp>(cos, negOne, fmf);
596 Value argPow2 = b.
create<arith::MulFOp>(arg, arg, fmf);
597 Value argPow4 = b.
create<arith::MulFOp>(argPow2, argPow2, fmf);
598 Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf);
601 b.
create<arith::AddFOp>(b.
create<arith::MulFOp>(argPow4, poly, fmf),
602 b.
create<arith::MulFOp>(negHalf, argPow2, fmf));
607 Value cond = b.
create<arith::CmpFOp>(arith::CmpFPredicate::OGE, argPow2,
608 piOver4Pow2, fmf.getValue());
609 return b.
create<arith::SelectOp>(cond, forLargeArg, forSmallArg);
617 matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
619 auto type = cast<ComplexType>(adaptor.getComplex().getType());
620 auto elementType = cast<FloatType>(type.getElementType());
621 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
624 Value abs = b.
create<complex::AbsOp>(elementType, adaptor.getComplex(),
626 Value resultReal = b.
create<math::LogOp>(elementType,
abs, fmf.getValue());
627 Value real = b.
create<complex::ReOp>(elementType, adaptor.getComplex());
628 Value imag = b.
create<complex::ImOp>(elementType, adaptor.getComplex());
630 b.
create<math::Atan2Op>(elementType, imag, real, fmf.getValue());
641 matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
643 auto type = cast<ComplexType>(adaptor.getComplex().getType());
644 auto elementType = cast<FloatType>(type.getElementType());
645 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
648 Value real = b.
create<complex::ReOp>(adaptor.getComplex());
649 Value imag = b.
create<complex::ImOp>(adaptor.getComplex());
651 Value half = b.
create<arith::ConstantOp>(elementType,
653 Value one = b.
create<arith::ConstantOp>(elementType,
655 Value realPlusOne = b.
create<arith::AddFOp>(real, one, fmf);
656 Value absRealPlusOne = b.
create<math::AbsFOp>(realPlusOne, fmf);
659 Value maxAbs = b.
create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf);
660 Value minAbs = b.
create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf);
662 Value useReal = b.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
663 realPlusOne, absImag, fmf);
664 Value maxMinusOne = b.
create<arith::SubFOp>(maxAbs, one, fmf);
665 Value maxAbsOfRealPlusOneAndImagMinusOne =
666 b.
create<arith::SelectOp>(useReal, real, maxMinusOne);
667 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
668 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
669 Value minMaxRatio = b.
create<arith::DivFOp>(minAbs, maxAbs, fmfWithNaNInf);
670 Value logOfMaxAbsOfRealPlusOneAndImag =
671 b.
create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf);
673 b.
create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmfWithNaNInf),
676 b.
create<arith::MulFOp>(half, logOfSqrtPart, fmfWithNaNInf),
677 logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf);
679 b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf),
681 Value resultImag = b.
create<math::Atan2Op>(imag, realPlusOne, fmf);
692 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
695 auto type = cast<ComplexType>(adaptor.getLhs().getType());
696 auto elementType = cast<FloatType>(type.getElementType());
697 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
698 auto fmfValue = fmf.getValue();
700 Value lhsReal = b.
create<complex::ReOp>(elementType, adaptor.getLhs());
701 Value lhsRealAbs = b.
create<math::AbsFOp>(lhsReal, fmfValue);
702 Value lhsImag = b.
create<complex::ImOp>(elementType, adaptor.getLhs());
703 Value lhsImagAbs = b.
create<math::AbsFOp>(lhsImag, fmfValue);
704 Value rhsReal = b.
create<complex::ReOp>(elementType, adaptor.getRhs());
705 Value rhsRealAbs = b.
create<math::AbsFOp>(rhsReal, fmfValue);
706 Value rhsImag = b.
create<complex::ImOp>(elementType, adaptor.getRhs());
707 Value rhsImagAbs = b.
create<math::AbsFOp>(rhsImag, fmfValue);
709 Value lhsRealTimesRhsReal =
710 b.
create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
711 Value lhsRealTimesRhsRealAbs =
712 b.
create<math::AbsFOp>(lhsRealTimesRhsReal, fmfValue);
713 Value lhsImagTimesRhsImag =
714 b.
create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
715 Value lhsImagTimesRhsImagAbs =
716 b.
create<math::AbsFOp>(lhsImagTimesRhsImag, fmfValue);
717 Value real = b.
create<arith::SubFOp>(lhsRealTimesRhsReal,
718 lhsImagTimesRhsImag, fmfValue);
720 Value lhsImagTimesRhsReal =
721 b.
create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
722 Value lhsImagTimesRhsRealAbs =
723 b.
create<math::AbsFOp>(lhsImagTimesRhsReal, fmfValue);
724 Value lhsRealTimesRhsImag =
725 b.
create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
726 Value lhsRealTimesRhsImagAbs =
727 b.
create<math::AbsFOp>(lhsRealTimesRhsImag, fmfValue);
728 Value imag = b.
create<arith::AddFOp>(lhsImagTimesRhsReal,
729 lhsRealTimesRhsImag, fmfValue);
733 b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
735 b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
736 Value isNan = b.
create<arith::AndIOp>(realIsNan, imagIsNan);
741 APFloat::getInf(elementType.getFloatSemantics())));
745 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
747 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
748 Value lhsIsInf = b.
create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
750 b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
752 b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
755 Value one = b.
create<arith::ConstantOp>(elementType,
757 Value lhsRealIsInfFloat =
758 b.
create<arith::SelectOp>(lhsRealIsInf, one, zero);
759 lhsReal = b.
create<arith::SelectOp>(
760 lhsIsInf, b.
create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
762 Value lhsImagIsInfFloat =
763 b.
create<arith::SelectOp>(lhsImagIsInf, one, zero);
764 lhsImag = b.
create<arith::SelectOp>(
765 lhsIsInf, b.
create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
767 Value lhsIsInfAndRhsRealIsNan =
768 b.
create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
769 rhsReal = b.
create<arith::SelectOp>(
770 lhsIsInfAndRhsRealIsNan, b.
create<math::CopySignOp>(zero, rhsReal),
772 Value lhsIsInfAndRhsImagIsNan =
773 b.
create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
774 rhsImag = b.
create<arith::SelectOp>(
775 lhsIsInfAndRhsImagIsNan, b.
create<math::CopySignOp>(zero, rhsImag),
780 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
782 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
783 Value rhsIsInf = b.
create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
785 b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
787 b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
788 Value rhsRealIsInfFloat =
789 b.
create<arith::SelectOp>(rhsRealIsInf, one, zero);
790 rhsReal = b.
create<arith::SelectOp>(
791 rhsIsInf, b.
create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
793 Value rhsImagIsInfFloat =
794 b.
create<arith::SelectOp>(rhsImagIsInf, one, zero);
795 rhsImag = b.
create<arith::SelectOp>(
796 rhsIsInf, b.
create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
798 Value rhsIsInfAndLhsRealIsNan =
799 b.
create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
800 lhsReal = b.
create<arith::SelectOp>(
801 rhsIsInfAndLhsRealIsNan, b.
create<math::CopySignOp>(zero, lhsReal),
803 Value rhsIsInfAndLhsImagIsNan =
804 b.
create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
805 lhsImag = b.
create<arith::SelectOp>(
806 rhsIsInfAndLhsImagIsNan, b.
create<math::CopySignOp>(zero, lhsImag),
808 Value recalc = b.
create<arith::OrIOp>(lhsIsInf, rhsIsInf);
812 Value lhsRealTimesRhsRealIsInf = b.
create<arith::CmpFOp>(
813 arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
814 Value lhsImagTimesRhsImagIsInf = b.
create<arith::CmpFOp>(
815 arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
816 Value isSpecialCase = b.
create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
817 lhsImagTimesRhsImagIsInf);
818 Value lhsRealTimesRhsImagIsInf = b.
create<arith::CmpFOp>(
819 arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
821 b.
create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
822 Value lhsImagTimesRhsRealIsInf = b.
create<arith::CmpFOp>(
823 arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
825 b.
create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
830 isSpecialCase = b.
create<arith::AndIOp>(isSpecialCase, notRecalc);
831 Value isSpecialCaseAndLhsRealIsNan =
832 b.
create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
833 lhsReal = b.
create<arith::SelectOp>(
834 isSpecialCaseAndLhsRealIsNan, b.
create<math::CopySignOp>(zero, lhsReal),
836 Value isSpecialCaseAndLhsImagIsNan =
837 b.
create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
838 lhsImag = b.
create<arith::SelectOp>(
839 isSpecialCaseAndLhsImagIsNan, b.
create<math::CopySignOp>(zero, lhsImag),
841 Value isSpecialCaseAndRhsRealIsNan =
842 b.
create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
843 rhsReal = b.
create<arith::SelectOp>(
844 isSpecialCaseAndRhsRealIsNan, b.
create<math::CopySignOp>(zero, rhsReal),
846 Value isSpecialCaseAndRhsImagIsNan =
847 b.
create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
848 rhsImag = b.
create<arith::SelectOp>(
849 isSpecialCaseAndRhsImagIsNan, b.
create<math::CopySignOp>(zero, rhsImag),
851 recalc = b.
create<arith::OrIOp>(recalc, isSpecialCase);
852 recalc = b.
create<arith::AndIOp>(isNan, recalc);
855 lhsRealTimesRhsReal = b.
create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
856 lhsImagTimesRhsImag = b.
create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
857 Value newReal = b.
create<arith::SubFOp>(lhsRealTimesRhsReal,
858 lhsImagTimesRhsImag, fmfValue);
859 real = b.
create<arith::SelectOp>(
860 recalc, b.
create<arith::MulFOp>(inf, newReal, fmfValue), real);
863 lhsImagTimesRhsReal = b.
create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
864 lhsRealTimesRhsImag = b.
create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
865 Value newImag = b.
create<arith::AddFOp>(lhsImagTimesRhsReal,
866 lhsRealTimesRhsImag, fmfValue);
867 imag = b.
create<arith::SelectOp>(
868 recalc, b.
create<arith::MulFOp>(inf, newImag, fmfValue), imag);
879 matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
881 auto loc = op.getLoc();
882 auto type = cast<ComplexType>(adaptor.getComplex().getType());
883 auto elementType = cast<FloatType>(type.getElementType());
886 rewriter.
create<complex::ReOp>(loc, elementType, adaptor.getComplex());
888 rewriter.
create<complex::ImOp>(loc, elementType, adaptor.getComplex());
889 Value negReal = rewriter.
create<arith::NegFOp>(loc, real);
890 Value negImag = rewriter.
create<arith::NegFOp>(loc, imag);
896 struct SinOpConversion :
public TrigonometricOpConversion<complex::SinOp> {
897 using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
902 arith::FastMathFlagsAttr fmf)
const override {
913 rewriter.
create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf);
914 Value resultReal = rewriter.
create<arith::MulFOp>(loc, sum, sin, fmf);
916 rewriter.
create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf);
917 Value resultImag = rewriter.
create<arith::MulFOp>(loc, diff, cos, fmf);
918 return {resultReal, resultImag};
927 matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
931 auto type = cast<ComplexType>(op.getType());
932 auto elementType = cast<FloatType>(type.getElementType());
933 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
935 auto cst = [&](APFloat v) {
936 return b.
create<arith::ConstantOp>(elementType,
939 const auto &floatSemantics = elementType.getFloatSemantics();
941 Value half = b.
create<arith::ConstantOp>(elementType,
944 Value real = b.
create<complex::ReOp>(elementType, adaptor.getComplex());
945 Value imag = b.
create<complex::ImOp>(elementType, adaptor.getComplex());
946 Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
947 Value argArg = b.
create<math::Atan2Op>(imag, real, fmf);
948 Value sqrtArg = b.
create<arith::MulFOp>(argArg, half, fmf);
954 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf);
956 Value resultReal = b.
create<arith::MulFOp>(absSqrt, cos, fmf);
958 sinIsZero, zero, b.
create<arith::MulFOp>(absSqrt, sin, fmf));
959 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
960 arith::FastMathFlags::ninf)) {
961 Value inf = cst(APFloat::getInf(floatSemantics));
962 Value negInf = cst(APFloat::getInf(floatSemantics,
true));
963 Value nan = cst(APFloat::getNaN(floatSemantics));
964 Value absImag = b.
create<math::AbsFOp>(elementType, imag, fmf);
967 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
968 Value absImagIsNotInf =
969 b.
create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf);
971 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf);
973 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf);
975 resultReal = b.
create<arith::SelectOp>(
976 b.
create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero,
978 resultReal = b.
create<arith::SelectOp>(
979 b.
create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal);
981 Value imagSignInf = b.
create<math::CopySignOp>(inf, imag, fmf);
982 resultImag = b.
create<arith::SelectOp>(
983 b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt),
985 resultImag = b.
create<arith::SelectOp>(
986 b.
create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf,
991 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
992 resultReal = b.
create<arith::SelectOp>(resultIsZero, zero, resultReal);
993 resultImag = b.
create<arith::SelectOp>(resultIsZero, zero, resultImag);
1005 matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
1007 auto type = cast<ComplexType>(adaptor.getComplex().getType());
1008 auto elementType = cast<FloatType>(type.getElementType());
1010 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
1012 Value real = b.
create<complex::ReOp>(elementType, adaptor.getComplex());
1013 Value imag = b.
create<complex::ImOp>(elementType, adaptor.getComplex());
1017 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
1019 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
1020 Value isZero = b.
create<arith::AndIOp>(realIsZero, imagIsZero);
1021 auto abs = b.
create<complex::AbsOp>(elementType, adaptor.getComplex(), fmf);
1024 Value sign = b.
create<complex::CreateOp>(type, realSign, imagSign);
1026 adaptor.getComplex(), sign);
1031 template <
typename Op>
1036 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
1040 auto type = cast<ComplexType>(adaptor.getComplex().getType());
1041 auto elementType = cast<FloatType>(type.getElementType());
1042 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
1043 const auto &floatSemantics = elementType.getFloatSemantics();
1046 b.
create<complex::ReOp>(loc, elementType, adaptor.getComplex());
1048 b.
create<complex::ImOp>(loc, elementType, adaptor.getComplex());
1052 if constexpr (std::is_same_v<Op, complex::TanOp>) {
1054 std::swap(real, imag);
1055 real = b.
create<arith::MulFOp>(real, negOne, fmf);
1058 auto cst = [&](APFloat v) {
1059 return b.
create<arith::ConstantOp>(elementType,
1062 Value inf = cst(APFloat::getInf(floatSemantics));
1063 Value four = b.
create<arith::ConstantOp>(elementType,
1065 Value twoReal = b.
create<arith::AddFOp>(real, real, fmf);
1066 Value negTwoReal = b.
create<arith::MulFOp>(negOne, twoReal, fmf);
1068 Value expTwoRealMinusOne = b.
create<math::ExpM1Op>(twoReal, fmf);
1069 Value expNegTwoRealMinusOne = b.
create<math::ExpM1Op>(negTwoReal, fmf);
1071 b.
create<arith::SubFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
1074 Value cosImagSq = b.
create<arith::MulFOp>(cosImag, cosImag, fmf);
1075 Value twoCosTwoImagPlusOne = b.
create<arith::MulFOp>(cosImagSq, four, fmf);
1079 four, b.
create<arith::MulFOp>(cosImag, sinImag, fmf), fmf);
1081 Value expSumMinusTwo =
1082 b.
create<arith::AddFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
1084 b.
create<arith::AddFOp>(expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
1086 Value isInf = b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
1087 expSumMinusTwo, inf, fmf);
1088 Value realLimit = b.
create<math::CopySignOp>(negOne, real, fmf);
1091 isInf, realLimit, b.
create<arith::DivFOp>(realNum, denom, fmf));
1092 Value resultImag = b.
create<arith::DivFOp>(imagNum, denom, fmf);
1094 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1095 arith::FastMathFlags::ninf)) {
1096 Value absReal = b.
create<math::AbsFOp>(real, fmf);
1099 Value nan = cst(APFloat::getNaN(floatSemantics));
1101 Value absRealIsInf =
1102 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
1104 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
1106 absRealIsInf, b.
create<arith::ConstantIntOp>(
true, 1));
1108 Value imagNumIsNaN = b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO,
1109 imagNum, imagNum, fmf);
1110 Value resultRealIsNaN =
1111 b.
create<arith::AndIOp>(imagNumIsNaN, absRealIsNotInf);
1113 imagIsZero, b.
create<arith::AndIOp>(absRealIsInf, imagNumIsNaN));
1115 resultReal = b.
create<arith::SelectOp>(resultRealIsNaN, nan, resultReal);
1117 b.
create<arith::SelectOp>(resultImagIsZero, zero, resultImag);
1120 if constexpr (std::is_same_v<Op, complex::TanOp>) {
1122 std::swap(resultReal, resultImag);
1123 resultImag = b.
create<arith::MulFOp>(resultImag, negOne, fmf);
1136 matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
1138 auto loc = op.getLoc();
1139 auto type = cast<ComplexType>(adaptor.getComplex().getType());
1140 auto elementType = cast<FloatType>(type.getElementType());
1142 rewriter.
create<complex::ReOp>(loc, elementType, adaptor.getComplex());
1144 rewriter.
create<complex::ImOp>(loc, elementType, adaptor.getComplex());
1145 Value negImag = rewriter.
create<arith::NegFOp>(loc, elementType, imag);
1158 arith::FastMathFlags fmf) {
1159 auto elementType = cast<FloatType>(type.getElementType());
1167 Value negD = builder.
create<arith::NegFOp>(d, fmf);
1168 Value argLhs = builder.
create<math::Atan2Op>(b, a, fmf);
1169 Value negDArgLhs = builder.
create<arith::MulFOp>(negD, argLhs, fmf);
1170 Value expNegDArgLhs = builder.
create<math::ExpOp>(negDArgLhs, fmf);
1172 Value coeff = builder.
create<arith::MulFOp>(absToC, expNegDArgLhs, fmf);
1174 Value cArgLhs = builder.
create<arith::MulFOp>(c, argLhs, fmf);
1175 Value dLnAbs = builder.
create<arith::MulFOp>(d, lnAbs, fmf);
1176 Value q = builder.
create<arith::AddFOp>(cArgLhs, dLnAbs, fmf);
1183 APFloat::getInf(elementType.getFloatSemantics())));
1188 Value complexOne = builder.
create<complex::CreateOp>(type, one, zero);
1189 Value complexZero = builder.
create<complex::CreateOp>(type, zero, zero);
1190 Value complexInf = builder.
create<complex::CreateOp>(type, inf, zero);
1197 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
abs, zero, fmf);
1199 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero, fmf);
1201 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero, fmf);
1203 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, b, zero, fmf);
1206 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OLE, zero, c, fmf);
1207 Value coeffCosQ = builder.
create<arith::MulFOp>(coeff, cosQ, fmf);
1208 Value coeffSinQ = builder.
create<arith::MulFOp>(coeff, sinQ, fmf);
1209 Value complexOneOrZero =
1210 builder.
create<arith::SelectOp>(cEqZero, complexOne, complexZero);
1212 builder.
create<complex::CreateOp>(type, coeffCosQ, coeffSinQ);
1214 builder.
create<arith::AndIOp>(
1215 builder.
create<arith::AndIOp>(absEqZero, dEqZero), zeroLeC),
1216 complexOneOrZero, coeffCosSin);
1222 Value rhsEqZero = builder.
create<arith::AndIOp>(cEqZero, dEqZero);
1224 builder.
create<arith::SelectOp>(rhsEqZero, complexOne, cutoff0);
1229 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf),
1232 builder.
create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1);
1237 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf),
1241 builder.
create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf));
1243 builder.
create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2);
1247 Value rhsLt0 = builder.create<arith::AndIOp>(
1249 builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf));
1250 Value cutoff4 = builder.create<arith::SelectOp>(
1251 builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3);
1260 matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
1263 auto type = cast<ComplexType>(adaptor.getLhs().getType());
1264 auto elementType = cast<FloatType>(type.getElementType());
1266 Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
1267 Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
1269 rewriter.
replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
1270 c, d, op.getFastmath())});
1279 matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
1282 auto type = cast<ComplexType>(adaptor.getComplex().getType());
1283 auto elementType = cast<FloatType>(type.getElementType());
1285 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
1287 auto cst = [&](APFloat v) {
1288 return b.
create<arith::ConstantOp>(elementType,
1291 const auto &floatSemantics = elementType.getFloatSemantics();
1293 Value inf = cst(APFloat::getInf(floatSemantics));
1296 Value nan = cst(APFloat::getNaN(floatSemantics));
1298 Value real = b.
create<complex::ReOp>(elementType, adaptor.getComplex());
1299 Value imag = b.
create<complex::ImOp>(elementType, adaptor.getComplex());
1300 Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
1301 Value argArg = b.
create<math::Atan2Op>(imag, real, fmf);
1302 Value rsqrtArg = b.
create<arith::MulFOp>(argArg, negHalf, fmf);
1306 Value resultReal = b.
create<arith::MulFOp>(absRsqrt, cos, fmf);
1307 Value resultImag = b.
create<arith::MulFOp>(absRsqrt, sin, fmf);
1309 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1310 arith::FastMathFlags::ninf)) {
1314 Value realSignedZero = b.
create<math::CopySignOp>(zero, real, fmf);
1315 Value imagSignedZero = b.
create<math::CopySignOp>(zero, imag, fmf);
1316 Value negImagSignedZero =
1317 b.
create<arith::MulFOp>(negOne, imagSignedZero, fmf);
1319 Value absReal = b.
create<math::AbsFOp>(real, fmf);
1320 Value absImag = b.
create<math::AbsFOp>(imag, fmf);
1322 Value absImagIsInf =
1323 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
1325 b.
create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf);
1327 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
1328 Value inIsNanInf = b.
create<arith::AndIOp>(absImagIsInf, realIsNan);
1330 Value resultIsZero = b.
create<arith::OrIOp>(inIsNanInf, realIsInf);
1333 b.
create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal);
1334 resultImag = b.
create<arith::SelectOp>(resultIsZero, negImagSignedZero,
1339 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf);
1341 b.
create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
1342 Value isZero = b.
create<arith::AndIOp>(isRealZero, isImagZero);
1344 resultReal = b.
create<arith::SelectOp>(isZero, inf, resultReal);
1345 resultImag = b.
create<arith::SelectOp>(isZero, nan, resultImag);
1357 matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
1359 auto loc = op.getLoc();
1360 auto type = op.getType();
1361 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
1364 rewriter.
create<complex::ReOp>(loc, type, adaptor.getComplex());
1366 rewriter.
create<complex::ImOp>(loc, type, adaptor.getComplex());
1383 BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1384 BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
1385 ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
1386 ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
1399 TanTanhOpConversion<complex::TanOp>,
1400 TanTanhOpConversion<complex::TanhOp>,
1408 struct ConvertComplexToStandardPass
1409 :
public impl::ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
1410 void runOnOperation()
override;
1413 void ConvertComplexToStandardPass::runOnOperation() {
1419 target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
1420 target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1423 signalPassFailure();
1428 return std::make_unique<ConvertComplexToStandardPass>();
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
TypedAttr getZeroAttr(Type type)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Fraction abs(const Fraction &f)
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
Include the generated interface declarations.
void populateComplexToStandardConversionPatterns(RewritePatternSet &patterns)
Populate the given list with patterns that convert from Complex to Standard.
std::unique_ptr< Pass > createConvertComplexToStandardPass()
Create a pass to convert Complex operations to the Standard dialect.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.