20#define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARDPASS
21#include "mlir/Conversion/Passes.h.inc"
28enum class AbsFn { abs, sqrt, rsqrt };
33 Value one = arith::ConstantOp::create(
b,
real.getType(),
34 b.getFloatAttr(
real.getType(), 1.0));
36 Value absReal = math::AbsFOp::create(
b,
real, fmf);
37 Value absImag = math::AbsFOp::create(
b,
imag, fmf);
39 Value max = arith::MaximumFOp::create(
b, absReal, absImag, fmf);
40 Value min = arith::MinimumFOp::create(
b, absReal, absImag, fmf);
43 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
44 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
45 Value ratio = arith::DivFOp::create(
b,
min,
max, fmfWithNaNInf);
46 Value ratioSq = arith::MulFOp::create(
b, ratio, ratio, fmfWithNaNInf);
47 Value ratioSqPlusOne = arith::AddFOp::create(
b, ratioSq, one, fmfWithNaNInf);
50 if (fn == AbsFn::rsqrt) {
51 ratioSqPlusOne = math::RsqrtOp::create(
b, ratioSqPlusOne, fmfWithNaNInf);
52 min = math::RsqrtOp::create(
b,
min, fmfWithNaNInf);
53 max = math::RsqrtOp::create(
b,
max, fmfWithNaNInf);
56 if (fn == AbsFn::sqrt) {
57 Value quarter = arith::ConstantOp::create(
58 b,
real.getType(),
b.getFloatAttr(
real.getType(), 0.25));
60 Value sqrt = math::SqrtOp::create(
b,
max, fmfWithNaNInf);
62 math::PowFOp::create(
b, ratioSqPlusOne, quarter, fmfWithNaNInf);
63 result = arith::MulFOp::create(
b, sqrt, p025, fmfWithNaNInf);
65 Value sqrt = math::SqrtOp::create(
b, ratioSqPlusOne, fmfWithNaNInf);
66 result = arith::MulFOp::create(
b,
max, sqrt, fmfWithNaNInf);
69 Value isNaN = arith::CmpFOp::create(
b, arith::CmpFPredicate::UNO,
result,
71 return arith::SelectOp::create(
b, isNaN,
min,
result);
74struct AbsOpConversion :
public OpConversionPattern<complex::AbsOp> {
75 using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
78 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
79 ConversionPatternRewriter &rewriter)
const override {
80 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
82 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
84 Value
real = complex::ReOp::create(
b, adaptor.getComplex());
85 Value
imag = complex::ImOp::create(
b, adaptor.getComplex());
86 rewriter.replaceOp(op, computeAbs(
real,
imag, fmf,
b));
93struct Atan2OpConversion :
public OpConversionPattern<complex::Atan2Op> {
94 using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
97 matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
98 ConversionPatternRewriter &rewriter)
const override {
99 mlir::ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
101 auto type = cast<ComplexType>(op.getType());
102 Type elementType = type.getElementType();
103 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
105 Value
lhs = adaptor.getLhs();
106 Value
rhs = adaptor.getRhs();
108 Value rhsSquared = complex::MulOp::create(
b, type,
rhs,
rhs, fmf);
109 Value lhsSquared = complex::MulOp::create(
b, type,
lhs,
lhs, fmf);
110 Value rhsSquaredPlusLhsSquared =
111 complex::AddOp::create(
b, type, rhsSquared, lhsSquared, fmf);
112 Value sqrtOfRhsSquaredPlusLhsSquared =
113 complex::SqrtOp::create(
b, type, rhsSquaredPlusLhsSquared, fmf);
116 arith::ConstantOp::create(
b, elementType,
b.getZeroAttr(elementType));
117 Value one = arith::ConstantOp::create(
b, elementType,
118 b.getFloatAttr(elementType, 1));
119 Value i = complex::CreateOp::create(
b, type, zero, one);
120 Value iTimesLhs = complex::MulOp::create(
b, i,
lhs, fmf);
121 Value rhsPlusILhs = complex::AddOp::create(
b,
rhs, iTimesLhs, fmf);
123 Value divResult = complex::DivOp::create(
124 b, rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
125 Value logResult = complex::LogOp::create(
b, divResult, fmf);
127 Value negativeOne = arith::ConstantOp::create(
128 b, elementType,
b.getFloatAttr(elementType, -1));
129 Value negativeI = complex::CreateOp::create(
b, type, zero, negativeOne);
131 rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf);
136template <
typename ComparisonOp, arith::CmpFPredicate p>
137struct ComparisonOpConversion :
public OpConversionPattern<ComparisonOp> {
138 using OpConversionPattern<ComparisonOp>::OpConversionPattern;
139 using ResultCombiner =
140 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
141 arith::AndIOp, arith::OrIOp>;
144 matchAndRewrite(ComparisonOp op,
typename ComparisonOp::Adaptor adaptor,
145 ConversionPatternRewriter &rewriter)
const override {
146 auto loc = op.getLoc();
147 auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
150 complex::ReOp::create(rewriter, loc, type, adaptor.getLhs());
152 complex::ImOp::create(rewriter, loc, type, adaptor.getLhs());
154 complex::ReOp::create(rewriter, loc, type, adaptor.getRhs());
156 complex::ImOp::create(rewriter, loc, type, adaptor.getRhs());
157 Value realComparison =
158 arith::CmpFOp::create(rewriter, loc, p, realLhs, realRhs);
159 Value imagComparison =
160 arith::CmpFOp::create(rewriter, loc, p, imagLhs, imagRhs);
162 rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
171template <
typename BinaryComplexOp,
typename BinaryStandardOp>
172struct BinaryComplexOpConversion :
public OpConversionPattern<BinaryComplexOp> {
173 using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
176 matchAndRewrite(BinaryComplexOp op,
typename BinaryComplexOp::Adaptor adaptor,
177 ConversionPatternRewriter &rewriter)
const override {
178 auto type = cast<ComplexType>(adaptor.getLhs().getType());
179 auto elementType = cast<FloatType>(type.getElementType());
180 mlir::ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
181 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
183 Value realLhs = complex::ReOp::create(
b, elementType, adaptor.getLhs());
184 Value realRhs = complex::ReOp::create(
b, elementType, adaptor.getRhs());
185 Value resultReal = BinaryStandardOp::create(
b, elementType, realLhs,
186 realRhs, fmf.getValue());
187 Value imagLhs = complex::ImOp::create(
b, elementType, adaptor.getLhs());
188 Value imagRhs = complex::ImOp::create(
b, elementType, adaptor.getRhs());
189 Value resultImag = BinaryStandardOp::create(
b, elementType, imagLhs,
190 imagRhs, fmf.getValue());
191 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
197template <
typename TrigonometricOp>
198struct TrigonometricOpConversion :
public OpConversionPattern<TrigonometricOp> {
199 using OpAdaptor =
typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
201 using OpConversionPattern<TrigonometricOp>::OpConversionPattern;
204 matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
205 ConversionPatternRewriter &rewriter)
const override {
206 auto loc = op.getLoc();
207 auto type = cast<ComplexType>(adaptor.getComplex().getType());
208 auto elementType = cast<FloatType>(type.getElementType());
209 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
212 complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
214 complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
219 Value half = arith::ConstantOp::create(
220 rewriter, loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
221 Value exp = math::ExpOp::create(rewriter, loc,
imag, fmf);
222 Value scaledExp = arith::MulFOp::create(rewriter, loc, half, exp, fmf);
223 Value reciprocalExp = arith::DivFOp::create(rewriter, loc, half, exp, fmf);
224 Value sin = math::SinOp::create(rewriter, loc,
real, fmf);
225 Value cos = math::CosOp::create(rewriter, loc,
real, fmf);
228 combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
230 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
235 virtual std::pair<Value, Value>
236 combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
237 Value cos, ConversionPatternRewriter &rewriter,
238 arith::FastMathFlagsAttr fmf)
const = 0;
241struct CosOpConversion :
public TrigonometricOpConversion<complex::CosOp> {
242 using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
244 std::pair<Value, Value>
combine(Location loc, Value scaledExp,
245 Value reciprocalExp, Value sin, Value cos,
246 ConversionPatternRewriter &rewriter,
247 arith::FastMathFlagsAttr fmf)
const override {
258 arith::AddFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf);
259 Value resultReal = arith::MulFOp::create(rewriter, loc, sum, cos, fmf);
261 arith::SubFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf);
262 Value resultImag = arith::MulFOp::create(rewriter, loc, diff, sin, fmf);
263 return {resultReal, resultImag};
267struct DivOpConversion :
public OpConversionPattern<complex::DivOp> {
268 DivOpConversion(MLIRContext *context, complex::ComplexRangeFlags
target)
269 : OpConversionPattern<complex::DivOp>(context), complexRange(
target) {}
271 using OpConversionPattern<complex::DivOp>::OpConversionPattern;
274 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
275 ConversionPatternRewriter &rewriter)
const override {
276 auto loc = op.getLoc();
277 auto type = cast<ComplexType>(adaptor.getLhs().getType());
278 auto elementType = cast<FloatType>(type.getElementType());
279 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
282 complex::ReOp::create(rewriter, loc, elementType, adaptor.getLhs());
284 complex::ImOp::create(rewriter, loc, elementType, adaptor.getLhs());
286 complex::ReOp::create(rewriter, loc, elementType, adaptor.getRhs());
288 complex::ImOp::create(rewriter, loc, elementType, adaptor.getRhs());
290 Value resultReal, resultImag;
292 if (complexRange == complex::ComplexRangeFlags::basic ||
293 complexRange == complex::ComplexRangeFlags::none) {
295 rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal,
297 }
else if (complexRange == complex::ComplexRangeFlags::improved) {
299 rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal,
303 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
310 complex::ComplexRangeFlags complexRange;
313struct ExpOpConversion :
public OpConversionPattern<complex::ExpOp> {
314 using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
321 matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
322 ConversionPatternRewriter &rewriter)
const override {
323 auto loc = op.getLoc();
324 auto type = cast<ComplexType>(adaptor.getComplex().getType());
325 auto ET = cast<FloatType>(type.getElementType());
326 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
327 const auto &floatSemantics = ET.getFloatSemantics();
328 ImplicitLocOpBuilder
b(loc, rewriter);
330 Value x = complex::ReOp::create(
b, ET, adaptor.getComplex());
331 Value y = complex::ImOp::create(
b, ET, adaptor.getComplex());
332 Value zero = arith::ConstantOp::create(
b, ET,
b.getZeroAttr(ET));
333 Value half = arith::ConstantOp::create(
b, ET,
b.getFloatAttr(ET, 0.5));
334 Value inf = arith::ConstantOp::create(
335 b, ET,
b.getFloatAttr(ET, APFloat::getInf(floatSemantics)));
337 Value exp = math::ExpOp::create(
b, x, fmf);
338 Value xHalf = arith::MulFOp::create(
b, x, half, fmf);
339 Value expHalf = math::ExpOp::create(
b, xHalf, fmf);
340 Value cos = math::CosOp::create(
b, y, fmf);
341 Value sin = math::SinOp::create(
b, y, fmf);
344 arith::CmpFOp::create(
b, arith::CmpFPredicate::OEQ, exp, inf, fmf);
346 arith::CmpFOp::create(
b, arith::CmpFPredicate::OEQ, y, zero);
349 Value realNormal = arith::MulFOp::create(
b, exp, cos, fmf);
350 Value expHalfCos = arith::MulFOp::create(
b, expHalf, cos, fmf);
351 Value realOverflow = arith::MulFOp::create(
b, expHalfCos, expHalf, fmf);
353 arith::SelectOp::create(
b, expIsInf, realOverflow, realNormal);
357 Value imagNormal = arith::MulFOp::create(
b, exp, sin, fmf);
358 Value expHalfSin = arith::MulFOp::create(
b, expHalf, sin, fmf);
359 Value imagOverflow = arith::MulFOp::create(
b, expHalfSin, expHalf, fmf);
361 arith::SelectOp::create(
b, expIsInf, imagOverflow, imagNormal);
362 Value resultImag = arith::SelectOp::create(
b, yIsZero, zero, imagNonZero);
364 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
372 arith::FastMathFlagsAttr fmf) {
373 auto argType = mlir::cast<FloatType>(arg.
getType());
375 arith::ConstantOp::create(
b,
b.getFloatAttr(argType, coefficients[0]));
376 for (
unsigned i = 1; i < coefficients.size(); ++i) {
377 poly = math::FmaOp::create(
379 arith::ConstantOp::create(
b,
b.getFloatAttr(argType, coefficients[i])),
385struct Expm1OpConversion :
public OpConversionPattern<complex::Expm1Op> {
386 using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
393 matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
394 ConversionPatternRewriter &rewriter)
const override {
396 auto elemType = mlir::cast<FloatType>(type.getElementType());
398 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
399 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
400 Value
real = complex::ReOp::create(
b, adaptor.getComplex());
401 Value
imag = complex::ImOp::create(
b, adaptor.getComplex());
403 Value zero = arith::ConstantOp::create(
b,
b.getFloatAttr(elemType, 0.0));
404 Value one = arith::ConstantOp::create(
b,
b.getFloatAttr(elemType, 1.0));
406 Value expm1Real = math::ExpM1Op::create(
b,
real, fmf);
407 Value expReal = arith::AddFOp::create(
b, expm1Real, one, fmf);
409 Value sinImag = math::SinOp::create(
b,
imag, fmf);
410 Value cosm1Imag = emitCosm1(
imag, fmf,
b);
411 Value cosImag = arith::AddFOp::create(
b, cosm1Imag, one, fmf);
413 Value realResult = arith::AddFOp::create(
414 b, arith::MulFOp::create(
b, expm1Real, cosImag, fmf), cosm1Imag, fmf);
416 Value imagIsZero = arith::CmpFOp::create(
b, arith::CmpFPredicate::OEQ,
imag,
417 zero, fmf.getValue());
418 Value imagResult = arith::SelectOp::create(
419 b, imagIsZero, zero, arith::MulFOp::create(
b, expReal, sinImag, fmf));
421 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult,
427 Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf,
428 ImplicitLocOpBuilder &
b)
const {
429 auto argType = mlir::cast<FloatType>(arg.
getType());
430 auto negHalf = arith::ConstantOp::create(
b,
b.getFloatAttr(argType, -0.5));
431 auto negOne = arith::ConstantOp::create(
b,
b.getFloatAttr(argType, -1.0));
434 SmallVector<double, 7> kCoeffs{
435 4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
436 2.0876754287081521758361E-9, -2.7557319214999787979814E-7,
437 2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
438 4.1666666666666666609054E-2,
440 Value cos = math::CosOp::create(
b, arg, fmf);
441 Value forLargeArg = arith::AddFOp::create(
b, cos, negOne, fmf);
443 Value argPow2 = arith::MulFOp::create(
b, arg, arg, fmf);
444 Value argPow4 = arith::MulFOp::create(
b, argPow2, argPow2, fmf);
445 Value poly = evaluatePolynomial(
b, argPow2, kCoeffs, fmf);
448 arith::AddFOp::create(
b, arith::MulFOp::create(
b, argPow4, poly, fmf),
449 arith::MulFOp::create(
b, negHalf, argPow2, fmf));
453 arith::ConstantOp::create(
b,
b.getFloatAttr(argType, 0.61685));
454 Value cond = arith::CmpFOp::create(
b, arith::CmpFPredicate::OGE, argPow2,
455 piOver4Pow2, fmf.getValue());
456 return arith::SelectOp::create(
b, cond, forLargeArg, forSmallArg);
460struct LogOpConversion :
public OpConversionPattern<complex::LogOp> {
461 using OpConversionPattern<complex::LogOp>::OpConversionPattern;
464 matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
465 ConversionPatternRewriter &rewriter)
const override {
466 auto type = cast<ComplexType>(adaptor.getComplex().getType());
467 auto elementType = cast<FloatType>(type.getElementType());
468 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
469 mlir::ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
471 Value
abs = complex::AbsOp::create(
b, elementType, adaptor.getComplex(),
473 Value resultReal = math::LogOp::create(
b, elementType, abs, fmf.getValue());
474 Value
real = complex::ReOp::create(
b, elementType, adaptor.getComplex());
475 Value
imag = complex::ImOp::create(
b, elementType, adaptor.getComplex());
477 math::Atan2Op::create(
b, elementType,
imag,
real, fmf.getValue());
478 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
484struct Log1pOpConversion :
public OpConversionPattern<complex::Log1pOp> {
485 using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
488 matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
489 ConversionPatternRewriter &rewriter)
const override {
490 auto type = cast<ComplexType>(adaptor.getComplex().getType());
491 auto elementType = cast<FloatType>(type.getElementType());
492 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
493 mlir::ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
495 Value
real = complex::ReOp::create(
b, adaptor.getComplex());
496 Value
imag = complex::ImOp::create(
b, adaptor.getComplex());
498 Value half = arith::ConstantOp::create(
b, elementType,
499 b.getFloatAttr(elementType, 0.5));
500 Value one = arith::ConstantOp::create(
b, elementType,
501 b.getFloatAttr(elementType, 1));
502 Value realPlusOne = arith::AddFOp::create(
b,
real, one, fmf);
503 Value absRealPlusOne = math::AbsFOp::create(
b, realPlusOne, fmf);
504 Value absImag = math::AbsFOp::create(
b,
imag, fmf);
506 Value maxAbs = arith::MaximumFOp::create(
b, absRealPlusOne, absImag, fmf);
507 Value minAbs = arith::MinimumFOp::create(
b, absRealPlusOne, absImag, fmf);
509 Value useReal = arith::CmpFOp::create(
b, arith::CmpFPredicate::OGT,
510 realPlusOne, absImag, fmf);
511 Value maxMinusOne = arith::SubFOp::create(
b, maxAbs, one, fmf);
512 Value maxAbsOfRealPlusOneAndImagMinusOne =
513 arith::SelectOp::create(
b, useReal,
real, maxMinusOne);
514 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
515 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
516 Value minMaxRatio = arith::DivFOp::create(
b, minAbs, maxAbs, fmfWithNaNInf);
517 Value logOfMaxAbsOfRealPlusOneAndImag =
518 math::Log1pOp::create(
b, maxAbsOfRealPlusOneAndImagMinusOne, fmf);
519 Value logOfSqrtPart = math::Log1pOp::create(
520 b, arith::MulFOp::create(
b, minMaxRatio, minMaxRatio, fmfWithNaNInf),
522 Value r = arith::AddFOp::create(
523 b, arith::MulFOp::create(
b, half, logOfSqrtPart, fmfWithNaNInf),
524 logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf);
525 Value resultReal = arith::SelectOp::create(
527 arith::CmpFOp::create(
b, arith::CmpFPredicate::UNO, r, r,
530 Value resultImag = math::Atan2Op::create(
b,
imag, realPlusOne, fmf);
531 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
537struct MulOpConversion :
public OpConversionPattern<complex::MulOp> {
538 using OpConversionPattern<complex::MulOp>::OpConversionPattern;
541 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
542 ConversionPatternRewriter &rewriter)
const override {
543 mlir::ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
544 auto type = cast<ComplexType>(adaptor.getLhs().getType());
545 auto elementType = cast<FloatType>(type.getElementType());
546 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
547 auto fmfValue = fmf.getValue();
548 Value lhsReal = complex::ReOp::create(
b, elementType, adaptor.getLhs());
549 Value lhsImag = complex::ImOp::create(
b, elementType, adaptor.getLhs());
550 Value rhsReal = complex::ReOp::create(
b, elementType, adaptor.getRhs());
551 Value rhsImag = complex::ImOp::create(
b, elementType, adaptor.getRhs());
554 if (arith::bitEnumContainsAll(fmfValue, arith::FastMathFlags::contract)) {
555 Value lhsImagTimesRhsImag =
556 arith::MulFOp::create(
b, lhsImag, rhsImag, fmfValue);
557 Value negLhsImagTimesRhsImag =
558 arith::NegFOp::create(
b, lhsImagTimesRhsImag, fmfValue);
559 real = math::FmaOp::create(
b, lhsReal, rhsReal, negLhsImagTimesRhsImag,
562 Value lhsImagTimesRhsReal =
563 arith::MulFOp::create(
b, lhsImag, rhsReal, fmfValue);
564 imag = math::FmaOp::create(
b, lhsReal, rhsImag, lhsImagTimesRhsReal,
567 Value lhsRealTimesRhsReal =
568 arith::MulFOp::create(
b, lhsReal, rhsReal, fmfValue);
569 Value lhsImagTimesRhsImag =
570 arith::MulFOp::create(
b, lhsImag, rhsImag, fmfValue);
571 Value lhsImagTimesRhsReal =
572 arith::MulFOp::create(
b, lhsImag, rhsReal, fmfValue);
573 Value lhsRealTimesRhsImag =
574 arith::MulFOp::create(
b, lhsReal, rhsImag, fmfValue);
576 real = arith::SubFOp::create(
b, lhsRealTimesRhsReal, lhsImagTimesRhsImag,
578 imag = arith::AddFOp::create(
b, lhsImagTimesRhsReal, lhsRealTimesRhsImag,
581 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type,
real,
imag);
586struct NegOpConversion :
public OpConversionPattern<complex::NegOp> {
587 using OpConversionPattern<complex::NegOp>::OpConversionPattern;
590 matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
591 ConversionPatternRewriter &rewriter)
const override {
592 auto loc = op.getLoc();
593 auto type = cast<ComplexType>(adaptor.getComplex().getType());
594 auto elementType = cast<FloatType>(type.getElementType());
597 complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
599 complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
600 Value negReal = arith::NegFOp::create(rewriter, loc,
real);
601 Value negImag = arith::NegFOp::create(rewriter, loc,
imag);
602 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
607struct SinOpConversion :
public TrigonometricOpConversion<complex::SinOp> {
608 using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
610 std::pair<Value, Value>
combine(Location loc, Value scaledExp,
611 Value reciprocalExp, Value sin, Value cos,
612 ConversionPatternRewriter &rewriter,
613 arith::FastMathFlagsAttr fmf)
const override {
624 arith::AddFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf);
625 Value resultReal = arith::MulFOp::create(rewriter, loc, sum, sin, fmf);
627 arith::SubFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf);
628 Value resultImag = arith::MulFOp::create(rewriter, loc, diff, cos, fmf);
629 return {resultReal, resultImag};
634struct SqrtOpConversion :
public OpConversionPattern<complex::SqrtOp> {
635 using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
638 matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
639 ConversionPatternRewriter &rewriter)
const override {
640 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
642 auto type = cast<ComplexType>(op.getType());
643 auto elementType = cast<FloatType>(type.getElementType());
644 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
646 auto cst = [&](APFloat v) {
647 return arith::ConstantOp::create(
b, elementType,
648 b.getFloatAttr(elementType, v));
650 const auto &floatSemantics = elementType.getFloatSemantics();
651 Value zero = cst(APFloat::getZero(floatSemantics));
652 Value half = arith::ConstantOp::create(
b, elementType,
653 b.getFloatAttr(elementType, 0.5));
655 Value
real = complex::ReOp::create(
b, elementType, adaptor.getComplex());
656 Value
imag = complex::ImOp::create(
b, elementType, adaptor.getComplex());
657 Value absSqrt = computeAbs(
real,
imag, fmf,
b, AbsFn::sqrt);
658 Value argArg = math::Atan2Op::create(
b,
imag,
real, fmf);
659 Value sqrtArg = arith::MulFOp::create(
b, argArg, half, fmf);
660 Value cos = math::CosOp::create(
b, sqrtArg, fmf);
661 Value sin = math::SinOp::create(
b, sqrtArg, fmf);
665 arith::CmpFOp::create(
b, arith::CmpFPredicate::OEQ, sin, zero, fmf);
667 Value resultReal = arith::MulFOp::create(
b, absSqrt, cos, fmf);
668 Value resultImag = arith::SelectOp::create(
669 b, sinIsZero, zero, arith::MulFOp::create(
b, absSqrt, sin, fmf));
670 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
671 arith::FastMathFlags::ninf)) {
672 Value inf = cst(APFloat::getInf(floatSemantics));
673 Value negInf = cst(APFloat::getInf(floatSemantics,
true));
674 Value nan = cst(APFloat::getNaN(floatSemantics));
675 Value absImag = math::AbsFOp::create(
b, elementType,
imag, fmf);
677 Value absImagIsInf = arith::CmpFOp::create(
b, arith::CmpFPredicate::OEQ,
679 Value absImagIsNotInf = arith::CmpFOp::create(
680 b, arith::CmpFPredicate::ONE, absImag, inf, fmf);
682 arith::CmpFOp::create(
b, arith::CmpFPredicate::OEQ,
real, inf, fmf);
683 Value realIsNegInf = arith::CmpFOp::create(
b, arith::CmpFPredicate::OEQ,
686 resultReal = arith::SelectOp::create(
687 b, arith::AndIOp::create(
b, realIsNegInf, absImagIsNotInf), zero,
689 resultReal = arith::SelectOp::create(
690 b, arith::OrIOp::create(
b, absImagIsInf, realIsInf), inf, resultReal);
692 Value imagSignInf = math::CopySignOp::create(
b, inf,
imag, fmf);
693 resultImag = arith::SelectOp::create(
695 arith::CmpFOp::create(
b, arith::CmpFPredicate::UNO, absSqrt, absSqrt),
697 resultImag = arith::SelectOp::create(
698 b, arith::OrIOp::create(
b, absImagIsInf, realIsNegInf), imagSignInf,
703 arith::CmpFOp::create(
b, arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
704 resultReal = arith::SelectOp::create(
b, resultIsZero, zero, resultReal);
705 resultImag = arith::SelectOp::create(
b, resultIsZero, zero, resultImag);
707 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
713struct SignOpConversion :
public OpConversionPattern<complex::SignOp> {
714 using OpConversionPattern<complex::SignOp>::OpConversionPattern;
717 matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
718 ConversionPatternRewriter &rewriter)
const override {
719 auto type = cast<ComplexType>(adaptor.getComplex().getType());
720 auto elementType = cast<FloatType>(type.getElementType());
721 mlir::ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
722 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
724 Value
real = complex::ReOp::create(
b, elementType, adaptor.getComplex());
725 Value
imag = complex::ImOp::create(
b, elementType, adaptor.getComplex());
727 arith::ConstantOp::create(
b, elementType,
b.getZeroAttr(elementType));
729 arith::CmpFOp::create(
b, arith::CmpFPredicate::OEQ,
real, zero);
731 arith::CmpFOp::create(
b, arith::CmpFPredicate::OEQ,
imag, zero);
732 Value isZero = arith::AndIOp::create(
b, realIsZero, imagIsZero);
734 complex::AbsOp::create(
b, elementType, adaptor.getComplex(), fmf);
735 Value realSign = arith::DivFOp::create(
b,
real, abs, fmf);
736 Value imagSign = arith::DivFOp::create(
b,
imag, abs, fmf);
737 Value sign = complex::CreateOp::create(
b, type, realSign, imagSign);
738 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
739 adaptor.getComplex(), sign);
744template <
typename Op>
745struct TanTanhOpConversion :
public OpConversionPattern<Op> {
746 using OpConversionPattern<
Op>::OpConversionPattern;
749 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
750 ConversionPatternRewriter &rewriter)
const override {
751 ImplicitLocOpBuilder
b(op.
getLoc(), rewriter);
753 auto type = cast<ComplexType>(adaptor.getComplex().getType());
754 auto elementType = cast<FloatType>(type.getElementType());
755 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
756 const auto &floatSemantics = elementType.getFloatSemantics();
759 complex::ReOp::create(
b, loc, elementType, adaptor.getComplex());
761 complex::ImOp::create(
b, loc, elementType, adaptor.getComplex());
762 Value negOne = arith::ConstantOp::create(
b, elementType,
763 b.getFloatAttr(elementType, -1.0));
765 if constexpr (std::is_same_v<Op, complex::TanOp>) {
768 real = arith::MulFOp::create(
b,
real, negOne, fmf);
771 auto cst = [&](APFloat v) {
772 return arith::ConstantOp::create(
b, elementType,
773 b.getFloatAttr(elementType, v));
775 Value inf = cst(APFloat::getInf(floatSemantics));
776 Value four = arith::ConstantOp::create(
b, elementType,
777 b.getFloatAttr(elementType, 4.0));
778 Value twoReal = arith::AddFOp::create(
b,
real,
real, fmf);
779 Value negTwoReal = arith::MulFOp::create(
b, negOne, twoReal, fmf);
781 Value expTwoRealMinusOne = math::ExpM1Op::create(
b, twoReal, fmf);
782 Value expNegTwoRealMinusOne = math::ExpM1Op::create(
b, negTwoReal, fmf);
783 Value realNum = arith::SubFOp::create(
b, expTwoRealMinusOne,
784 expNegTwoRealMinusOne, fmf);
786 Value cosImag = math::CosOp::create(
b,
imag, fmf);
787 Value cosImagSq = arith::MulFOp::create(
b, cosImag, cosImag, fmf);
788 Value twoCosTwoImagPlusOne = arith::MulFOp::create(
b, cosImagSq, four, fmf);
789 Value sinImag = math::SinOp::create(
b,
imag, fmf);
791 Value imagNum = arith::MulFOp::create(
792 b, four, arith::MulFOp::create(
b, cosImag, sinImag, fmf), fmf);
794 Value expSumMinusTwo = arith::AddFOp::create(
b, expTwoRealMinusOne,
795 expNegTwoRealMinusOne, fmf);
797 arith::AddFOp::create(
b, expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
799 Value isInf = arith::CmpFOp::create(
b, arith::CmpFPredicate::OEQ,
800 expSumMinusTwo, inf, fmf);
801 Value realLimit = math::CopySignOp::create(
b, negOne,
real, fmf);
803 Value resultReal = arith::SelectOp::create(
804 b, isInf, realLimit, arith::DivFOp::create(
b, realNum, denom, fmf));
805 Value resultImag = arith::DivFOp::create(
b, imagNum, denom, fmf);
807 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
808 arith::FastMathFlags::ninf)) {
809 Value absReal = math::AbsFOp::create(
b,
real, fmf);
810 Value zero = arith::ConstantOp::create(
b, elementType,
811 b.getFloatAttr(elementType, 0.0));
812 Value nan = cst(APFloat::getNaN(floatSemantics));
814 Value absRealIsInf = arith::CmpFOp::create(
b, arith::CmpFPredicate::OEQ,
817 arith::CmpFOp::create(
b, arith::CmpFPredicate::OEQ,
imag, zero, fmf);
818 Value absRealIsNotInf = arith::XOrIOp::create(
821 Value imagNumIsNaN = arith::CmpFOp::create(
b, arith::CmpFPredicate::UNO,
822 imagNum, imagNum, fmf);
823 Value resultRealIsNaN =
824 arith::AndIOp::create(
b, imagNumIsNaN, absRealIsNotInf);
825 Value resultImagIsZero = arith::OrIOp::create(
826 b, imagIsZero, arith::AndIOp::create(
b, absRealIsInf, imagNumIsNaN));
828 resultReal = arith::SelectOp::create(
b, resultRealIsNaN, nan, resultReal);
830 arith::SelectOp::create(
b, resultImagIsZero, zero, resultImag);
833 if constexpr (std::is_same_v<Op, complex::TanOp>) {
835 std::swap(resultReal, resultImag);
836 resultImag = arith::MulFOp::create(
b, resultImag, negOne, fmf);
839 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
845struct ConjOpConversion :
public OpConversionPattern<complex::ConjOp> {
846 using OpConversionPattern<complex::ConjOp>::OpConversionPattern;
849 matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
850 ConversionPatternRewriter &rewriter)
const override {
852 auto type = cast<ComplexType>(adaptor.getComplex().getType());
853 auto elementType = cast<FloatType>(type.getElementType());
855 complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
857 complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
858 Value negImag = arith::NegFOp::create(rewriter, loc, elementType,
imag);
860 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type,
real, negImag);
871 arith::FastMathFlags fmf) {
872 auto elementType = cast<FloatType>(type.getElementType());
874 Value a = complex::ReOp::create(builder,
lhs);
875 Value b = complex::ImOp::create(builder,
lhs);
877 Value abs = complex::AbsOp::create(builder,
lhs, fmf);
878 Value absToC = math::PowFOp::create(builder, abs, c, fmf);
880 Value negD = arith::NegFOp::create(builder, d, fmf);
881 Value argLhs = math::Atan2Op::create(builder,
b, a, fmf);
882 Value negDArgLhs = arith::MulFOp::create(builder, negD, argLhs, fmf);
883 Value expNegDArgLhs = math::ExpOp::create(builder, negDArgLhs, fmf);
885 Value coeff = arith::MulFOp::create(builder, absToC, expNegDArgLhs, fmf);
886 Value lnAbs = math::LogOp::create(builder, abs, fmf);
887 Value cArgLhs = arith::MulFOp::create(builder, c, argLhs, fmf);
888 Value dLnAbs = arith::MulFOp::create(builder, d, lnAbs, fmf);
889 Value q = arith::AddFOp::create(builder, cArgLhs, dLnAbs, fmf);
890 Value cosQ = math::CosOp::create(builder, q, fmf);
891 Value sinQ = math::SinOp::create(builder, q, fmf);
893 Value inf = arith::ConstantOp::create(
894 builder, elementType,
896 APFloat::getInf(elementType.getFloatSemantics())));
897 Value zero = arith::ConstantOp::create(
898 builder, elementType, builder.
getFloatAttr(elementType, 0.0));
899 Value one = arith::ConstantOp::create(builder, elementType,
901 Value complexOne = complex::CreateOp::create(builder, type, one, zero);
902 Value complexZero = complex::CreateOp::create(builder, type, zero, zero);
903 Value complexInf = complex::CreateOp::create(builder, type, inf, zero);
910 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, abs, zero, fmf);
912 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, d, zero, fmf);
914 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, c, zero, fmf);
916 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ,
b, zero, fmf);
919 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLE, zero, c, fmf);
920 Value coeffCosQ = arith::MulFOp::create(builder, coeff, cosQ, fmf);
921 Value coeffSinQ = arith::MulFOp::create(builder, coeff, sinQ, fmf);
922 Value complexOneOrZero =
923 arith::SelectOp::create(builder, cEqZero, complexOne, complexZero);
925 complex::CreateOp::create(builder, type, coeffCosQ, coeffSinQ);
926 Value cutoff0 = arith::SelectOp::create(
928 arith::AndIOp::create(
929 builder, arith::AndIOp::create(builder, absEqZero, dEqZero), zeroLeC),
930 complexOneOrZero, coeffCosSin);
936 Value rhsEqZero = arith::AndIOp::create(builder, cEqZero, dEqZero);
938 arith::SelectOp::create(builder, rhsEqZero, complexOne, cutoff0);
942 Value lhsEqOne = arith::AndIOp::create(
944 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, one, fmf),
947 arith::SelectOp::create(builder, lhsEqOne, complexOne, cutoff1);
951 Value lhsEqInf = arith::AndIOp::create(
953 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, inf, fmf),
955 Value rhsGt0 = arith::AndIOp::create(
957 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, c, zero, fmf));
958 Value cutoff3 = arith::SelectOp::create(
959 builder, arith::AndIOp::create(builder, lhsEqInf, rhsGt0), complexInf,
964 Value rhsLt0 = arith::AndIOp::create(
966 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, c, zero, fmf));
967 Value cutoff4 = arith::SelectOp::create(
968 builder, arith::AndIOp::create(builder, lhsEqInf, rhsLt0), complexZero,
974struct PowiOpConversion :
public OpConversionPattern<complex::PowiOp> {
975 using OpConversionPattern<complex::PowiOp>::OpConversionPattern;
978 matchAndRewrite(complex::PowiOp op, OpAdaptor adaptor,
979 ConversionPatternRewriter &rewriter)
const override {
980 ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
981 auto type = cast<ComplexType>(op.getType());
982 auto elementType = cast<FloatType>(type.getElementType());
984 Value floatExponent =
985 arith::SIToFPOp::create(builder, elementType, adaptor.getRhs());
986 Value zero = arith::ConstantOp::create(
987 builder, elementType, builder.
getFloatAttr(elementType, 0.0));
988 Value complexExponent =
989 complex::CreateOp::create(builder, type, floatExponent, zero);
991 auto pow = complex::PowOp::create(builder, type, adaptor.getLhs(),
992 complexExponent, op.getFastmathAttr());
993 rewriter.replaceOp(op, pow.getResult());
998struct PowOpConversion :
public OpConversionPattern<complex::PowOp> {
999 using OpConversionPattern<complex::PowOp>::OpConversionPattern;
1002 matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
1003 ConversionPatternRewriter &rewriter)
const override {
1004 mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
1005 auto type = cast<ComplexType>(adaptor.getLhs().getType());
1006 auto elementType = cast<FloatType>(type.getElementType());
1008 Value c = complex::ReOp::create(builder, elementType, adaptor.getRhs());
1009 Value d = complex::ImOp::create(builder, elementType, adaptor.getRhs());
1011 rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
1012 c, d, op.getFastmath())});
1017struct RsqrtOpConversion :
public OpConversionPattern<complex::RsqrtOp> {
1018 using OpConversionPattern<complex::RsqrtOp>::OpConversionPattern;
1021 matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
1022 ConversionPatternRewriter &rewriter)
const override {
1023 mlir::ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
1024 auto type = cast<ComplexType>(adaptor.getComplex().getType());
1025 auto elementType = cast<FloatType>(type.getElementType());
1027 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
1029 auto cst = [&](APFloat v) {
1030 return arith::ConstantOp::create(
b, elementType,
1031 b.getFloatAttr(elementType, v));
1033 const auto &floatSemantics = elementType.getFloatSemantics();
1034 Value zero = cst(APFloat::getZero(floatSemantics));
1035 Value inf = cst(APFloat::getInf(floatSemantics));
1036 Value negHalf = arith::ConstantOp::create(
1037 b, elementType,
b.getFloatAttr(elementType, -0.5));
1038 Value nan = cst(APFloat::getNaN(floatSemantics));
1040 Value
real = complex::ReOp::create(
b, elementType, adaptor.getComplex());
1041 Value
imag = complex::ImOp::create(
b, elementType, adaptor.getComplex());
1042 Value absRsqrt = computeAbs(
real,
imag, fmf,
b, AbsFn::rsqrt);
1043 Value argArg = math::Atan2Op::create(
b,
imag,
real, fmf);
1044 Value rsqrtArg = arith::MulFOp::create(
b, argArg, negHalf, fmf);
1045 Value cos = math::CosOp::create(
b, rsqrtArg, fmf);
1046 Value sin = math::SinOp::create(
b, rsqrtArg, fmf);
1048 Value resultReal = arith::MulFOp::create(
b, absRsqrt, cos, fmf);
1049 Value resultImag = arith::MulFOp::create(
b, absRsqrt, sin, fmf);
1051 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1052 arith::FastMathFlags::ninf)) {
1053 Value negOne = arith::ConstantOp::create(
b, elementType,
1054 b.getFloatAttr(elementType, -1));
1056 Value realSignedZero = math::CopySignOp::create(
b, zero,
real, fmf);
1057 Value imagSignedZero = math::CopySignOp::create(
b, zero,
imag, fmf);
1058 Value negImagSignedZero =
1059 arith::MulFOp::create(
b, negOne, imagSignedZero, fmf);
1061 Value absReal = math::AbsFOp::create(
b,
real, fmf);
1062 Value absImag = math::AbsFOp::create(
b,
imag, fmf);
1064 Value absImagIsInf = arith::CmpFOp::create(
b, arith::CmpFPredicate::OEQ,
1067 arith::CmpFOp::create(
b, arith::CmpFPredicate::UNO,
real,
real, fmf);
1068 Value realIsInf = arith::CmpFOp::create(
b, arith::CmpFPredicate::OEQ,
1070 Value inIsNanInf = arith::AndIOp::create(
b, absImagIsInf, realIsNan);
1072 Value resultIsZero = arith::OrIOp::create(
b, inIsNanInf, realIsInf);
1075 arith::SelectOp::create(
b, resultIsZero, realSignedZero, resultReal);
1076 resultImag = arith::SelectOp::create(
b, resultIsZero, negImagSignedZero,
1081 arith::CmpFOp::create(
b, arith::CmpFPredicate::OEQ,
real, zero, fmf);
1083 arith::CmpFOp::create(
b, arith::CmpFPredicate::OEQ,
imag, zero, fmf);
1084 Value isZero = arith::AndIOp::create(
b, isRealZero, isImagZero);
1086 resultReal = arith::SelectOp::create(
b, isZero, inf, resultReal);
1087 resultImag = arith::SelectOp::create(
b, isZero, nan, resultImag);
1089 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
1095struct AngleOpConversion :
public OpConversionPattern<complex::AngleOp> {
1096 using OpConversionPattern<complex::AngleOp>::OpConversionPattern;
1099 matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
1100 ConversionPatternRewriter &rewriter)
const override {
1101 auto loc = op.getLoc();
1102 auto type = op.getType();
1103 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
1106 complex::ReOp::create(rewriter, loc, type, adaptor.getComplex());
1108 complex::ImOp::create(rewriter, loc, type, adaptor.getComplex());
1110 rewriter.replaceOpWithNewOp<math::Atan2Op>(op,
imag,
real, fmf);
1125 BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1126 BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
1127 ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
1128 ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
1140 TanTanhOpConversion<complex::TanOp>,
1141 TanTanhOpConversion<complex::TanhOp>,
1147 patterns.
add<DivOpConversion>(patterns.
getContext(), complexRange);
1153struct ConvertComplexToStandardPass
1154 :
public impl::ConvertComplexToStandardPassBase<
1155 ConvertComplexToStandardPass> {
1158 void runOnOperation()
override;
1161void ConvertComplexToStandardPass::runOnOperation() {
1167 target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
1168 target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1170 applyPartialConversion(getOperation(),
target, std::move(patterns))))
1171 signalPassFailure();
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
FloatAttr getFloatAttr(Type type, double value)
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
void convertDivToStandardUsingAlgebraic(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the arith/math dialects using algebraic method
void convertDivToStandardUsingRangeReduction(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the arith/math dialects using Smith's method
Fraction abs(const Fraction &f)
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
Include the generated interface declarations.
constexpr T real(const NonFloatComplex< T > &x)
void populateComplexToStandardConversionPatterns(RewritePatternSet &patterns, mlir::complex::ComplexRangeFlags complexRange=mlir::complex::ComplexRangeFlags::improved)
Populate the given list with patterns that convert from Complex to Standard.
constexpr T imag(const NonFloatComplex< T > &x)