MLIR  20.0.0git
ComplexToStandard.cpp
Go to the documentation of this file.
1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Pass/Pass.h"
18 #include <memory>
19 #include <type_traits>
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARD
23 #include "mlir/Conversion/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 
28 namespace {
29 
30 enum class AbsFn { abs, sqrt, rsqrt };
31 
32 // Returns the absolute value, its square root or its reciprocal square root.
33 Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
34  ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) {
35  Value one = b.create<arith::ConstantOp>(real.getType(),
36  b.getFloatAttr(real.getType(), 1.0));
37 
38  Value absReal = b.create<math::AbsFOp>(real, fmf);
39  Value absImag = b.create<math::AbsFOp>(imag, fmf);
40 
41  Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
42  Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
43 
44  // The lowering below requires NaNs and infinities to work correctly.
45  arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
46  fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
47  Value ratio = b.create<arith::DivFOp>(min, max, fmfWithNaNInf);
48  Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmfWithNaNInf);
49  Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmfWithNaNInf);
50  Value result;
51 
52  if (fn == AbsFn::rsqrt) {
53  ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
54  min = b.create<math::RsqrtOp>(min, fmfWithNaNInf);
55  max = b.create<math::RsqrtOp>(max, fmfWithNaNInf);
56  }
57 
58  if (fn == AbsFn::sqrt) {
59  Value quarter = b.create<arith::ConstantOp>(
60  real.getType(), b.getFloatAttr(real.getType(), 0.25));
61  // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
62  Value sqrt = b.create<math::SqrtOp>(max, fmfWithNaNInf);
63  Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmfWithNaNInf);
64  result = b.create<arith::MulFOp>(sqrt, p025, fmfWithNaNInf);
65  } else {
66  Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
67  result = b.create<arith::MulFOp>(max, sqrt, fmfWithNaNInf);
68  }
69 
70  Value isNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result,
71  result, fmfWithNaNInf);
72  return b.create<arith::SelectOp>(isNaN, min, result);
73 }
74 
75 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
77 
78  LogicalResult
79  matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
80  ConversionPatternRewriter &rewriter) const override {
81  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
82 
83  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
84 
85  Value real = b.create<complex::ReOp>(adaptor.getComplex());
86  Value imag = b.create<complex::ImOp>(adaptor.getComplex());
87  rewriter.replaceOp(op, computeAbs(real, imag, fmf, b));
88 
89  return success();
90  }
91 };
92 
93 // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
94 struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
96 
97  LogicalResult
98  matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
99  ConversionPatternRewriter &rewriter) const override {
100  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
101 
102  auto type = cast<ComplexType>(op.getType());
103  Type elementType = type.getElementType();
104  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
105 
106  Value lhs = adaptor.getLhs();
107  Value rhs = adaptor.getRhs();
108 
109  Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs, fmf);
110  Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs, fmf);
111  Value rhsSquaredPlusLhsSquared =
112  b.create<complex::AddOp>(type, rhsSquared, lhsSquared, fmf);
113  Value sqrtOfRhsSquaredPlusLhsSquared =
114  b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared, fmf);
115 
116  Value zero =
117  b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
118  Value one = b.create<arith::ConstantOp>(elementType,
119  b.getFloatAttr(elementType, 1));
120  Value i = b.create<complex::CreateOp>(type, zero, one);
121  Value iTimesLhs = b.create<complex::MulOp>(i, lhs, fmf);
122  Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs, fmf);
123 
124  Value divResult = b.create<complex::DivOp>(
125  rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
126  Value logResult = b.create<complex::LogOp>(divResult, fmf);
127 
128  Value negativeOne = b.create<arith::ConstantOp>(
129  elementType, b.getFloatAttr(elementType, -1));
130  Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
131 
132  rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf);
133  return success();
134  }
135 };
136 
137 template <typename ComparisonOp, arith::CmpFPredicate p>
138 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
140  using ResultCombiner =
141  std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
142  arith::AndIOp, arith::OrIOp>;
143 
144  LogicalResult
145  matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
146  ConversionPatternRewriter &rewriter) const override {
147  auto loc = op.getLoc();
148  auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
149 
150  Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
151  Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
152  Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
153  Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
154  Value realComparison =
155  rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
156  Value imagComparison =
157  rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
158 
159  rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
160  imagComparison);
161  return success();
162  }
163 };
164 
165 // Default conversion which applies the BinaryStandardOp separately on the real
166 // and imaginary parts. Can for example be used for complex::AddOp and
167 // complex::SubOp.
168 template <typename BinaryComplexOp, typename BinaryStandardOp>
169 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
171 
172  LogicalResult
173  matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
174  ConversionPatternRewriter &rewriter) const override {
175  auto type = cast<ComplexType>(adaptor.getLhs().getType());
176  auto elementType = cast<FloatType>(type.getElementType());
177  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
178  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
179 
180  Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
181  Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
182  Value resultReal = b.create<BinaryStandardOp>(elementType, realLhs, realRhs,
183  fmf.getValue());
184  Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
185  Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
186  Value resultImag = b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs,
187  fmf.getValue());
188  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
189  resultImag);
190  return success();
191  }
192 };
193 
194 template <typename TrigonometricOp>
195 struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
196  using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
197 
199 
200  LogicalResult
201  matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
202  ConversionPatternRewriter &rewriter) const override {
203  auto loc = op.getLoc();
204  auto type = cast<ComplexType>(adaptor.getComplex().getType());
205  auto elementType = cast<FloatType>(type.getElementType());
206  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
207 
208  Value real =
209  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
210  Value imag =
211  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
212 
213  // Trigonometric ops use a set of common building blocks to convert to real
214  // ops. Here we create these building blocks and call into an op-specific
215  // implementation in the subclass to combine them.
216  Value half = rewriter.create<arith::ConstantOp>(
217  loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
218  Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf);
219  Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp, fmf);
220  Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp, fmf);
221  Value sin = rewriter.create<math::SinOp>(loc, real, fmf);
222  Value cos = rewriter.create<math::CosOp>(loc, real, fmf);
223 
224  auto resultPair =
225  combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
226 
227  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
228  resultPair.second);
229  return success();
230  }
231 
232  virtual std::pair<Value, Value>
233  combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
234  Value cos, ConversionPatternRewriter &rewriter,
235  arith::FastMathFlagsAttr fmf) const = 0;
236 };
237 
238 struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
239  using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
240 
241  std::pair<Value, Value> combine(Location loc, Value scaledExp,
242  Value reciprocalExp, Value sin, Value cos,
243  ConversionPatternRewriter &rewriter,
244  arith::FastMathFlagsAttr fmf) const override {
245  // Complex cosine is defined as;
246  // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
247  // Plugging in:
248  // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
249  // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
250  // and defining t := exp(y)
251  // We get:
252  // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
253  // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
254  Value sum =
255  rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf);
256  Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos, fmf);
257  Value diff =
258  rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf);
259  Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin, fmf);
260  return {resultReal, resultImag};
261  }
262 };
263 
264 struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
266 
267  LogicalResult
268  matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
269  ConversionPatternRewriter &rewriter) const override {
270  auto loc = op.getLoc();
271  auto type = cast<ComplexType>(adaptor.getLhs().getType());
272  auto elementType = cast<FloatType>(type.getElementType());
273  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
274 
275  Value lhsReal =
276  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
277  Value lhsImag =
278  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs());
279  Value rhsReal =
280  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs());
281  Value rhsImag =
282  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs());
283 
284  // Smith's algorithm to divide complex numbers. It is just a bit smarter
285  // way to compute the following formula:
286  // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
287  // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
288  // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
289  // = ((lhsReal * rhsReal + lhsImag * rhsImag) +
290  // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
291  //
292  // Depending on whether |rhsReal| < |rhsImag| we compute either
293  // rhsRealImagRatio = rhsReal / rhsImag
294  // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
295  // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
296  // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
297  //
298  // or
299  //
300  // rhsImagRealRatio = rhsImag / rhsReal
301  // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
302  // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
303  // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
304  //
305  // See https://dl.acm.org/citation.cfm?id=368661 for more details.
306  Value rhsRealImagRatio =
307  rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag, fmf);
308  Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
309  loc, rhsImag,
310  rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal, fmf),
311  fmf);
312  Value realNumerator1 = rewriter.create<arith::AddFOp>(
313  loc,
314  rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio, fmf),
315  lhsImag, fmf);
316  Value resultReal1 = rewriter.create<arith::DivFOp>(loc, realNumerator1,
317  rhsRealImagDenom, fmf);
318  Value imagNumerator1 = rewriter.create<arith::SubFOp>(
319  loc,
320  rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio, fmf),
321  lhsReal, fmf);
322  Value resultImag1 = rewriter.create<arith::DivFOp>(loc, imagNumerator1,
323  rhsRealImagDenom, fmf);
324 
325  Value rhsImagRealRatio =
326  rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal, fmf);
327  Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
328  loc, rhsReal,
329  rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag, fmf),
330  fmf);
331  Value realNumerator2 = rewriter.create<arith::AddFOp>(
332  loc, lhsReal,
333  rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio, fmf),
334  fmf);
335  Value resultReal2 = rewriter.create<arith::DivFOp>(loc, realNumerator2,
336  rhsImagRealDenom, fmf);
337  Value imagNumerator2 = rewriter.create<arith::SubFOp>(
338  loc, lhsImag,
339  rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio, fmf),
340  fmf);
341  Value resultImag2 = rewriter.create<arith::DivFOp>(loc, imagNumerator2,
342  rhsImagRealDenom, fmf);
343 
344  // Consider corner cases.
345  // Case 1. Zero denominator, numerator contains at most one NaN value.
346  Value zero = rewriter.create<arith::ConstantOp>(
347  loc, elementType, rewriter.getZeroAttr(elementType));
348  Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsReal, fmf);
349  Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
350  loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
351  Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsImag, fmf);
352  Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
353  loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
354  Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
355  loc, arith::CmpFPredicate::ORD, lhsReal, zero);
356  Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
357  loc, arith::CmpFPredicate::ORD, lhsImag, zero);
358  Value lhsContainsNotNaNValue =
359  rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
360  Value resultIsInfinity = rewriter.create<arith::AndIOp>(
361  loc, lhsContainsNotNaNValue,
362  rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
363  Value inf = rewriter.create<arith::ConstantOp>(
364  loc, elementType,
365  rewriter.getFloatAttr(
366  elementType, APFloat::getInf(elementType.getFloatSemantics())));
367  Value infWithSignOfRhsReal =
368  rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
369  Value infinityResultReal =
370  rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal, fmf);
371  Value infinityResultImag =
372  rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag, fmf);
373 
374  // Case 2. Infinite numerator, finite denominator.
375  Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
376  loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
377  Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
378  loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
379  Value rhsFinite =
380  rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
381  Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsReal, fmf);
382  Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
383  loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
384  Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsImag, fmf);
385  Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
386  loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
387  Value lhsInfinite =
388  rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
389  Value infNumFiniteDenom =
390  rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
391  Value one = rewriter.create<arith::ConstantOp>(
392  loc, elementType, rewriter.getFloatAttr(elementType, 1));
393  Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
394  loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
395  lhsReal);
396  Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
397  loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
398  lhsImag);
399  Value lhsRealIsInfWithSignTimesRhsReal =
400  rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal, fmf);
401  Value lhsImagIsInfWithSignTimesRhsImag =
402  rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag, fmf);
403  Value resultReal3 = rewriter.create<arith::MulFOp>(
404  loc, inf,
405  rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
406  lhsImagIsInfWithSignTimesRhsImag, fmf),
407  fmf);
408  Value lhsRealIsInfWithSignTimesRhsImag =
409  rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag, fmf);
410  Value lhsImagIsInfWithSignTimesRhsReal =
411  rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal, fmf);
412  Value resultImag3 = rewriter.create<arith::MulFOp>(
413  loc, inf,
414  rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
415  lhsRealIsInfWithSignTimesRhsImag, fmf),
416  fmf);
417 
418  // Case 3: Finite numerator, infinite denominator.
419  Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
420  loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
421  Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
422  loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
423  Value lhsFinite =
424  rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
425  Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
426  loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
427  Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
428  loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
429  Value rhsInfinite =
430  rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
431  Value finiteNumInfiniteDenom =
432  rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
433  Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
434  loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
435  rhsReal);
436  Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
437  loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
438  rhsImag);
439  Value rhsRealIsInfWithSignTimesLhsReal =
440  rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign, fmf);
441  Value rhsImagIsInfWithSignTimesLhsImag =
442  rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign, fmf);
443  Value resultReal4 = rewriter.create<arith::MulFOp>(
444  loc, zero,
445  rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
446  rhsImagIsInfWithSignTimesLhsImag, fmf),
447  fmf);
448  Value rhsRealIsInfWithSignTimesLhsImag =
449  rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign, fmf);
450  Value rhsImagIsInfWithSignTimesLhsReal =
451  rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign, fmf);
452  Value resultImag4 = rewriter.create<arith::MulFOp>(
453  loc, zero,
454  rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
455  rhsImagIsInfWithSignTimesLhsReal, fmf),
456  fmf);
457 
458  Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
459  loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
460  Value resultReal = rewriter.create<arith::SelectOp>(
461  loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
462  Value resultImag = rewriter.create<arith::SelectOp>(
463  loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
464  Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>(
465  loc, finiteNumInfiniteDenom, resultReal4, resultReal);
466  Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>(
467  loc, finiteNumInfiniteDenom, resultImag4, resultImag);
468  Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>(
469  loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
470  Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>(
471  loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
472  Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>(
473  loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
474  Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>(
475  loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
476 
477  Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
478  loc, arith::CmpFPredicate::UNO, resultReal, zero);
479  Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
480  loc, arith::CmpFPredicate::UNO, resultImag, zero);
481  Value resultIsNaN =
482  rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
483  Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>(
484  loc, resultIsNaN, resultRealSpecialCase1, resultReal);
485  Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>(
486  loc, resultIsNaN, resultImagSpecialCase1, resultImag);
487 
488  rewriter.replaceOpWithNewOp<complex::CreateOp>(
489  op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
490  return success();
491  }
492 };
493 
494 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
496 
497  LogicalResult
498  matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
499  ConversionPatternRewriter &rewriter) const override {
500  auto loc = op.getLoc();
501  auto type = cast<ComplexType>(adaptor.getComplex().getType());
502  auto elementType = cast<FloatType>(type.getElementType());
503  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
504 
505  Value real =
506  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
507  Value imag =
508  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
509  Value expReal = rewriter.create<math::ExpOp>(loc, real, fmf.getValue());
510  Value cosImag = rewriter.create<math::CosOp>(loc, imag, fmf.getValue());
511  Value resultReal =
512  rewriter.create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue());
513  Value sinImag = rewriter.create<math::SinOp>(loc, imag, fmf.getValue());
514  Value resultImag =
515  rewriter.create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue());
516 
517  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
518  resultImag);
519  return success();
520  }
521 };
522 
523 struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
525 
526  LogicalResult
527  matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
528  ConversionPatternRewriter &rewriter) const override {
529  auto type = cast<ComplexType>(adaptor.getComplex().getType());
530  auto elementType = cast<FloatType>(type.getElementType());
531  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
532 
533  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
534  Value exp = b.create<complex::ExpOp>(adaptor.getComplex(), fmf.getValue());
535 
536  Value real = b.create<complex::ReOp>(elementType, exp);
537  Value one = b.create<arith::ConstantOp>(elementType,
538  b.getFloatAttr(elementType, 1));
539  Value realMinusOne = b.create<arith::SubFOp>(real, one, fmf.getValue());
540  Value imag = b.create<complex::ImOp>(elementType, exp);
541 
542  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne,
543  imag);
544  return success();
545  }
546 };
547 
548 struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
550 
551  LogicalResult
552  matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
553  ConversionPatternRewriter &rewriter) const override {
554  auto type = cast<ComplexType>(adaptor.getComplex().getType());
555  auto elementType = cast<FloatType>(type.getElementType());
556  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
557  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
558 
559  Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(),
560  fmf.getValue());
561  Value resultReal = b.create<math::LogOp>(elementType, abs, fmf.getValue());
562  Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
563  Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
564  Value resultImag =
565  b.create<math::Atan2Op>(elementType, imag, real, fmf.getValue());
566  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
567  resultImag);
568  return success();
569  }
570 };
571 
572 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
574 
575  LogicalResult
576  matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
577  ConversionPatternRewriter &rewriter) const override {
578  auto type = cast<ComplexType>(adaptor.getComplex().getType());
579  auto elementType = cast<FloatType>(type.getElementType());
580  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
581  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
582 
583  Value real = b.create<complex::ReOp>(adaptor.getComplex());
584  Value imag = b.create<complex::ImOp>(adaptor.getComplex());
585 
586  Value half = b.create<arith::ConstantOp>(elementType,
587  b.getFloatAttr(elementType, 0.5));
588  Value one = b.create<arith::ConstantOp>(elementType,
589  b.getFloatAttr(elementType, 1));
590  Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf);
591  Value absRealPlusOne = b.create<math::AbsFOp>(realPlusOne, fmf);
592  Value absImag = b.create<math::AbsFOp>(imag, fmf);
593 
594  Value maxAbs = b.create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf);
595  Value minAbs = b.create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf);
596 
597  Value useReal = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
598  realPlusOne, absImag, fmf);
599  Value maxMinusOne = b.create<arith::SubFOp>(maxAbs, one, fmf);
600  Value maxAbsOfRealPlusOneAndImagMinusOne =
601  b.create<arith::SelectOp>(useReal, real, maxMinusOne);
602  arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
603  fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
604  Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmfWithNaNInf);
605  Value logOfMaxAbsOfRealPlusOneAndImag =
606  b.create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf);
607  Value logOfSqrtPart = b.create<math::Log1pOp>(
608  b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmfWithNaNInf),
609  fmfWithNaNInf);
610  Value r = b.create<arith::AddFOp>(
611  b.create<arith::MulFOp>(half, logOfSqrtPart, fmfWithNaNInf),
612  logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf);
613  Value resultReal = b.create<arith::SelectOp>(
614  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf),
615  minAbs, r);
616  Value resultImag = b.create<math::Atan2Op>(imag, realPlusOne, fmf);
617  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
618  resultImag);
619  return success();
620  }
621 };
622 
623 struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
625 
626  LogicalResult
627  matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
628  ConversionPatternRewriter &rewriter) const override {
629  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
630  auto type = cast<ComplexType>(adaptor.getLhs().getType());
631  auto elementType = cast<FloatType>(type.getElementType());
632  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
633  auto fmfValue = fmf.getValue();
634 
635  Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
636  Value lhsRealAbs = b.create<math::AbsFOp>(lhsReal, fmfValue);
637  Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
638  Value lhsImagAbs = b.create<math::AbsFOp>(lhsImag, fmfValue);
639  Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
640  Value rhsRealAbs = b.create<math::AbsFOp>(rhsReal, fmfValue);
641  Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
642  Value rhsImagAbs = b.create<math::AbsFOp>(rhsImag, fmfValue);
643 
644  Value lhsRealTimesRhsReal =
645  b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
646  Value lhsRealTimesRhsRealAbs =
647  b.create<math::AbsFOp>(lhsRealTimesRhsReal, fmfValue);
648  Value lhsImagTimesRhsImag =
649  b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
650  Value lhsImagTimesRhsImagAbs =
651  b.create<math::AbsFOp>(lhsImagTimesRhsImag, fmfValue);
652  Value real = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
653  lhsImagTimesRhsImag, fmfValue);
654 
655  Value lhsImagTimesRhsReal =
656  b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
657  Value lhsImagTimesRhsRealAbs =
658  b.create<math::AbsFOp>(lhsImagTimesRhsReal, fmfValue);
659  Value lhsRealTimesRhsImag =
660  b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
661  Value lhsRealTimesRhsImagAbs =
662  b.create<math::AbsFOp>(lhsRealTimesRhsImag, fmfValue);
663  Value imag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
664  lhsRealTimesRhsImag, fmfValue);
665 
666  // Handle cases where the "naive" calculation results in NaN values.
667  Value realIsNan =
668  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
669  Value imagIsNan =
670  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
671  Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan);
672 
673  Value inf = b.create<arith::ConstantOp>(
674  elementType,
675  b.getFloatAttr(elementType,
676  APFloat::getInf(elementType.getFloatSemantics())));
677 
678  // Case 1. `lhsReal` or `lhsImag` are infinite.
679  Value lhsRealIsInf =
680  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
681  Value lhsImagIsInf =
682  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
683  Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
684  Value rhsRealIsNan =
685  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
686  Value rhsImagIsNan =
687  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
688  Value zero =
689  b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
690  Value one = b.create<arith::ConstantOp>(elementType,
691  b.getFloatAttr(elementType, 1));
692  Value lhsRealIsInfFloat =
693  b.create<arith::SelectOp>(lhsRealIsInf, one, zero);
694  lhsReal = b.create<arith::SelectOp>(
695  lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
696  lhsReal);
697  Value lhsImagIsInfFloat =
698  b.create<arith::SelectOp>(lhsImagIsInf, one, zero);
699  lhsImag = b.create<arith::SelectOp>(
700  lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
701  lhsImag);
702  Value lhsIsInfAndRhsRealIsNan =
703  b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
704  rhsReal = b.create<arith::SelectOp>(
705  lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
706  rhsReal);
707  Value lhsIsInfAndRhsImagIsNan =
708  b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
709  rhsImag = b.create<arith::SelectOp>(
710  lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
711  rhsImag);
712 
713  // Case 2. `rhsReal` or `rhsImag` are infinite.
714  Value rhsRealIsInf =
715  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
716  Value rhsImagIsInf =
717  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
718  Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
719  Value lhsRealIsNan =
720  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
721  Value lhsImagIsNan =
722  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
723  Value rhsRealIsInfFloat =
724  b.create<arith::SelectOp>(rhsRealIsInf, one, zero);
725  rhsReal = b.create<arith::SelectOp>(
726  rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
727  rhsReal);
728  Value rhsImagIsInfFloat =
729  b.create<arith::SelectOp>(rhsImagIsInf, one, zero);
730  rhsImag = b.create<arith::SelectOp>(
731  rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
732  rhsImag);
733  Value rhsIsInfAndLhsRealIsNan =
734  b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
735  lhsReal = b.create<arith::SelectOp>(
736  rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
737  lhsReal);
738  Value rhsIsInfAndLhsImagIsNan =
739  b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
740  lhsImag = b.create<arith::SelectOp>(
741  rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
742  lhsImag);
743  Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf);
744 
745  // Case 3. One of the pairwise products of left hand side with right hand
746  // side is infinite.
747  Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>(
748  arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
749  Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>(
750  arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
751  Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
752  lhsImagTimesRhsImagIsInf);
753  Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>(
754  arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
755  isSpecialCase =
756  b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
757  Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>(
758  arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
759  isSpecialCase =
760  b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
761  Type i1Type = b.getI1Type();
762  Value notRecalc = b.create<arith::XOrIOp>(
763  recalc,
764  b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
765  isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc);
766  Value isSpecialCaseAndLhsRealIsNan =
767  b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
768  lhsReal = b.create<arith::SelectOp>(
769  isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
770  lhsReal);
771  Value isSpecialCaseAndLhsImagIsNan =
772  b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
773  lhsImag = b.create<arith::SelectOp>(
774  isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
775  lhsImag);
776  Value isSpecialCaseAndRhsRealIsNan =
777  b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
778  rhsReal = b.create<arith::SelectOp>(
779  isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
780  rhsReal);
781  Value isSpecialCaseAndRhsImagIsNan =
782  b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
783  rhsImag = b.create<arith::SelectOp>(
784  isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
785  rhsImag);
786  recalc = b.create<arith::OrIOp>(recalc, isSpecialCase);
787  recalc = b.create<arith::AndIOp>(isNan, recalc);
788 
789  // Recalculate real part.
790  lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
791  lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
792  Value newReal = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
793  lhsImagTimesRhsImag, fmfValue);
794  real = b.create<arith::SelectOp>(
795  recalc, b.create<arith::MulFOp>(inf, newReal, fmfValue), real);
796 
797  // Recalculate imag part.
798  lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
799  lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
800  Value newImag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
801  lhsRealTimesRhsImag, fmfValue);
802  imag = b.create<arith::SelectOp>(
803  recalc, b.create<arith::MulFOp>(inf, newImag, fmfValue), imag);
804 
805  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
806  return success();
807  }
808 };
809 
810 struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
812 
813  LogicalResult
814  matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
815  ConversionPatternRewriter &rewriter) const override {
816  auto loc = op.getLoc();
817  auto type = cast<ComplexType>(adaptor.getComplex().getType());
818  auto elementType = cast<FloatType>(type.getElementType());
819 
820  Value real =
821  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
822  Value imag =
823  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
824  Value negReal = rewriter.create<arith::NegFOp>(loc, real);
825  Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
826  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
827  return success();
828  }
829 };
830 
831 struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
832  using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
833 
834  std::pair<Value, Value> combine(Location loc, Value scaledExp,
835  Value reciprocalExp, Value sin, Value cos,
836  ConversionPatternRewriter &rewriter,
837  arith::FastMathFlagsAttr fmf) const override {
838  // Complex sine is defined as;
839  // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
840  // Plugging in:
841  // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
842  // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
843  // and defining t := exp(y)
844  // We get:
845  // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
846  // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
847  Value sum =
848  rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf);
849  Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin, fmf);
850  Value diff =
851  rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf);
852  Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos, fmf);
853  return {resultReal, resultImag};
854  }
855 };
856 
857 // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
858 struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
860 
861  LogicalResult
862  matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
863  ConversionPatternRewriter &rewriter) const override {
864  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
865 
866  auto type = cast<ComplexType>(op.getType());
867  auto elementType = cast<FloatType>(type.getElementType());
868  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
869 
870  auto cst = [&](APFloat v) {
871  return b.create<arith::ConstantOp>(elementType,
872  b.getFloatAttr(elementType, v));
873  };
874  const auto &floatSemantics = elementType.getFloatSemantics();
875  Value zero = cst(APFloat::getZero(floatSemantics));
876  Value half = b.create<arith::ConstantOp>(elementType,
877  b.getFloatAttr(elementType, 0.5));
878 
879  Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
880  Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
881  Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
882  Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
883  Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf);
884  Value cos = b.create<math::CosOp>(sqrtArg, fmf);
885  Value sin = b.create<math::SinOp>(sqrtArg, fmf);
886  // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply
887  // 0 * inf.
888  Value sinIsZero =
889  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf);
890 
891  Value resultReal = b.create<arith::MulFOp>(absSqrt, cos, fmf);
892  Value resultImag = b.create<arith::SelectOp>(
893  sinIsZero, zero, b.create<arith::MulFOp>(absSqrt, sin, fmf));
894  if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
895  arith::FastMathFlags::ninf)) {
896  Value inf = cst(APFloat::getInf(floatSemantics));
897  Value negInf = cst(APFloat::getInf(floatSemantics, true));
898  Value nan = cst(APFloat::getNaN(floatSemantics));
899  Value absImag = b.create<math::AbsFOp>(elementType, imag, fmf);
900 
901  Value absImagIsInf =
902  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
903  Value absImagIsNotInf =
904  b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf);
905  Value realIsInf =
906  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf);
907  Value realIsNegInf =
908  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf);
909 
910  resultReal = b.create<arith::SelectOp>(
911  b.create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero,
912  resultReal);
913  resultReal = b.create<arith::SelectOp>(
914  b.create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal);
915 
916  Value imagSignInf = b.create<math::CopySignOp>(inf, imag, fmf);
917  resultImag = b.create<arith::SelectOp>(
918  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt),
919  nan, resultImag);
920  resultImag = b.create<arith::SelectOp>(
921  b.create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf,
922  resultImag);
923  }
924 
925  Value resultIsZero =
926  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
927  resultReal = b.create<arith::SelectOp>(resultIsZero, zero, resultReal);
928  resultImag = b.create<arith::SelectOp>(resultIsZero, zero, resultImag);
929 
930  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
931  resultImag);
932  return success();
933  }
934 };
935 
936 struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
938 
939  LogicalResult
940  matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
941  ConversionPatternRewriter &rewriter) const override {
942  auto type = cast<ComplexType>(adaptor.getComplex().getType());
943  auto elementType = cast<FloatType>(type.getElementType());
944  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
945  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
946 
947  Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
948  Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
949  Value zero =
950  b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
951  Value realIsZero =
952  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
953  Value imagIsZero =
954  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
955  Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
956  auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), fmf);
957  Value realSign = b.create<arith::DivFOp>(real, abs, fmf);
958  Value imagSign = b.create<arith::DivFOp>(imag, abs, fmf);
959  Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
960  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
961  adaptor.getComplex(), sign);
962  return success();
963  }
964 };
965 
966 template <typename Op>
967 struct TanTanhOpConversion : public OpConversionPattern<Op> {
969 
970  LogicalResult
971  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
972  ConversionPatternRewriter &rewriter) const override {
973  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
974  auto loc = op.getLoc();
975  auto type = cast<ComplexType>(adaptor.getComplex().getType());
976  auto elementType = cast<FloatType>(type.getElementType());
977  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
978  const auto &floatSemantics = elementType.getFloatSemantics();
979 
980  Value real =
981  b.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
982  Value imag =
983  b.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
984  Value negOne = b.create<arith::ConstantOp>(
985  elementType, b.getFloatAttr(elementType, -1.0));
986 
987  if constexpr (std::is_same_v<Op, complex::TanOp>) {
988  // tan(x+yi) = -i*tanh(-y + xi)
989  std::swap(real, imag);
990  real = b.create<arith::MulFOp>(real, negOne, fmf);
991  }
992 
993  auto cst = [&](APFloat v) {
994  return b.create<arith::ConstantOp>(elementType,
995  b.getFloatAttr(elementType, v));
996  };
997  Value inf = cst(APFloat::getInf(floatSemantics));
998  Value four = b.create<arith::ConstantOp>(elementType,
999  b.getFloatAttr(elementType, 4.0));
1000  Value twoReal = b.create<arith::AddFOp>(real, real, fmf);
1001  Value negTwoReal = b.create<arith::MulFOp>(negOne, twoReal, fmf);
1002 
1003  Value expTwoRealMinusOne = b.create<math::ExpM1Op>(twoReal, fmf);
1004  Value expNegTwoRealMinusOne = b.create<math::ExpM1Op>(negTwoReal, fmf);
1005  Value realNum =
1006  b.create<arith::SubFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
1007 
1008  Value cosImag = b.create<math::CosOp>(imag, fmf);
1009  Value cosImagSq = b.create<arith::MulFOp>(cosImag, cosImag, fmf);
1010  Value twoCosTwoImagPlusOne = b.create<arith::MulFOp>(cosImagSq, four, fmf);
1011  Value sinImag = b.create<math::SinOp>(imag, fmf);
1012 
1013  Value imagNum = b.create<arith::MulFOp>(
1014  four, b.create<arith::MulFOp>(cosImag, sinImag, fmf), fmf);
1015 
1016  Value expSumMinusTwo =
1017  b.create<arith::AddFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
1018  Value denom =
1019  b.create<arith::AddFOp>(expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
1020 
1021  Value isInf = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
1022  expSumMinusTwo, inf, fmf);
1023  Value realLimit = b.create<math::CopySignOp>(negOne, real, fmf);
1024 
1025  Value resultReal = b.create<arith::SelectOp>(
1026  isInf, realLimit, b.create<arith::DivFOp>(realNum, denom, fmf));
1027  Value resultImag = b.create<arith::DivFOp>(imagNum, denom, fmf);
1028 
1029  if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1030  arith::FastMathFlags::ninf)) {
1031  Value absReal = b.create<math::AbsFOp>(real, fmf);
1032  Value zero = b.create<arith::ConstantOp>(
1033  elementType, b.getFloatAttr(elementType, 0.0));
1034  Value nan = cst(APFloat::getNaN(floatSemantics));
1035 
1036  Value absRealIsInf =
1037  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
1038  Value imagIsZero =
1039  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
1040  Value absRealIsNotInf = b.create<arith::XOrIOp>(
1041  absRealIsInf, b.create<arith::ConstantIntOp>(true, /*width=*/1));
1042 
1043  Value imagNumIsNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO,
1044  imagNum, imagNum, fmf);
1045  Value resultRealIsNaN =
1046  b.create<arith::AndIOp>(imagNumIsNaN, absRealIsNotInf);
1047  Value resultImagIsZero = b.create<arith::OrIOp>(
1048  imagIsZero, b.create<arith::AndIOp>(absRealIsInf, imagNumIsNaN));
1049 
1050  resultReal = b.create<arith::SelectOp>(resultRealIsNaN, nan, resultReal);
1051  resultImag =
1052  b.create<arith::SelectOp>(resultImagIsZero, zero, resultImag);
1053  }
1054 
1055  if constexpr (std::is_same_v<Op, complex::TanOp>) {
1056  // tan(x+yi) = -i*tanh(-y + xi)
1057  std::swap(resultReal, resultImag);
1058  resultImag = b.create<arith::MulFOp>(resultImag, negOne, fmf);
1059  }
1060 
1061  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
1062  resultImag);
1063  return success();
1064  }
1065 };
1066 
1067 struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
1069 
1070  LogicalResult
1071  matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
1072  ConversionPatternRewriter &rewriter) const override {
1073  auto loc = op.getLoc();
1074  auto type = cast<ComplexType>(adaptor.getComplex().getType());
1075  auto elementType = cast<FloatType>(type.getElementType());
1076  Value real =
1077  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
1078  Value imag =
1079  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
1080  Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag);
1081 
1082  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);
1083 
1084  return success();
1085  }
1086 };
1087 
1088 /// Converts lhs^y = (a+bi)^(c+di) to
1089 /// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
1090 /// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
1091 static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
1092  ComplexType type, Value lhs, Value c, Value d,
1093  arith::FastMathFlags fmf) {
1094  auto elementType = cast<FloatType>(type.getElementType());
1095 
1096  Value a = builder.create<complex::ReOp>(lhs);
1097  Value b = builder.create<complex::ImOp>(lhs);
1098 
1099  Value abs = builder.create<complex::AbsOp>(lhs, fmf);
1100  Value absToC = builder.create<math::PowFOp>(abs, c, fmf);
1101 
1102  Value negD = builder.create<arith::NegFOp>(d, fmf);
1103  Value argLhs = builder.create<math::Atan2Op>(b, a, fmf);
1104  Value negDArgLhs = builder.create<arith::MulFOp>(negD, argLhs, fmf);
1105  Value expNegDArgLhs = builder.create<math::ExpOp>(negDArgLhs, fmf);
1106 
1107  Value coeff = builder.create<arith::MulFOp>(absToC, expNegDArgLhs, fmf);
1108  Value lnAbs = builder.create<math::LogOp>(abs, fmf);
1109  Value cArgLhs = builder.create<arith::MulFOp>(c, argLhs, fmf);
1110  Value dLnAbs = builder.create<arith::MulFOp>(d, lnAbs, fmf);
1111  Value q = builder.create<arith::AddFOp>(cArgLhs, dLnAbs, fmf);
1112  Value cosQ = builder.create<math::CosOp>(q, fmf);
1113  Value sinQ = builder.create<math::SinOp>(q, fmf);
1114 
1115  Value inf = builder.create<arith::ConstantOp>(
1116  elementType,
1117  builder.getFloatAttr(elementType,
1118  APFloat::getInf(elementType.getFloatSemantics())));
1119  Value zero = builder.create<arith::ConstantOp>(
1120  elementType, builder.getFloatAttr(elementType, 0.0));
1121  Value one = builder.create<arith::ConstantOp>(
1122  elementType, builder.getFloatAttr(elementType, 1.0));
1123  Value complexOne = builder.create<complex::CreateOp>(type, one, zero);
1124  Value complexZero = builder.create<complex::CreateOp>(type, zero, zero);
1125  Value complexInf = builder.create<complex::CreateOp>(type, inf, zero);
1126 
1127  // Case 0:
1128  // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
1129  // Branch Cuts for Complex Elementary Functions or Much Ado About
1130  // Nothing's Sign Bit, W. Kahan, Section 10.
1131  Value absEqZero =
1132  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, abs, zero, fmf);
1133  Value dEqZero =
1134  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero, fmf);
1135  Value cEqZero =
1136  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero, fmf);
1137  Value bEqZero =
1138  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, b, zero, fmf);
1139 
1140  Value zeroLeC =
1141  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLE, zero, c, fmf);
1142  Value coeffCosQ = builder.create<arith::MulFOp>(coeff, cosQ, fmf);
1143  Value coeffSinQ = builder.create<arith::MulFOp>(coeff, sinQ, fmf);
1144  Value complexOneOrZero =
1145  builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero);
1146  Value coeffCosSin =
1147  builder.create<complex::CreateOp>(type, coeffCosQ, coeffSinQ);
1148  Value cutoff0 = builder.create<arith::SelectOp>(
1149  builder.create<arith::AndIOp>(
1150  builder.create<arith::AndIOp>(absEqZero, dEqZero), zeroLeC),
1151  complexOneOrZero, coeffCosSin);
1152 
1153  // Case 1:
1154  // x^0 is defined to be 1 for any x, see
1155  // Branch Cuts for Complex Elementary Functions or Much Ado About
1156  // Nothing's Sign Bit, W. Kahan, Section 10.
1157  Value rhsEqZero = builder.create<arith::AndIOp>(cEqZero, dEqZero);
1158  Value cutoff1 =
1159  builder.create<arith::SelectOp>(rhsEqZero, complexOne, cutoff0);
1160 
1161  // Case 2:
1162  // 1^(c + d*i) = 1 + 0*i
1163  Value lhsEqOne = builder.create<arith::AndIOp>(
1164  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf),
1165  bEqZero);
1166  Value cutoff2 =
1167  builder.create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1);
1168 
1169  // Case 3:
1170  // inf^(c + 0*i) = inf + 0*i, c > 0
1171  Value lhsEqInf = builder.create<arith::AndIOp>(
1172  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf),
1173  bEqZero);
1174  Value rhsGt0 = builder.create<arith::AndIOp>(
1175  dEqZero,
1176  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf));
1177  Value cutoff3 = builder.create<arith::SelectOp>(
1178  builder.create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2);
1179 
1180  // Case 4:
1181  // inf^(c + 0*i) = 0 + 0*i, c < 0
1182  Value rhsLt0 = builder.create<arith::AndIOp>(
1183  dEqZero,
1184  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf));
1185  Value cutoff4 = builder.create<arith::SelectOp>(
1186  builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3);
1187 
1188  return cutoff4;
1189 }
1190 
1191 struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
1193 
1194  LogicalResult
1195  matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
1196  ConversionPatternRewriter &rewriter) const override {
1197  mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
1198  auto type = cast<ComplexType>(adaptor.getLhs().getType());
1199  auto elementType = cast<FloatType>(type.getElementType());
1200 
1201  Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
1202  Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
1203 
1204  rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
1205  c, d, op.getFastmath())});
1206  return success();
1207  }
1208 };
1209 
1210 struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
1212 
1213  LogicalResult
1214  matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
1215  ConversionPatternRewriter &rewriter) const override {
1216  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1217  auto type = cast<ComplexType>(adaptor.getComplex().getType());
1218  auto elementType = cast<FloatType>(type.getElementType());
1219 
1220  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
1221 
1222  auto cst = [&](APFloat v) {
1223  return b.create<arith::ConstantOp>(elementType,
1224  b.getFloatAttr(elementType, v));
1225  };
1226  const auto &floatSemantics = elementType.getFloatSemantics();
1227  Value zero = cst(APFloat::getZero(floatSemantics));
1228  Value inf = cst(APFloat::getInf(floatSemantics));
1229  Value negHalf = b.create<arith::ConstantOp>(
1230  elementType, b.getFloatAttr(elementType, -0.5));
1231  Value nan = cst(APFloat::getNaN(floatSemantics));
1232 
1233  Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
1234  Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
1235  Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
1236  Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
1237  Value rsqrtArg = b.create<arith::MulFOp>(argArg, negHalf, fmf);
1238  Value cos = b.create<math::CosOp>(rsqrtArg, fmf);
1239  Value sin = b.create<math::SinOp>(rsqrtArg, fmf);
1240 
1241  Value resultReal = b.create<arith::MulFOp>(absRsqrt, cos, fmf);
1242  Value resultImag = b.create<arith::MulFOp>(absRsqrt, sin, fmf);
1243 
1244  if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1245  arith::FastMathFlags::ninf)) {
1246  Value negOne = b.create<arith::ConstantOp>(
1247  elementType, b.getFloatAttr(elementType, -1));
1248 
1249  Value realSignedZero = b.create<math::CopySignOp>(zero, real, fmf);
1250  Value imagSignedZero = b.create<math::CopySignOp>(zero, imag, fmf);
1251  Value negImagSignedZero =
1252  b.create<arith::MulFOp>(negOne, imagSignedZero, fmf);
1253 
1254  Value absReal = b.create<math::AbsFOp>(real, fmf);
1255  Value absImag = b.create<math::AbsFOp>(imag, fmf);
1256 
1257  Value absImagIsInf =
1258  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
1259  Value realIsNan =
1260  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf);
1261  Value realIsInf =
1262  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
1263  Value inIsNanInf = b.create<arith::AndIOp>(absImagIsInf, realIsNan);
1264 
1265  Value resultIsZero = b.create<arith::OrIOp>(inIsNanInf, realIsInf);
1266 
1267  resultReal =
1268  b.create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal);
1269  resultImag = b.create<arith::SelectOp>(resultIsZero, negImagSignedZero,
1270  resultImag);
1271  }
1272 
1273  Value isRealZero =
1274  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf);
1275  Value isImagZero =
1276  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
1277  Value isZero = b.create<arith::AndIOp>(isRealZero, isImagZero);
1278 
1279  resultReal = b.create<arith::SelectOp>(isZero, inf, resultReal);
1280  resultImag = b.create<arith::SelectOp>(isZero, nan, resultImag);
1281 
1282  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
1283  resultImag);
1284  return success();
1285  }
1286 };
1287 
1288 struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
1290 
1291  LogicalResult
1292  matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
1293  ConversionPatternRewriter &rewriter) const override {
1294  auto loc = op.getLoc();
1295  auto type = op.getType();
1296  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
1297 
1298  Value real =
1299  rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
1300  Value imag =
1301  rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
1302 
1303  rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf);
1304 
1305  return success();
1306  }
1307 };
1308 
1309 } // namespace
1310 
1312  RewritePatternSet &patterns) {
1313  // clang-format off
1314  patterns.add<
1315  AbsOpConversion,
1316  AngleOpConversion,
1317  Atan2OpConversion,
1318  BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1319  BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
1320  ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
1321  ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
1322  ConjOpConversion,
1323  CosOpConversion,
1324  DivOpConversion,
1325  ExpOpConversion,
1326  Expm1OpConversion,
1327  Log1pOpConversion,
1328  LogOpConversion,
1329  MulOpConversion,
1330  NegOpConversion,
1331  SignOpConversion,
1332  SinOpConversion,
1333  SqrtOpConversion,
1334  TanTanhOpConversion<complex::TanOp>,
1335  TanTanhOpConversion<complex::TanhOp>,
1336  PowOpConversion,
1337  RsqrtOpConversion
1338  >(patterns.getContext());
1339  // clang-format on
1340 }
1341 
1342 namespace {
1343 struct ConvertComplexToStandardPass
1344  : public impl::ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
1345  void runOnOperation() override;
1346 };
1347 
1348 void ConvertComplexToStandardPass::runOnOperation() {
1349  // Convert to the Standard dialect using the converter defined above.
1350  RewritePatternSet patterns(&getContext());
1352 
1353  ConversionTarget target(getContext());
1354  target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
1355  target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1356  if (failed(
1357  applyPartialConversion(getOperation(), target, std::move(patterns))))
1358  signalPassFailure();
1359 }
1360 } // namespace
1361 
1363  return std::make_unique<ConvertComplexToStandardPass>();
1364 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:254
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:277
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:347
IntegerType getI1Type()
Definition: Builders.cpp:89
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:480
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
This provides public APIs that all operations should have.
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
Include the generated interface declarations.
void populateComplexToStandardConversionPatterns(RewritePatternSet &patterns)
Populate the given list with patterns that convert from Complex to Standard.
std::unique_ptr< Pass > createConvertComplexToStandardPass()
Create a pass to convert Complex operations to the Standard dialect.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.