MLIR  20.0.0git
ComplexToStandard.cpp
Go to the documentation of this file.
1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Pass/Pass.h"
18 #include <memory>
19 #include <type_traits>
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARD
23 #include "mlir/Conversion/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 
28 namespace {
29 
30 enum class AbsFn { abs, sqrt, rsqrt };
31 
32 // Returns the absolute value, its square root or its reciprocal square root.
33 Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
34  ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) {
35  Value one = b.create<arith::ConstantOp>(real.getType(),
36  b.getFloatAttr(real.getType(), 1.0));
37 
38  Value absReal = b.create<math::AbsFOp>(real, fmf);
39  Value absImag = b.create<math::AbsFOp>(imag, fmf);
40 
41  Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
42  Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
43 
44  // The lowering below requires NaNs and infinities to work correctly.
45  arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
46  fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
47  Value ratio = b.create<arith::DivFOp>(min, max, fmfWithNaNInf);
48  Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmfWithNaNInf);
49  Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmfWithNaNInf);
50  Value result;
51 
52  if (fn == AbsFn::rsqrt) {
53  ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
54  min = b.create<math::RsqrtOp>(min, fmfWithNaNInf);
55  max = b.create<math::RsqrtOp>(max, fmfWithNaNInf);
56  }
57 
58  if (fn == AbsFn::sqrt) {
59  Value quarter = b.create<arith::ConstantOp>(
60  real.getType(), b.getFloatAttr(real.getType(), 0.25));
61  // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
62  Value sqrt = b.create<math::SqrtOp>(max, fmfWithNaNInf);
63  Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmfWithNaNInf);
64  result = b.create<arith::MulFOp>(sqrt, p025, fmfWithNaNInf);
65  } else {
66  Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
67  result = b.create<arith::MulFOp>(max, sqrt, fmfWithNaNInf);
68  }
69 
70  Value isNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result,
71  result, fmfWithNaNInf);
72  return b.create<arith::SelectOp>(isNaN, min, result);
73 }
74 
75 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
77 
78  LogicalResult
79  matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
80  ConversionPatternRewriter &rewriter) const override {
81  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
82 
83  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
84 
85  Value real = b.create<complex::ReOp>(adaptor.getComplex());
86  Value imag = b.create<complex::ImOp>(adaptor.getComplex());
87  rewriter.replaceOp(op, computeAbs(real, imag, fmf, b));
88 
89  return success();
90  }
91 };
92 
93 // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
94 struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
96 
97  LogicalResult
98  matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
99  ConversionPatternRewriter &rewriter) const override {
100  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
101 
102  auto type = cast<ComplexType>(op.getType());
103  Type elementType = type.getElementType();
104  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
105 
106  Value lhs = adaptor.getLhs();
107  Value rhs = adaptor.getRhs();
108 
109  Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs, fmf);
110  Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs, fmf);
111  Value rhsSquaredPlusLhsSquared =
112  b.create<complex::AddOp>(type, rhsSquared, lhsSquared, fmf);
113  Value sqrtOfRhsSquaredPlusLhsSquared =
114  b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared, fmf);
115 
116  Value zero =
117  b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
118  Value one = b.create<arith::ConstantOp>(elementType,
119  b.getFloatAttr(elementType, 1));
120  Value i = b.create<complex::CreateOp>(type, zero, one);
121  Value iTimesLhs = b.create<complex::MulOp>(i, lhs, fmf);
122  Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs, fmf);
123 
124  Value divResult = b.create<complex::DivOp>(
125  rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
126  Value logResult = b.create<complex::LogOp>(divResult, fmf);
127 
128  Value negativeOne = b.create<arith::ConstantOp>(
129  elementType, b.getFloatAttr(elementType, -1));
130  Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
131 
132  rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf);
133  return success();
134  }
135 };
136 
137 template <typename ComparisonOp, arith::CmpFPredicate p>
138 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
140  using ResultCombiner =
141  std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
142  arith::AndIOp, arith::OrIOp>;
143 
144  LogicalResult
145  matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
146  ConversionPatternRewriter &rewriter) const override {
147  auto loc = op.getLoc();
148  auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
149 
150  Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
151  Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
152  Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
153  Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
154  Value realComparison =
155  rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
156  Value imagComparison =
157  rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
158 
159  rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
160  imagComparison);
161  return success();
162  }
163 };
164 
165 // Default conversion which applies the BinaryStandardOp separately on the real
166 // and imaginary parts. Can for example be used for complex::AddOp and
167 // complex::SubOp.
168 template <typename BinaryComplexOp, typename BinaryStandardOp>
169 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
171 
172  LogicalResult
173  matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
174  ConversionPatternRewriter &rewriter) const override {
175  auto type = cast<ComplexType>(adaptor.getLhs().getType());
176  auto elementType = cast<FloatType>(type.getElementType());
177  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
178  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
179 
180  Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
181  Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
182  Value resultReal = b.create<BinaryStandardOp>(elementType, realLhs, realRhs,
183  fmf.getValue());
184  Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
185  Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
186  Value resultImag = b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs,
187  fmf.getValue());
188  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
189  resultImag);
190  return success();
191  }
192 };
193 
194 template <typename TrigonometricOp>
195 struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
196  using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
197 
199 
200  LogicalResult
201  matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
202  ConversionPatternRewriter &rewriter) const override {
203  auto loc = op.getLoc();
204  auto type = cast<ComplexType>(adaptor.getComplex().getType());
205  auto elementType = cast<FloatType>(type.getElementType());
206  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
207 
208  Value real =
209  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
210  Value imag =
211  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
212 
213  // Trigonometric ops use a set of common building blocks to convert to real
214  // ops. Here we create these building blocks and call into an op-specific
215  // implementation in the subclass to combine them.
216  Value half = rewriter.create<arith::ConstantOp>(
217  loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
218  Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf);
219  Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp, fmf);
220  Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp, fmf);
221  Value sin = rewriter.create<math::SinOp>(loc, real, fmf);
222  Value cos = rewriter.create<math::CosOp>(loc, real, fmf);
223 
224  auto resultPair =
225  combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
226 
227  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
228  resultPair.second);
229  return success();
230  }
231 
232  virtual std::pair<Value, Value>
233  combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
234  Value cos, ConversionPatternRewriter &rewriter,
235  arith::FastMathFlagsAttr fmf) const = 0;
236 };
237 
238 struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
239  using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
240 
241  std::pair<Value, Value> combine(Location loc, Value scaledExp,
242  Value reciprocalExp, Value sin, Value cos,
243  ConversionPatternRewriter &rewriter,
244  arith::FastMathFlagsAttr fmf) const override {
245  // Complex cosine is defined as;
246  // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
247  // Plugging in:
248  // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
249  // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
250  // and defining t := exp(y)
251  // We get:
252  // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
253  // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
254  Value sum =
255  rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf);
256  Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos, fmf);
257  Value diff =
258  rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf);
259  Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin, fmf);
260  return {resultReal, resultImag};
261  }
262 };
263 
264 struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
266 
267  LogicalResult
268  matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
269  ConversionPatternRewriter &rewriter) const override {
270  auto loc = op.getLoc();
271  auto type = cast<ComplexType>(adaptor.getLhs().getType());
272  auto elementType = cast<FloatType>(type.getElementType());
273  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
274 
275  Value lhsReal =
276  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
277  Value lhsImag =
278  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs());
279  Value rhsReal =
280  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs());
281  Value rhsImag =
282  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs());
283 
284  // Smith's algorithm to divide complex numbers. It is just a bit smarter
285  // way to compute the following formula:
286  // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
287  // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
288  // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
289  // = ((lhsReal * rhsReal + lhsImag * rhsImag) +
290  // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
291  //
292  // Depending on whether |rhsReal| < |rhsImag| we compute either
293  // rhsRealImagRatio = rhsReal / rhsImag
294  // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
295  // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
296  // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
297  //
298  // or
299  //
300  // rhsImagRealRatio = rhsImag / rhsReal
301  // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
302  // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
303  // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
304  //
305  // See https://dl.acm.org/citation.cfm?id=368661 for more details.
306  Value rhsRealImagRatio =
307  rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag, fmf);
308  Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
309  loc, rhsImag,
310  rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal, fmf),
311  fmf);
312  Value realNumerator1 = rewriter.create<arith::AddFOp>(
313  loc,
314  rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio, fmf),
315  lhsImag, fmf);
316  Value resultReal1 = rewriter.create<arith::DivFOp>(loc, realNumerator1,
317  rhsRealImagDenom, fmf);
318  Value imagNumerator1 = rewriter.create<arith::SubFOp>(
319  loc,
320  rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio, fmf),
321  lhsReal, fmf);
322  Value resultImag1 = rewriter.create<arith::DivFOp>(loc, imagNumerator1,
323  rhsRealImagDenom, fmf);
324 
325  Value rhsImagRealRatio =
326  rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal, fmf);
327  Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
328  loc, rhsReal,
329  rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag, fmf),
330  fmf);
331  Value realNumerator2 = rewriter.create<arith::AddFOp>(
332  loc, lhsReal,
333  rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio, fmf),
334  fmf);
335  Value resultReal2 = rewriter.create<arith::DivFOp>(loc, realNumerator2,
336  rhsImagRealDenom, fmf);
337  Value imagNumerator2 = rewriter.create<arith::SubFOp>(
338  loc, lhsImag,
339  rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio, fmf),
340  fmf);
341  Value resultImag2 = rewriter.create<arith::DivFOp>(loc, imagNumerator2,
342  rhsImagRealDenom, fmf);
343 
344  // Consider corner cases.
345  // Case 1. Zero denominator, numerator contains at most one NaN value.
346  Value zero = rewriter.create<arith::ConstantOp>(
347  loc, elementType, rewriter.getZeroAttr(elementType));
348  Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsReal, fmf);
349  Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
350  loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
351  Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsImag, fmf);
352  Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
353  loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
354  Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
355  loc, arith::CmpFPredicate::ORD, lhsReal, zero);
356  Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
357  loc, arith::CmpFPredicate::ORD, lhsImag, zero);
358  Value lhsContainsNotNaNValue =
359  rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
360  Value resultIsInfinity = rewriter.create<arith::AndIOp>(
361  loc, lhsContainsNotNaNValue,
362  rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
363  Value inf = rewriter.create<arith::ConstantOp>(
364  loc, elementType,
365  rewriter.getFloatAttr(
366  elementType, APFloat::getInf(elementType.getFloatSemantics())));
367  Value infWithSignOfRhsReal =
368  rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
369  Value infinityResultReal =
370  rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal, fmf);
371  Value infinityResultImag =
372  rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag, fmf);
373 
374  // Case 2. Infinite numerator, finite denominator.
375  Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
376  loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
377  Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
378  loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
379  Value rhsFinite =
380  rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
381  Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsReal, fmf);
382  Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
383  loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
384  Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsImag, fmf);
385  Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
386  loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
387  Value lhsInfinite =
388  rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
389  Value infNumFiniteDenom =
390  rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
391  Value one = rewriter.create<arith::ConstantOp>(
392  loc, elementType, rewriter.getFloatAttr(elementType, 1));
393  Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
394  loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
395  lhsReal);
396  Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
397  loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
398  lhsImag);
399  Value lhsRealIsInfWithSignTimesRhsReal =
400  rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal, fmf);
401  Value lhsImagIsInfWithSignTimesRhsImag =
402  rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag, fmf);
403  Value resultReal3 = rewriter.create<arith::MulFOp>(
404  loc, inf,
405  rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
406  lhsImagIsInfWithSignTimesRhsImag, fmf),
407  fmf);
408  Value lhsRealIsInfWithSignTimesRhsImag =
409  rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag, fmf);
410  Value lhsImagIsInfWithSignTimesRhsReal =
411  rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal, fmf);
412  Value resultImag3 = rewriter.create<arith::MulFOp>(
413  loc, inf,
414  rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
415  lhsRealIsInfWithSignTimesRhsImag, fmf),
416  fmf);
417 
418  // Case 3: Finite numerator, infinite denominator.
419  Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
420  loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
421  Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
422  loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
423  Value lhsFinite =
424  rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
425  Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
426  loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
427  Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
428  loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
429  Value rhsInfinite =
430  rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
431  Value finiteNumInfiniteDenom =
432  rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
433  Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
434  loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
435  rhsReal);
436  Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
437  loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
438  rhsImag);
439  Value rhsRealIsInfWithSignTimesLhsReal =
440  rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign, fmf);
441  Value rhsImagIsInfWithSignTimesLhsImag =
442  rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign, fmf);
443  Value resultReal4 = rewriter.create<arith::MulFOp>(
444  loc, zero,
445  rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
446  rhsImagIsInfWithSignTimesLhsImag, fmf),
447  fmf);
448  Value rhsRealIsInfWithSignTimesLhsImag =
449  rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign, fmf);
450  Value rhsImagIsInfWithSignTimesLhsReal =
451  rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign, fmf);
452  Value resultImag4 = rewriter.create<arith::MulFOp>(
453  loc, zero,
454  rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
455  rhsImagIsInfWithSignTimesLhsReal, fmf),
456  fmf);
457 
458  Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
459  loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
460  Value resultReal = rewriter.create<arith::SelectOp>(
461  loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
462  Value resultImag = rewriter.create<arith::SelectOp>(
463  loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
464  Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>(
465  loc, finiteNumInfiniteDenom, resultReal4, resultReal);
466  Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>(
467  loc, finiteNumInfiniteDenom, resultImag4, resultImag);
468  Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>(
469  loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
470  Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>(
471  loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
472  Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>(
473  loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
474  Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>(
475  loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
476 
477  Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
478  loc, arith::CmpFPredicate::UNO, resultReal, zero);
479  Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
480  loc, arith::CmpFPredicate::UNO, resultImag, zero);
481  Value resultIsNaN =
482  rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
483  Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>(
484  loc, resultIsNaN, resultRealSpecialCase1, resultReal);
485  Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>(
486  loc, resultIsNaN, resultImagSpecialCase1, resultImag);
487 
488  rewriter.replaceOpWithNewOp<complex::CreateOp>(
489  op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
490  return success();
491  }
492 };
493 
494 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
496 
497  LogicalResult
498  matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
499  ConversionPatternRewriter &rewriter) const override {
500  auto loc = op.getLoc();
501  auto type = cast<ComplexType>(adaptor.getComplex().getType());
502  auto elementType = cast<FloatType>(type.getElementType());
503  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
504 
505  Value real =
506  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
507  Value imag =
508  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
509  Value expReal = rewriter.create<math::ExpOp>(loc, real, fmf.getValue());
510  Value cosImag = rewriter.create<math::CosOp>(loc, imag, fmf.getValue());
511  Value resultReal =
512  rewriter.create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue());
513  Value sinImag = rewriter.create<math::SinOp>(loc, imag, fmf.getValue());
514  Value resultImag =
515  rewriter.create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue());
516 
517  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
518  resultImag);
519  return success();
520  }
521 };
522 
523 Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg,
524  ArrayRef<double> coefficients,
525  arith::FastMathFlagsAttr fmf) {
526  auto argType = mlir::cast<FloatType>(arg.getType());
527  Value poly =
528  b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[0]));
529  for (unsigned i = 1; i < coefficients.size(); ++i) {
530  poly = b.create<math::FmaOp>(
531  poly, arg,
532  b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[i])),
533  fmf);
534  }
535  return poly;
536 }
537 
538 struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
540 
541  // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
542  // [handle inaccuracies when a and/or b are small]
543  // = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i
544  // = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i
545  LogicalResult
546  matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
547  ConversionPatternRewriter &rewriter) const override {
548  auto type = op.getType();
549  auto elemType = mlir::cast<FloatType>(type.getElementType());
550 
551  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
552  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
553  Value real = b.create<complex::ReOp>(adaptor.getComplex());
554  Value imag = b.create<complex::ImOp>(adaptor.getComplex());
555 
556  Value zero = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 0.0));
557  Value one = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 1.0));
558 
559  Value expm1Real = b.create<math::ExpM1Op>(real, fmf);
560  Value expReal = b.create<arith::AddFOp>(expm1Real, one, fmf);
561 
562  Value sinImag = b.create<math::SinOp>(imag, fmf);
563  Value cosm1Imag = emitCosm1(imag, fmf, b);
564  Value cosImag = b.create<arith::AddFOp>(cosm1Imag, one, fmf);
565 
566  Value realResult = b.create<arith::AddFOp>(
567  b.create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf);
568 
569  Value imagIsZero = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag,
570  zero, fmf.getValue());
571  Value imagResult = b.create<arith::SelectOp>(
572  imagIsZero, zero, b.create<arith::MulFOp>(expReal, sinImag, fmf));
573 
574  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult,
575  imagResult);
576  return success();
577  }
578 
579 private:
580  Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf,
581  ImplicitLocOpBuilder &b) const {
582  auto argType = mlir::cast<FloatType>(arg.getType());
583  auto negHalf = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -0.5));
584  auto negOne = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -1.0));
585 
586  // Algorithm copied from cephes cosm1.
587  SmallVector<double, 7> kCoeffs{
588  4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
589  2.0876754287081521758361E-9, -2.7557319214999787979814E-7,
590  2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
591  4.1666666666666666609054E-2,
592  };
593  Value cos = b.create<math::CosOp>(arg, fmf);
594  Value forLargeArg = b.create<arith::AddFOp>(cos, negOne, fmf);
595 
596  Value argPow2 = b.create<arith::MulFOp>(arg, arg, fmf);
597  Value argPow4 = b.create<arith::MulFOp>(argPow2, argPow2, fmf);
598  Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf);
599 
600  auto forSmallArg =
601  b.create<arith::AddFOp>(b.create<arith::MulFOp>(argPow4, poly, fmf),
602  b.create<arith::MulFOp>(negHalf, argPow2, fmf));
603 
604  // (pi/4)^2 is approximately 0.61685
605  Value piOver4Pow2 =
606  b.create<arith::ConstantOp>(b.getFloatAttr(argType, 0.61685));
607  Value cond = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, argPow2,
608  piOver4Pow2, fmf.getValue());
609  return b.create<arith::SelectOp>(cond, forLargeArg, forSmallArg);
610  }
611 };
612 
613 struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
615 
616  LogicalResult
617  matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
618  ConversionPatternRewriter &rewriter) const override {
619  auto type = cast<ComplexType>(adaptor.getComplex().getType());
620  auto elementType = cast<FloatType>(type.getElementType());
621  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
622  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
623 
624  Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(),
625  fmf.getValue());
626  Value resultReal = b.create<math::LogOp>(elementType, abs, fmf.getValue());
627  Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
628  Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
629  Value resultImag =
630  b.create<math::Atan2Op>(elementType, imag, real, fmf.getValue());
631  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
632  resultImag);
633  return success();
634  }
635 };
636 
637 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
639 
640  LogicalResult
641  matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
642  ConversionPatternRewriter &rewriter) const override {
643  auto type = cast<ComplexType>(adaptor.getComplex().getType());
644  auto elementType = cast<FloatType>(type.getElementType());
645  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
646  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
647 
648  Value real = b.create<complex::ReOp>(adaptor.getComplex());
649  Value imag = b.create<complex::ImOp>(adaptor.getComplex());
650 
651  Value half = b.create<arith::ConstantOp>(elementType,
652  b.getFloatAttr(elementType, 0.5));
653  Value one = b.create<arith::ConstantOp>(elementType,
654  b.getFloatAttr(elementType, 1));
655  Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf);
656  Value absRealPlusOne = b.create<math::AbsFOp>(realPlusOne, fmf);
657  Value absImag = b.create<math::AbsFOp>(imag, fmf);
658 
659  Value maxAbs = b.create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf);
660  Value minAbs = b.create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf);
661 
662  Value useReal = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
663  realPlusOne, absImag, fmf);
664  Value maxMinusOne = b.create<arith::SubFOp>(maxAbs, one, fmf);
665  Value maxAbsOfRealPlusOneAndImagMinusOne =
666  b.create<arith::SelectOp>(useReal, real, maxMinusOne);
667  arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
668  fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
669  Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmfWithNaNInf);
670  Value logOfMaxAbsOfRealPlusOneAndImag =
671  b.create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf);
672  Value logOfSqrtPart = b.create<math::Log1pOp>(
673  b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmfWithNaNInf),
674  fmfWithNaNInf);
675  Value r = b.create<arith::AddFOp>(
676  b.create<arith::MulFOp>(half, logOfSqrtPart, fmfWithNaNInf),
677  logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf);
678  Value resultReal = b.create<arith::SelectOp>(
679  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf),
680  minAbs, r);
681  Value resultImag = b.create<math::Atan2Op>(imag, realPlusOne, fmf);
682  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
683  resultImag);
684  return success();
685  }
686 };
687 
688 struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
690 
691  LogicalResult
692  matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
693  ConversionPatternRewriter &rewriter) const override {
694  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
695  auto type = cast<ComplexType>(adaptor.getLhs().getType());
696  auto elementType = cast<FloatType>(type.getElementType());
697  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
698  auto fmfValue = fmf.getValue();
699 
700  Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
701  Value lhsRealAbs = b.create<math::AbsFOp>(lhsReal, fmfValue);
702  Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
703  Value lhsImagAbs = b.create<math::AbsFOp>(lhsImag, fmfValue);
704  Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
705  Value rhsRealAbs = b.create<math::AbsFOp>(rhsReal, fmfValue);
706  Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
707  Value rhsImagAbs = b.create<math::AbsFOp>(rhsImag, fmfValue);
708 
709  Value lhsRealTimesRhsReal =
710  b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
711  Value lhsRealTimesRhsRealAbs =
712  b.create<math::AbsFOp>(lhsRealTimesRhsReal, fmfValue);
713  Value lhsImagTimesRhsImag =
714  b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
715  Value lhsImagTimesRhsImagAbs =
716  b.create<math::AbsFOp>(lhsImagTimesRhsImag, fmfValue);
717  Value real = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
718  lhsImagTimesRhsImag, fmfValue);
719 
720  Value lhsImagTimesRhsReal =
721  b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
722  Value lhsImagTimesRhsRealAbs =
723  b.create<math::AbsFOp>(lhsImagTimesRhsReal, fmfValue);
724  Value lhsRealTimesRhsImag =
725  b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
726  Value lhsRealTimesRhsImagAbs =
727  b.create<math::AbsFOp>(lhsRealTimesRhsImag, fmfValue);
728  Value imag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
729  lhsRealTimesRhsImag, fmfValue);
730 
731  // Handle cases where the "naive" calculation results in NaN values.
732  Value realIsNan =
733  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
734  Value imagIsNan =
735  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
736  Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan);
737 
738  Value inf = b.create<arith::ConstantOp>(
739  elementType,
740  b.getFloatAttr(elementType,
741  APFloat::getInf(elementType.getFloatSemantics())));
742 
743  // Case 1. `lhsReal` or `lhsImag` are infinite.
744  Value lhsRealIsInf =
745  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
746  Value lhsImagIsInf =
747  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
748  Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
749  Value rhsRealIsNan =
750  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
751  Value rhsImagIsNan =
752  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
753  Value zero =
754  b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
755  Value one = b.create<arith::ConstantOp>(elementType,
756  b.getFloatAttr(elementType, 1));
757  Value lhsRealIsInfFloat =
758  b.create<arith::SelectOp>(lhsRealIsInf, one, zero);
759  lhsReal = b.create<arith::SelectOp>(
760  lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
761  lhsReal);
762  Value lhsImagIsInfFloat =
763  b.create<arith::SelectOp>(lhsImagIsInf, one, zero);
764  lhsImag = b.create<arith::SelectOp>(
765  lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
766  lhsImag);
767  Value lhsIsInfAndRhsRealIsNan =
768  b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
769  rhsReal = b.create<arith::SelectOp>(
770  lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
771  rhsReal);
772  Value lhsIsInfAndRhsImagIsNan =
773  b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
774  rhsImag = b.create<arith::SelectOp>(
775  lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
776  rhsImag);
777 
778  // Case 2. `rhsReal` or `rhsImag` are infinite.
779  Value rhsRealIsInf =
780  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
781  Value rhsImagIsInf =
782  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
783  Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
784  Value lhsRealIsNan =
785  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
786  Value lhsImagIsNan =
787  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
788  Value rhsRealIsInfFloat =
789  b.create<arith::SelectOp>(rhsRealIsInf, one, zero);
790  rhsReal = b.create<arith::SelectOp>(
791  rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
792  rhsReal);
793  Value rhsImagIsInfFloat =
794  b.create<arith::SelectOp>(rhsImagIsInf, one, zero);
795  rhsImag = b.create<arith::SelectOp>(
796  rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
797  rhsImag);
798  Value rhsIsInfAndLhsRealIsNan =
799  b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
800  lhsReal = b.create<arith::SelectOp>(
801  rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
802  lhsReal);
803  Value rhsIsInfAndLhsImagIsNan =
804  b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
805  lhsImag = b.create<arith::SelectOp>(
806  rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
807  lhsImag);
808  Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf);
809 
810  // Case 3. One of the pairwise products of left hand side with right hand
811  // side is infinite.
812  Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>(
813  arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
814  Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>(
815  arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
816  Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
817  lhsImagTimesRhsImagIsInf);
818  Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>(
819  arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
820  isSpecialCase =
821  b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
822  Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>(
823  arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
824  isSpecialCase =
825  b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
826  Type i1Type = b.getI1Type();
827  Value notRecalc = b.create<arith::XOrIOp>(
828  recalc,
829  b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
830  isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc);
831  Value isSpecialCaseAndLhsRealIsNan =
832  b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
833  lhsReal = b.create<arith::SelectOp>(
834  isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
835  lhsReal);
836  Value isSpecialCaseAndLhsImagIsNan =
837  b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
838  lhsImag = b.create<arith::SelectOp>(
839  isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
840  lhsImag);
841  Value isSpecialCaseAndRhsRealIsNan =
842  b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
843  rhsReal = b.create<arith::SelectOp>(
844  isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
845  rhsReal);
846  Value isSpecialCaseAndRhsImagIsNan =
847  b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
848  rhsImag = b.create<arith::SelectOp>(
849  isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
850  rhsImag);
851  recalc = b.create<arith::OrIOp>(recalc, isSpecialCase);
852  recalc = b.create<arith::AndIOp>(isNan, recalc);
853 
854  // Recalculate real part.
855  lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
856  lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
857  Value newReal = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
858  lhsImagTimesRhsImag, fmfValue);
859  real = b.create<arith::SelectOp>(
860  recalc, b.create<arith::MulFOp>(inf, newReal, fmfValue), real);
861 
862  // Recalculate imag part.
863  lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
864  lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
865  Value newImag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
866  lhsRealTimesRhsImag, fmfValue);
867  imag = b.create<arith::SelectOp>(
868  recalc, b.create<arith::MulFOp>(inf, newImag, fmfValue), imag);
869 
870  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
871  return success();
872  }
873 };
874 
875 struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
877 
878  LogicalResult
879  matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
880  ConversionPatternRewriter &rewriter) const override {
881  auto loc = op.getLoc();
882  auto type = cast<ComplexType>(adaptor.getComplex().getType());
883  auto elementType = cast<FloatType>(type.getElementType());
884 
885  Value real =
886  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
887  Value imag =
888  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
889  Value negReal = rewriter.create<arith::NegFOp>(loc, real);
890  Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
891  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
892  return success();
893  }
894 };
895 
896 struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
897  using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
898 
899  std::pair<Value, Value> combine(Location loc, Value scaledExp,
900  Value reciprocalExp, Value sin, Value cos,
901  ConversionPatternRewriter &rewriter,
902  arith::FastMathFlagsAttr fmf) const override {
903  // Complex sine is defined as;
904  // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
905  // Plugging in:
906  // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
907  // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
908  // and defining t := exp(y)
909  // We get:
910  // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
911  // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
912  Value sum =
913  rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf);
914  Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin, fmf);
915  Value diff =
916  rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf);
917  Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos, fmf);
918  return {resultReal, resultImag};
919  }
920 };
921 
922 // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
923 struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
925 
926  LogicalResult
927  matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
928  ConversionPatternRewriter &rewriter) const override {
929  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
930 
931  auto type = cast<ComplexType>(op.getType());
932  auto elementType = cast<FloatType>(type.getElementType());
933  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
934 
935  auto cst = [&](APFloat v) {
936  return b.create<arith::ConstantOp>(elementType,
937  b.getFloatAttr(elementType, v));
938  };
939  const auto &floatSemantics = elementType.getFloatSemantics();
940  Value zero = cst(APFloat::getZero(floatSemantics));
941  Value half = b.create<arith::ConstantOp>(elementType,
942  b.getFloatAttr(elementType, 0.5));
943 
944  Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
945  Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
946  Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
947  Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
948  Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf);
949  Value cos = b.create<math::CosOp>(sqrtArg, fmf);
950  Value sin = b.create<math::SinOp>(sqrtArg, fmf);
951  // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply
952  // 0 * inf.
953  Value sinIsZero =
954  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf);
955 
956  Value resultReal = b.create<arith::MulFOp>(absSqrt, cos, fmf);
957  Value resultImag = b.create<arith::SelectOp>(
958  sinIsZero, zero, b.create<arith::MulFOp>(absSqrt, sin, fmf));
959  if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
960  arith::FastMathFlags::ninf)) {
961  Value inf = cst(APFloat::getInf(floatSemantics));
962  Value negInf = cst(APFloat::getInf(floatSemantics, true));
963  Value nan = cst(APFloat::getNaN(floatSemantics));
964  Value absImag = b.create<math::AbsFOp>(elementType, imag, fmf);
965 
966  Value absImagIsInf =
967  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
968  Value absImagIsNotInf =
969  b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf);
970  Value realIsInf =
971  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf);
972  Value realIsNegInf =
973  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf);
974 
975  resultReal = b.create<arith::SelectOp>(
976  b.create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero,
977  resultReal);
978  resultReal = b.create<arith::SelectOp>(
979  b.create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal);
980 
981  Value imagSignInf = b.create<math::CopySignOp>(inf, imag, fmf);
982  resultImag = b.create<arith::SelectOp>(
983  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt),
984  nan, resultImag);
985  resultImag = b.create<arith::SelectOp>(
986  b.create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf,
987  resultImag);
988  }
989 
990  Value resultIsZero =
991  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
992  resultReal = b.create<arith::SelectOp>(resultIsZero, zero, resultReal);
993  resultImag = b.create<arith::SelectOp>(resultIsZero, zero, resultImag);
994 
995  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
996  resultImag);
997  return success();
998  }
999 };
1000 
1001 struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
1003 
1004  LogicalResult
1005  matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
1006  ConversionPatternRewriter &rewriter) const override {
1007  auto type = cast<ComplexType>(adaptor.getComplex().getType());
1008  auto elementType = cast<FloatType>(type.getElementType());
1009  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1010  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
1011 
1012  Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
1013  Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
1014  Value zero =
1015  b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
1016  Value realIsZero =
1017  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
1018  Value imagIsZero =
1019  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
1020  Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
1021  auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), fmf);
1022  Value realSign = b.create<arith::DivFOp>(real, abs, fmf);
1023  Value imagSign = b.create<arith::DivFOp>(imag, abs, fmf);
1024  Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
1025  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
1026  adaptor.getComplex(), sign);
1027  return success();
1028  }
1029 };
1030 
1031 template <typename Op>
1032 struct TanTanhOpConversion : public OpConversionPattern<Op> {
1034 
1035  LogicalResult
1036  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
1037  ConversionPatternRewriter &rewriter) const override {
1038  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1039  auto loc = op.getLoc();
1040  auto type = cast<ComplexType>(adaptor.getComplex().getType());
1041  auto elementType = cast<FloatType>(type.getElementType());
1042  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
1043  const auto &floatSemantics = elementType.getFloatSemantics();
1044 
1045  Value real =
1046  b.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
1047  Value imag =
1048  b.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
1049  Value negOne = b.create<arith::ConstantOp>(
1050  elementType, b.getFloatAttr(elementType, -1.0));
1051 
1052  if constexpr (std::is_same_v<Op, complex::TanOp>) {
1053  // tan(x+yi) = -i*tanh(-y + xi)
1054  std::swap(real, imag);
1055  real = b.create<arith::MulFOp>(real, negOne, fmf);
1056  }
1057 
1058  auto cst = [&](APFloat v) {
1059  return b.create<arith::ConstantOp>(elementType,
1060  b.getFloatAttr(elementType, v));
1061  };
1062  Value inf = cst(APFloat::getInf(floatSemantics));
1063  Value four = b.create<arith::ConstantOp>(elementType,
1064  b.getFloatAttr(elementType, 4.0));
1065  Value twoReal = b.create<arith::AddFOp>(real, real, fmf);
1066  Value negTwoReal = b.create<arith::MulFOp>(negOne, twoReal, fmf);
1067 
1068  Value expTwoRealMinusOne = b.create<math::ExpM1Op>(twoReal, fmf);
1069  Value expNegTwoRealMinusOne = b.create<math::ExpM1Op>(negTwoReal, fmf);
1070  Value realNum =
1071  b.create<arith::SubFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
1072 
1073  Value cosImag = b.create<math::CosOp>(imag, fmf);
1074  Value cosImagSq = b.create<arith::MulFOp>(cosImag, cosImag, fmf);
1075  Value twoCosTwoImagPlusOne = b.create<arith::MulFOp>(cosImagSq, four, fmf);
1076  Value sinImag = b.create<math::SinOp>(imag, fmf);
1077 
1078  Value imagNum = b.create<arith::MulFOp>(
1079  four, b.create<arith::MulFOp>(cosImag, sinImag, fmf), fmf);
1080 
1081  Value expSumMinusTwo =
1082  b.create<arith::AddFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
1083  Value denom =
1084  b.create<arith::AddFOp>(expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
1085 
1086  Value isInf = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
1087  expSumMinusTwo, inf, fmf);
1088  Value realLimit = b.create<math::CopySignOp>(negOne, real, fmf);
1089 
1090  Value resultReal = b.create<arith::SelectOp>(
1091  isInf, realLimit, b.create<arith::DivFOp>(realNum, denom, fmf));
1092  Value resultImag = b.create<arith::DivFOp>(imagNum, denom, fmf);
1093 
1094  if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1095  arith::FastMathFlags::ninf)) {
1096  Value absReal = b.create<math::AbsFOp>(real, fmf);
1097  Value zero = b.create<arith::ConstantOp>(
1098  elementType, b.getFloatAttr(elementType, 0.0));
1099  Value nan = cst(APFloat::getNaN(floatSemantics));
1100 
1101  Value absRealIsInf =
1102  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
1103  Value imagIsZero =
1104  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
1105  Value absRealIsNotInf = b.create<arith::XOrIOp>(
1106  absRealIsInf, b.create<arith::ConstantIntOp>(true, /*width=*/1));
1107 
1108  Value imagNumIsNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO,
1109  imagNum, imagNum, fmf);
1110  Value resultRealIsNaN =
1111  b.create<arith::AndIOp>(imagNumIsNaN, absRealIsNotInf);
1112  Value resultImagIsZero = b.create<arith::OrIOp>(
1113  imagIsZero, b.create<arith::AndIOp>(absRealIsInf, imagNumIsNaN));
1114 
1115  resultReal = b.create<arith::SelectOp>(resultRealIsNaN, nan, resultReal);
1116  resultImag =
1117  b.create<arith::SelectOp>(resultImagIsZero, zero, resultImag);
1118  }
1119 
1120  if constexpr (std::is_same_v<Op, complex::TanOp>) {
1121  // tan(x+yi) = -i*tanh(-y + xi)
1122  std::swap(resultReal, resultImag);
1123  resultImag = b.create<arith::MulFOp>(resultImag, negOne, fmf);
1124  }
1125 
1126  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
1127  resultImag);
1128  return success();
1129  }
1130 };
1131 
1132 struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
1134 
1135  LogicalResult
1136  matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
1137  ConversionPatternRewriter &rewriter) const override {
1138  auto loc = op.getLoc();
1139  auto type = cast<ComplexType>(adaptor.getComplex().getType());
1140  auto elementType = cast<FloatType>(type.getElementType());
1141  Value real =
1142  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
1143  Value imag =
1144  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
1145  Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag);
1146 
1147  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);
1148 
1149  return success();
1150  }
1151 };
1152 
1153 /// Converts lhs^y = (a+bi)^(c+di) to
1154 /// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
1155 /// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
1156 static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
1157  ComplexType type, Value lhs, Value c, Value d,
1158  arith::FastMathFlags fmf) {
1159  auto elementType = cast<FloatType>(type.getElementType());
1160 
1161  Value a = builder.create<complex::ReOp>(lhs);
1162  Value b = builder.create<complex::ImOp>(lhs);
1163 
1164  Value abs = builder.create<complex::AbsOp>(lhs, fmf);
1165  Value absToC = builder.create<math::PowFOp>(abs, c, fmf);
1166 
1167  Value negD = builder.create<arith::NegFOp>(d, fmf);
1168  Value argLhs = builder.create<math::Atan2Op>(b, a, fmf);
1169  Value negDArgLhs = builder.create<arith::MulFOp>(negD, argLhs, fmf);
1170  Value expNegDArgLhs = builder.create<math::ExpOp>(negDArgLhs, fmf);
1171 
1172  Value coeff = builder.create<arith::MulFOp>(absToC, expNegDArgLhs, fmf);
1173  Value lnAbs = builder.create<math::LogOp>(abs, fmf);
1174  Value cArgLhs = builder.create<arith::MulFOp>(c, argLhs, fmf);
1175  Value dLnAbs = builder.create<arith::MulFOp>(d, lnAbs, fmf);
1176  Value q = builder.create<arith::AddFOp>(cArgLhs, dLnAbs, fmf);
1177  Value cosQ = builder.create<math::CosOp>(q, fmf);
1178  Value sinQ = builder.create<math::SinOp>(q, fmf);
1179 
1180  Value inf = builder.create<arith::ConstantOp>(
1181  elementType,
1182  builder.getFloatAttr(elementType,
1183  APFloat::getInf(elementType.getFloatSemantics())));
1184  Value zero = builder.create<arith::ConstantOp>(
1185  elementType, builder.getFloatAttr(elementType, 0.0));
1186  Value one = builder.create<arith::ConstantOp>(
1187  elementType, builder.getFloatAttr(elementType, 1.0));
1188  Value complexOne = builder.create<complex::CreateOp>(type, one, zero);
1189  Value complexZero = builder.create<complex::CreateOp>(type, zero, zero);
1190  Value complexInf = builder.create<complex::CreateOp>(type, inf, zero);
1191 
1192  // Case 0:
1193  // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
1194  // Branch Cuts for Complex Elementary Functions or Much Ado About
1195  // Nothing's Sign Bit, W. Kahan, Section 10.
1196  Value absEqZero =
1197  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, abs, zero, fmf);
1198  Value dEqZero =
1199  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero, fmf);
1200  Value cEqZero =
1201  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero, fmf);
1202  Value bEqZero =
1203  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, b, zero, fmf);
1204 
1205  Value zeroLeC =
1206  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLE, zero, c, fmf);
1207  Value coeffCosQ = builder.create<arith::MulFOp>(coeff, cosQ, fmf);
1208  Value coeffSinQ = builder.create<arith::MulFOp>(coeff, sinQ, fmf);
1209  Value complexOneOrZero =
1210  builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero);
1211  Value coeffCosSin =
1212  builder.create<complex::CreateOp>(type, coeffCosQ, coeffSinQ);
1213  Value cutoff0 = builder.create<arith::SelectOp>(
1214  builder.create<arith::AndIOp>(
1215  builder.create<arith::AndIOp>(absEqZero, dEqZero), zeroLeC),
1216  complexOneOrZero, coeffCosSin);
1217 
1218  // Case 1:
1219  // x^0 is defined to be 1 for any x, see
1220  // Branch Cuts for Complex Elementary Functions or Much Ado About
1221  // Nothing's Sign Bit, W. Kahan, Section 10.
1222  Value rhsEqZero = builder.create<arith::AndIOp>(cEqZero, dEqZero);
1223  Value cutoff1 =
1224  builder.create<arith::SelectOp>(rhsEqZero, complexOne, cutoff0);
1225 
1226  // Case 2:
1227  // 1^(c + d*i) = 1 + 0*i
1228  Value lhsEqOne = builder.create<arith::AndIOp>(
1229  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf),
1230  bEqZero);
1231  Value cutoff2 =
1232  builder.create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1);
1233 
1234  // Case 3:
1235  // inf^(c + 0*i) = inf + 0*i, c > 0
1236  Value lhsEqInf = builder.create<arith::AndIOp>(
1237  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf),
1238  bEqZero);
1239  Value rhsGt0 = builder.create<arith::AndIOp>(
1240  dEqZero,
1241  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf));
1242  Value cutoff3 = builder.create<arith::SelectOp>(
1243  builder.create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2);
1244 
1245  // Case 4:
1246  // inf^(c + 0*i) = 0 + 0*i, c < 0
1247  Value rhsLt0 = builder.create<arith::AndIOp>(
1248  dEqZero,
1249  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf));
1250  Value cutoff4 = builder.create<arith::SelectOp>(
1251  builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3);
1252 
1253  return cutoff4;
1254 }
1255 
1256 struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
1258 
1259  LogicalResult
1260  matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
1261  ConversionPatternRewriter &rewriter) const override {
1262  mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
1263  auto type = cast<ComplexType>(adaptor.getLhs().getType());
1264  auto elementType = cast<FloatType>(type.getElementType());
1265 
1266  Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
1267  Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
1268 
1269  rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
1270  c, d, op.getFastmath())});
1271  return success();
1272  }
1273 };
1274 
1275 struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
1277 
1278  LogicalResult
1279  matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
1280  ConversionPatternRewriter &rewriter) const override {
1281  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1282  auto type = cast<ComplexType>(adaptor.getComplex().getType());
1283  auto elementType = cast<FloatType>(type.getElementType());
1284 
1285  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
1286 
1287  auto cst = [&](APFloat v) {
1288  return b.create<arith::ConstantOp>(elementType,
1289  b.getFloatAttr(elementType, v));
1290  };
1291  const auto &floatSemantics = elementType.getFloatSemantics();
1292  Value zero = cst(APFloat::getZero(floatSemantics));
1293  Value inf = cst(APFloat::getInf(floatSemantics));
1294  Value negHalf = b.create<arith::ConstantOp>(
1295  elementType, b.getFloatAttr(elementType, -0.5));
1296  Value nan = cst(APFloat::getNaN(floatSemantics));
1297 
1298  Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
1299  Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
1300  Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
1301  Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
1302  Value rsqrtArg = b.create<arith::MulFOp>(argArg, negHalf, fmf);
1303  Value cos = b.create<math::CosOp>(rsqrtArg, fmf);
1304  Value sin = b.create<math::SinOp>(rsqrtArg, fmf);
1305 
1306  Value resultReal = b.create<arith::MulFOp>(absRsqrt, cos, fmf);
1307  Value resultImag = b.create<arith::MulFOp>(absRsqrt, sin, fmf);
1308 
1309  if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1310  arith::FastMathFlags::ninf)) {
1311  Value negOne = b.create<arith::ConstantOp>(
1312  elementType, b.getFloatAttr(elementType, -1));
1313 
1314  Value realSignedZero = b.create<math::CopySignOp>(zero, real, fmf);
1315  Value imagSignedZero = b.create<math::CopySignOp>(zero, imag, fmf);
1316  Value negImagSignedZero =
1317  b.create<arith::MulFOp>(negOne, imagSignedZero, fmf);
1318 
1319  Value absReal = b.create<math::AbsFOp>(real, fmf);
1320  Value absImag = b.create<math::AbsFOp>(imag, fmf);
1321 
1322  Value absImagIsInf =
1323  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
1324  Value realIsNan =
1325  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf);
1326  Value realIsInf =
1327  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
1328  Value inIsNanInf = b.create<arith::AndIOp>(absImagIsInf, realIsNan);
1329 
1330  Value resultIsZero = b.create<arith::OrIOp>(inIsNanInf, realIsInf);
1331 
1332  resultReal =
1333  b.create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal);
1334  resultImag = b.create<arith::SelectOp>(resultIsZero, negImagSignedZero,
1335  resultImag);
1336  }
1337 
1338  Value isRealZero =
1339  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf);
1340  Value isImagZero =
1341  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
1342  Value isZero = b.create<arith::AndIOp>(isRealZero, isImagZero);
1343 
1344  resultReal = b.create<arith::SelectOp>(isZero, inf, resultReal);
1345  resultImag = b.create<arith::SelectOp>(isZero, nan, resultImag);
1346 
1347  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
1348  resultImag);
1349  return success();
1350  }
1351 };
1352 
1353 struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
1355 
1356  LogicalResult
1357  matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
1358  ConversionPatternRewriter &rewriter) const override {
1359  auto loc = op.getLoc();
1360  auto type = op.getType();
1361  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
1362 
1363  Value real =
1364  rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
1365  Value imag =
1366  rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
1367 
1368  rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf);
1369 
1370  return success();
1371  }
1372 };
1373 
1374 } // namespace
1375 
1377  RewritePatternSet &patterns) {
1378  // clang-format off
1379  patterns.add<
1380  AbsOpConversion,
1381  AngleOpConversion,
1382  Atan2OpConversion,
1383  BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1384  BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
1385  ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
1386  ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
1387  ConjOpConversion,
1388  CosOpConversion,
1389  DivOpConversion,
1390  ExpOpConversion,
1391  Expm1OpConversion,
1392  Log1pOpConversion,
1393  LogOpConversion,
1394  MulOpConversion,
1395  NegOpConversion,
1396  SignOpConversion,
1397  SinOpConversion,
1398  SqrtOpConversion,
1399  TanTanhOpConversion<complex::TanOp>,
1400  TanTanhOpConversion<complex::TanhOp>,
1401  PowOpConversion,
1402  RsqrtOpConversion
1403  >(patterns.getContext());
1404  // clang-format on
1405 }
1406 
1407 namespace {
1408 struct ConvertComplexToStandardPass
1409  : public impl::ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
1410  void runOnOperation() override;
1411 };
1412 
1413 void ConvertComplexToStandardPass::runOnOperation() {
1414  // Convert to the Standard dialect using the converter defined above.
1415  RewritePatternSet patterns(&getContext());
1417 
1418  ConversionTarget target(getContext());
1419  target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
1420  target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1421  if (failed(
1422  applyPartialConversion(getOperation(), target, std::move(patterns))))
1423  signalPassFailure();
1424 }
1425 } // namespace
1426 
1428  return std::make_unique<ConvertComplexToStandardPass>();
1429 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:294
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
IntegerType getI1Type()
Definition: Builders.cpp:97
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:125
This provides public APIs that all operations should have.
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
Include the generated interface declarations.
void populateComplexToStandardConversionPatterns(RewritePatternSet &patterns)
Populate the given list with patterns that convert from Complex to Standard.
std::unique_ptr< Pass > createConvertComplexToStandardPass()
Create a pass to convert Complex operations to the Standard dialect.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.