MLIR  19.0.0git
ComplexToStandard.cpp
Go to the documentation of this file.
1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Pass/Pass.h"
18 #include <memory>
19 #include <type_traits>
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARD
23 #include "mlir/Conversion/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 
28 namespace {
29 
30 enum class AbsFn { abs, sqrt, rsqrt };
31 
32 // Returns the absolute value, its square root or its reciprocal square root.
33 Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
34  ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) {
35  Value one = b.create<arith::ConstantOp>(real.getType(),
36  b.getFloatAttr(real.getType(), 1.0));
37 
38  Value absReal = b.create<math::AbsFOp>(real, fmf);
39  Value absImag = b.create<math::AbsFOp>(imag, fmf);
40 
41  Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
42  Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
43  Value ratio = b.create<arith::DivFOp>(min, max, fmf);
44  Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmf);
45  Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmf);
46  Value result;
47 
48  if (fn == AbsFn::rsqrt) {
49  ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmf);
50  min = b.create<math::RsqrtOp>(min, fmf);
51  max = b.create<math::RsqrtOp>(max, fmf);
52  }
53 
54  if (fn == AbsFn::sqrt) {
55  Value quarter = b.create<arith::ConstantOp>(
56  real.getType(), b.getFloatAttr(real.getType(), 0.25));
57  // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
58  Value sqrt = b.create<math::SqrtOp>(max, fmf);
59  Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmf);
60  result = b.create<arith::MulFOp>(sqrt, p025, fmf);
61  } else {
62  Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmf);
63  result = b.create<arith::MulFOp>(max, sqrt, fmf);
64  }
65 
66  Value isNaN =
67  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result, result, fmf);
68  return b.create<arith::SelectOp>(isNaN, min, result);
69 }
70 
71 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
73 
75  matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
76  ConversionPatternRewriter &rewriter) const override {
77  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
78 
79  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
80 
81  Value real = b.create<complex::ReOp>(adaptor.getComplex());
82  Value imag = b.create<complex::ImOp>(adaptor.getComplex());
83  rewriter.replaceOp(op, computeAbs(real, imag, fmf, b));
84 
85  return success();
86  }
87 };
88 
89 // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
90 struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
92 
94  matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
95  ConversionPatternRewriter &rewriter) const override {
96  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
97 
98  auto type = cast<ComplexType>(op.getType());
99  Type elementType = type.getElementType();
100  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
101 
102  Value lhs = adaptor.getLhs();
103  Value rhs = adaptor.getRhs();
104 
105  Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs, fmf);
106  Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs, fmf);
107  Value rhsSquaredPlusLhsSquared =
108  b.create<complex::AddOp>(type, rhsSquared, lhsSquared, fmf);
109  Value sqrtOfRhsSquaredPlusLhsSquared =
110  b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared, fmf);
111 
112  Value zero =
113  b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
114  Value one = b.create<arith::ConstantOp>(elementType,
115  b.getFloatAttr(elementType, 1));
116  Value i = b.create<complex::CreateOp>(type, zero, one);
117  Value iTimesLhs = b.create<complex::MulOp>(i, lhs, fmf);
118  Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs, fmf);
119 
120  Value divResult = b.create<complex::DivOp>(
121  rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
122  Value logResult = b.create<complex::LogOp>(divResult, fmf);
123 
124  Value negativeOne = b.create<arith::ConstantOp>(
125  elementType, b.getFloatAttr(elementType, -1));
126  Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
127 
128  rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf);
129  return success();
130  }
131 };
132 
133 template <typename ComparisonOp, arith::CmpFPredicate p>
134 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
136  using ResultCombiner =
137  std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
138  arith::AndIOp, arith::OrIOp>;
139 
141  matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
142  ConversionPatternRewriter &rewriter) const override {
143  auto loc = op.getLoc();
144  auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
145 
146  Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
147  Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
148  Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
149  Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
150  Value realComparison =
151  rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
152  Value imagComparison =
153  rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
154 
155  rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
156  imagComparison);
157  return success();
158  }
159 };
160 
161 // Default conversion which applies the BinaryStandardOp separately on the real
162 // and imaginary parts. Can for example be used for complex::AddOp and
163 // complex::SubOp.
164 template <typename BinaryComplexOp, typename BinaryStandardOp>
165 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
167 
169  matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
170  ConversionPatternRewriter &rewriter) const override {
171  auto type = cast<ComplexType>(adaptor.getLhs().getType());
172  auto elementType = cast<FloatType>(type.getElementType());
173  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
174  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
175 
176  Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
177  Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
178  Value resultReal = b.create<BinaryStandardOp>(elementType, realLhs, realRhs,
179  fmf.getValue());
180  Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
181  Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
182  Value resultImag = b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs,
183  fmf.getValue());
184  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
185  resultImag);
186  return success();
187  }
188 };
189 
190 template <typename TrigonometricOp>
191 struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
192  using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
193 
195 
197  matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
198  ConversionPatternRewriter &rewriter) const override {
199  auto loc = op.getLoc();
200  auto type = cast<ComplexType>(adaptor.getComplex().getType());
201  auto elementType = cast<FloatType>(type.getElementType());
202  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
203 
204  Value real =
205  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
206  Value imag =
207  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
208 
209  // Trigonometric ops use a set of common building blocks to convert to real
210  // ops. Here we create these building blocks and call into an op-specific
211  // implementation in the subclass to combine them.
212  Value half = rewriter.create<arith::ConstantOp>(
213  loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
214  Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf);
215  Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp, fmf);
216  Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp, fmf);
217  Value sin = rewriter.create<math::SinOp>(loc, real, fmf);
218  Value cos = rewriter.create<math::CosOp>(loc, real, fmf);
219 
220  auto resultPair =
221  combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
222 
223  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
224  resultPair.second);
225  return success();
226  }
227 
228  virtual std::pair<Value, Value>
229  combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
230  Value cos, ConversionPatternRewriter &rewriter,
231  arith::FastMathFlagsAttr fmf) const = 0;
232 };
233 
234 struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
235  using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
236 
237  std::pair<Value, Value> combine(Location loc, Value scaledExp,
238  Value reciprocalExp, Value sin, Value cos,
239  ConversionPatternRewriter &rewriter,
240  arith::FastMathFlagsAttr fmf) const override {
241  // Complex cosine is defined as;
242  // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
243  // Plugging in:
244  // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
245  // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
246  // and defining t := exp(y)
247  // We get:
248  // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
249  // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
250  Value sum =
251  rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf);
252  Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos, fmf);
253  Value diff =
254  rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf);
255  Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin, fmf);
256  return {resultReal, resultImag};
257  }
258 };
259 
260 struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
262 
264  matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
265  ConversionPatternRewriter &rewriter) const override {
266  auto loc = op.getLoc();
267  auto type = cast<ComplexType>(adaptor.getLhs().getType());
268  auto elementType = cast<FloatType>(type.getElementType());
269  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
270 
271  Value lhsReal =
272  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
273  Value lhsImag =
274  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs());
275  Value rhsReal =
276  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs());
277  Value rhsImag =
278  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs());
279 
280  // Smith's algorithm to divide complex numbers. It is just a bit smarter
281  // way to compute the following formula:
282  // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
283  // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
284  // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
285  // = ((lhsReal * rhsReal + lhsImag * rhsImag) +
286  // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
287  //
288  // Depending on whether |rhsReal| < |rhsImag| we compute either
289  // rhsRealImagRatio = rhsReal / rhsImag
290  // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
291  // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
292  // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
293  //
294  // or
295  //
296  // rhsImagRealRatio = rhsImag / rhsReal
297  // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
298  // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
299  // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
300  //
301  // See https://dl.acm.org/citation.cfm?id=368661 for more details.
302  Value rhsRealImagRatio =
303  rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag, fmf);
304  Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
305  loc, rhsImag,
306  rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal, fmf),
307  fmf);
308  Value realNumerator1 = rewriter.create<arith::AddFOp>(
309  loc,
310  rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio, fmf),
311  lhsImag, fmf);
312  Value resultReal1 = rewriter.create<arith::DivFOp>(loc, realNumerator1,
313  rhsRealImagDenom, fmf);
314  Value imagNumerator1 = rewriter.create<arith::SubFOp>(
315  loc,
316  rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio, fmf),
317  lhsReal, fmf);
318  Value resultImag1 = rewriter.create<arith::DivFOp>(loc, imagNumerator1,
319  rhsRealImagDenom, fmf);
320 
321  Value rhsImagRealRatio =
322  rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal, fmf);
323  Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
324  loc, rhsReal,
325  rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag, fmf),
326  fmf);
327  Value realNumerator2 = rewriter.create<arith::AddFOp>(
328  loc, lhsReal,
329  rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio, fmf),
330  fmf);
331  Value resultReal2 = rewriter.create<arith::DivFOp>(loc, realNumerator2,
332  rhsImagRealDenom, fmf);
333  Value imagNumerator2 = rewriter.create<arith::SubFOp>(
334  loc, lhsImag,
335  rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio, fmf),
336  fmf);
337  Value resultImag2 = rewriter.create<arith::DivFOp>(loc, imagNumerator2,
338  rhsImagRealDenom, fmf);
339 
340  // Consider corner cases.
341  // Case 1. Zero denominator, numerator contains at most one NaN value.
342  Value zero = rewriter.create<arith::ConstantOp>(
343  loc, elementType, rewriter.getZeroAttr(elementType));
344  Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsReal, fmf);
345  Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
346  loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
347  Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsImag, fmf);
348  Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
349  loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
350  Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
351  loc, arith::CmpFPredicate::ORD, lhsReal, zero);
352  Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
353  loc, arith::CmpFPredicate::ORD, lhsImag, zero);
354  Value lhsContainsNotNaNValue =
355  rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
356  Value resultIsInfinity = rewriter.create<arith::AndIOp>(
357  loc, lhsContainsNotNaNValue,
358  rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
359  Value inf = rewriter.create<arith::ConstantOp>(
360  loc, elementType,
361  rewriter.getFloatAttr(
362  elementType, APFloat::getInf(elementType.getFloatSemantics())));
363  Value infWithSignOfRhsReal =
364  rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
365  Value infinityResultReal =
366  rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal, fmf);
367  Value infinityResultImag =
368  rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag, fmf);
369 
370  // Case 2. Infinite numerator, finite denominator.
371  Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
372  loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
373  Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
374  loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
375  Value rhsFinite =
376  rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
377  Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsReal, fmf);
378  Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
379  loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
380  Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsImag, fmf);
381  Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
382  loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
383  Value lhsInfinite =
384  rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
385  Value infNumFiniteDenom =
386  rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
387  Value one = rewriter.create<arith::ConstantOp>(
388  loc, elementType, rewriter.getFloatAttr(elementType, 1));
389  Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
390  loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
391  lhsReal);
392  Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
393  loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
394  lhsImag);
395  Value lhsRealIsInfWithSignTimesRhsReal =
396  rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal, fmf);
397  Value lhsImagIsInfWithSignTimesRhsImag =
398  rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag, fmf);
399  Value resultReal3 = rewriter.create<arith::MulFOp>(
400  loc, inf,
401  rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
402  lhsImagIsInfWithSignTimesRhsImag, fmf),
403  fmf);
404  Value lhsRealIsInfWithSignTimesRhsImag =
405  rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag, fmf);
406  Value lhsImagIsInfWithSignTimesRhsReal =
407  rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal, fmf);
408  Value resultImag3 = rewriter.create<arith::MulFOp>(
409  loc, inf,
410  rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
411  lhsRealIsInfWithSignTimesRhsImag, fmf),
412  fmf);
413 
414  // Case 3: Finite numerator, infinite denominator.
415  Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
416  loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
417  Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
418  loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
419  Value lhsFinite =
420  rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
421  Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
422  loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
423  Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
424  loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
425  Value rhsInfinite =
426  rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
427  Value finiteNumInfiniteDenom =
428  rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
429  Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
430  loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
431  rhsReal);
432  Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
433  loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
434  rhsImag);
435  Value rhsRealIsInfWithSignTimesLhsReal =
436  rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign, fmf);
437  Value rhsImagIsInfWithSignTimesLhsImag =
438  rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign, fmf);
439  Value resultReal4 = rewriter.create<arith::MulFOp>(
440  loc, zero,
441  rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
442  rhsImagIsInfWithSignTimesLhsImag, fmf),
443  fmf);
444  Value rhsRealIsInfWithSignTimesLhsImag =
445  rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign, fmf);
446  Value rhsImagIsInfWithSignTimesLhsReal =
447  rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign, fmf);
448  Value resultImag4 = rewriter.create<arith::MulFOp>(
449  loc, zero,
450  rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
451  rhsImagIsInfWithSignTimesLhsReal, fmf),
452  fmf);
453 
454  Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
455  loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
456  Value resultReal = rewriter.create<arith::SelectOp>(
457  loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
458  Value resultImag = rewriter.create<arith::SelectOp>(
459  loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
460  Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>(
461  loc, finiteNumInfiniteDenom, resultReal4, resultReal);
462  Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>(
463  loc, finiteNumInfiniteDenom, resultImag4, resultImag);
464  Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>(
465  loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
466  Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>(
467  loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
468  Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>(
469  loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
470  Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>(
471  loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
472 
473  Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
474  loc, arith::CmpFPredicate::UNO, resultReal, zero);
475  Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
476  loc, arith::CmpFPredicate::UNO, resultImag, zero);
477  Value resultIsNaN =
478  rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
479  Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>(
480  loc, resultIsNaN, resultRealSpecialCase1, resultReal);
481  Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>(
482  loc, resultIsNaN, resultImagSpecialCase1, resultImag);
483 
484  rewriter.replaceOpWithNewOp<complex::CreateOp>(
485  op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
486  return success();
487  }
488 };
489 
490 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
492 
494  matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
495  ConversionPatternRewriter &rewriter) const override {
496  auto loc = op.getLoc();
497  auto type = cast<ComplexType>(adaptor.getComplex().getType());
498  auto elementType = cast<FloatType>(type.getElementType());
499  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
500 
501  Value real =
502  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
503  Value imag =
504  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
505  Value expReal = rewriter.create<math::ExpOp>(loc, real, fmf.getValue());
506  Value cosImag = rewriter.create<math::CosOp>(loc, imag, fmf.getValue());
507  Value resultReal =
508  rewriter.create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue());
509  Value sinImag = rewriter.create<math::SinOp>(loc, imag, fmf.getValue());
510  Value resultImag =
511  rewriter.create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue());
512 
513  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
514  resultImag);
515  return success();
516  }
517 };
518 
519 struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
521 
523  matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
524  ConversionPatternRewriter &rewriter) const override {
525  auto type = cast<ComplexType>(adaptor.getComplex().getType());
526  auto elementType = cast<FloatType>(type.getElementType());
527  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
528 
529  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
530  Value exp = b.create<complex::ExpOp>(adaptor.getComplex(), fmf.getValue());
531 
532  Value real = b.create<complex::ReOp>(elementType, exp);
533  Value one = b.create<arith::ConstantOp>(elementType,
534  b.getFloatAttr(elementType, 1));
535  Value realMinusOne = b.create<arith::SubFOp>(real, one, fmf.getValue());
536  Value imag = b.create<complex::ImOp>(elementType, exp);
537 
538  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne,
539  imag);
540  return success();
541  }
542 };
543 
544 struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
546 
548  matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
549  ConversionPatternRewriter &rewriter) const override {
550  auto type = cast<ComplexType>(adaptor.getComplex().getType());
551  auto elementType = cast<FloatType>(type.getElementType());
552  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
553  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
554 
555  Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(),
556  fmf.getValue());
557  Value resultReal = b.create<math::LogOp>(elementType, abs, fmf.getValue());
558  Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
559  Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
560  Value resultImag =
561  b.create<math::Atan2Op>(elementType, imag, real, fmf.getValue());
562  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
563  resultImag);
564  return success();
565  }
566 };
567 
568 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
570 
572  matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
573  ConversionPatternRewriter &rewriter) const override {
574  auto type = cast<ComplexType>(adaptor.getComplex().getType());
575  auto elementType = cast<FloatType>(type.getElementType());
576  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
577  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
578 
579  Value real = b.create<complex::ReOp>(adaptor.getComplex());
580  Value imag = b.create<complex::ImOp>(adaptor.getComplex());
581 
582  Value half = b.create<arith::ConstantOp>(elementType,
583  b.getFloatAttr(elementType, 0.5));
584  Value one = b.create<arith::ConstantOp>(elementType,
585  b.getFloatAttr(elementType, 1));
586  Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf);
587  Value absRealPlusOne = b.create<math::AbsFOp>(realPlusOne, fmf);
588  Value absImag = b.create<math::AbsFOp>(imag, fmf);
589 
590  Value maxAbs = b.create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf);
591  Value minAbs = b.create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf);
592 
593  Value useReal = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
594  realPlusOne, absImag, fmf);
595  Value maxMinusOne = b.create<arith::SubFOp>(maxAbs, one, fmf);
596  Value maxAbsOfRealPlusOneAndImagMinusOne =
597  b.create<arith::SelectOp>(useReal, real, maxMinusOne);
598  Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmf);
599  Value logOfMaxAbsOfRealPlusOneAndImag =
600  b.create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf);
601  Value logOfSqrtPart = b.create<math::Log1pOp>(
602  b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmf), fmf);
603  Value r = b.create<arith::AddFOp>(
604  b.create<arith::MulFOp>(half, logOfSqrtPart, fmf),
605  logOfMaxAbsOfRealPlusOneAndImag, fmf);
606  Value resultReal = b.create<arith::SelectOp>(
607  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmf), minAbs,
608  r);
609  Value resultImag = b.create<math::Atan2Op>(imag, realPlusOne, fmf);
610  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
611  resultImag);
612  return success();
613  }
614 };
615 
616 struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
618 
620  matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
621  ConversionPatternRewriter &rewriter) const override {
622  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
623  auto type = cast<ComplexType>(adaptor.getLhs().getType());
624  auto elementType = cast<FloatType>(type.getElementType());
625  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
626  auto fmfValue = fmf.getValue();
627 
628  Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
629  Value lhsRealAbs = b.create<math::AbsFOp>(lhsReal, fmfValue);
630  Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
631  Value lhsImagAbs = b.create<math::AbsFOp>(lhsImag, fmfValue);
632  Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
633  Value rhsRealAbs = b.create<math::AbsFOp>(rhsReal, fmfValue);
634  Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
635  Value rhsImagAbs = b.create<math::AbsFOp>(rhsImag, fmfValue);
636 
637  Value lhsRealTimesRhsReal =
638  b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
639  Value lhsRealTimesRhsRealAbs =
640  b.create<math::AbsFOp>(lhsRealTimesRhsReal, fmfValue);
641  Value lhsImagTimesRhsImag =
642  b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
643  Value lhsImagTimesRhsImagAbs =
644  b.create<math::AbsFOp>(lhsImagTimesRhsImag, fmfValue);
645  Value real = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
646  lhsImagTimesRhsImag, fmfValue);
647 
648  Value lhsImagTimesRhsReal =
649  b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
650  Value lhsImagTimesRhsRealAbs =
651  b.create<math::AbsFOp>(lhsImagTimesRhsReal, fmfValue);
652  Value lhsRealTimesRhsImag =
653  b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
654  Value lhsRealTimesRhsImagAbs =
655  b.create<math::AbsFOp>(lhsRealTimesRhsImag, fmfValue);
656  Value imag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
657  lhsRealTimesRhsImag, fmfValue);
658 
659  // Handle cases where the "naive" calculation results in NaN values.
660  Value realIsNan =
661  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
662  Value imagIsNan =
663  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
664  Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan);
665 
666  Value inf = b.create<arith::ConstantOp>(
667  elementType,
668  b.getFloatAttr(elementType,
669  APFloat::getInf(elementType.getFloatSemantics())));
670 
671  // Case 1. `lhsReal` or `lhsImag` are infinite.
672  Value lhsRealIsInf =
673  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
674  Value lhsImagIsInf =
675  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
676  Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
677  Value rhsRealIsNan =
678  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
679  Value rhsImagIsNan =
680  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
681  Value zero =
682  b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
683  Value one = b.create<arith::ConstantOp>(elementType,
684  b.getFloatAttr(elementType, 1));
685  Value lhsRealIsInfFloat =
686  b.create<arith::SelectOp>(lhsRealIsInf, one, zero);
687  lhsReal = b.create<arith::SelectOp>(
688  lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
689  lhsReal);
690  Value lhsImagIsInfFloat =
691  b.create<arith::SelectOp>(lhsImagIsInf, one, zero);
692  lhsImag = b.create<arith::SelectOp>(
693  lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
694  lhsImag);
695  Value lhsIsInfAndRhsRealIsNan =
696  b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
697  rhsReal = b.create<arith::SelectOp>(
698  lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
699  rhsReal);
700  Value lhsIsInfAndRhsImagIsNan =
701  b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
702  rhsImag = b.create<arith::SelectOp>(
703  lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
704  rhsImag);
705 
706  // Case 2. `rhsReal` or `rhsImag` are infinite.
707  Value rhsRealIsInf =
708  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
709  Value rhsImagIsInf =
710  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
711  Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
712  Value lhsRealIsNan =
713  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
714  Value lhsImagIsNan =
715  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
716  Value rhsRealIsInfFloat =
717  b.create<arith::SelectOp>(rhsRealIsInf, one, zero);
718  rhsReal = b.create<arith::SelectOp>(
719  rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
720  rhsReal);
721  Value rhsImagIsInfFloat =
722  b.create<arith::SelectOp>(rhsImagIsInf, one, zero);
723  rhsImag = b.create<arith::SelectOp>(
724  rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
725  rhsImag);
726  Value rhsIsInfAndLhsRealIsNan =
727  b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
728  lhsReal = b.create<arith::SelectOp>(
729  rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
730  lhsReal);
731  Value rhsIsInfAndLhsImagIsNan =
732  b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
733  lhsImag = b.create<arith::SelectOp>(
734  rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
735  lhsImag);
736  Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf);
737 
738  // Case 3. One of the pairwise products of left hand side with right hand
739  // side is infinite.
740  Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>(
741  arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
742  Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>(
743  arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
744  Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
745  lhsImagTimesRhsImagIsInf);
746  Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>(
747  arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
748  isSpecialCase =
749  b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
750  Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>(
751  arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
752  isSpecialCase =
753  b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
754  Type i1Type = b.getI1Type();
755  Value notRecalc = b.create<arith::XOrIOp>(
756  recalc,
757  b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
758  isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc);
759  Value isSpecialCaseAndLhsRealIsNan =
760  b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
761  lhsReal = b.create<arith::SelectOp>(
762  isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
763  lhsReal);
764  Value isSpecialCaseAndLhsImagIsNan =
765  b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
766  lhsImag = b.create<arith::SelectOp>(
767  isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
768  lhsImag);
769  Value isSpecialCaseAndRhsRealIsNan =
770  b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
771  rhsReal = b.create<arith::SelectOp>(
772  isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
773  rhsReal);
774  Value isSpecialCaseAndRhsImagIsNan =
775  b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
776  rhsImag = b.create<arith::SelectOp>(
777  isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
778  rhsImag);
779  recalc = b.create<arith::OrIOp>(recalc, isSpecialCase);
780  recalc = b.create<arith::AndIOp>(isNan, recalc);
781 
782  // Recalculate real part.
783  lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
784  lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
785  Value newReal = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
786  lhsImagTimesRhsImag, fmfValue);
787  real = b.create<arith::SelectOp>(
788  recalc, b.create<arith::MulFOp>(inf, newReal, fmfValue), real);
789 
790  // Recalculate imag part.
791  lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
792  lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
793  Value newImag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
794  lhsRealTimesRhsImag, fmfValue);
795  imag = b.create<arith::SelectOp>(
796  recalc, b.create<arith::MulFOp>(inf, newImag, fmfValue), imag);
797 
798  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
799  return success();
800  }
801 };
802 
803 struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
805 
807  matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
808  ConversionPatternRewriter &rewriter) const override {
809  auto loc = op.getLoc();
810  auto type = cast<ComplexType>(adaptor.getComplex().getType());
811  auto elementType = cast<FloatType>(type.getElementType());
812 
813  Value real =
814  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
815  Value imag =
816  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
817  Value negReal = rewriter.create<arith::NegFOp>(loc, real);
818  Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
819  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
820  return success();
821  }
822 };
823 
824 struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
825  using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
826 
827  std::pair<Value, Value> combine(Location loc, Value scaledExp,
828  Value reciprocalExp, Value sin, Value cos,
829  ConversionPatternRewriter &rewriter,
830  arith::FastMathFlagsAttr fmf) const override {
831  // Complex sine is defined as;
832  // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
833  // Plugging in:
834  // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
835  // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
836  // and defining t := exp(y)
837  // We get:
838  // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
839  // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
840  Value sum =
841  rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf);
842  Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin, fmf);
843  Value diff =
844  rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf);
845  Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos, fmf);
846  return {resultReal, resultImag};
847  }
848 };
849 
850 // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
851 struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
853 
855  matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
856  ConversionPatternRewriter &rewriter) const override {
857  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
858 
859  auto type = cast<ComplexType>(op.getType());
860  auto elementType = type.getElementType().cast<FloatType>();
861  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
862 
863  auto cst = [&](APFloat v) {
864  return b.create<arith::ConstantOp>(elementType,
865  b.getFloatAttr(elementType, v));
866  };
867  const auto &floatSemantics = elementType.getFloatSemantics();
868  Value zero = cst(APFloat::getZero(floatSemantics));
869  Value half = b.create<arith::ConstantOp>(elementType,
870  b.getFloatAttr(elementType, 0.5));
871 
872  Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
873  Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
874  Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
875  Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
876  Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf);
877  Value cos = b.create<math::CosOp>(sqrtArg, fmf);
878  Value sin = b.create<math::SinOp>(sqrtArg, fmf);
879  // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply
880  // 0 * inf.
881  Value sinIsZero =
882  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf);
883 
884  Value resultReal = b.create<arith::MulFOp>(absSqrt, cos, fmf);
885  Value resultImag = b.create<arith::SelectOp>(
886  sinIsZero, zero, b.create<arith::MulFOp>(absSqrt, sin, fmf));
887  if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
888  arith::FastMathFlags::ninf)) {
889  Value inf = cst(APFloat::getInf(floatSemantics));
890  Value negInf = cst(APFloat::getInf(floatSemantics, true));
891  Value nan = cst(APFloat::getNaN(floatSemantics));
892  Value absImag = b.create<math::AbsFOp>(elementType, imag, fmf);
893 
894  Value absImagIsInf =
895  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
896  Value absImagIsNotInf =
897  b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf);
898  Value realIsInf =
899  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf);
900  Value realIsNegInf =
901  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf);
902 
903  resultReal = b.create<arith::SelectOp>(
904  b.create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero,
905  resultReal);
906  resultReal = b.create<arith::SelectOp>(
907  b.create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal);
908 
909  Value imagSignInf = b.create<math::CopySignOp>(inf, imag, fmf);
910  resultImag = b.create<arith::SelectOp>(
911  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt),
912  nan, resultImag);
913  resultImag = b.create<arith::SelectOp>(
914  b.create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf,
915  resultImag);
916  }
917 
918  Value resultIsZero =
919  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
920  resultReal = b.create<arith::SelectOp>(resultIsZero, zero, resultReal);
921  resultImag = b.create<arith::SelectOp>(resultIsZero, zero, resultImag);
922 
923  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
924  resultImag);
925  return success();
926  }
927 };
928 
929 struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
931 
933  matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
934  ConversionPatternRewriter &rewriter) const override {
935  auto type = cast<ComplexType>(adaptor.getComplex().getType());
936  auto elementType = cast<FloatType>(type.getElementType());
937  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
938  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
939 
940  Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
941  Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
942  Value zero =
943  b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
944  Value realIsZero =
945  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
946  Value imagIsZero =
947  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
948  Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
949  auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), fmf);
950  Value realSign = b.create<arith::DivFOp>(real, abs, fmf);
951  Value imagSign = b.create<arith::DivFOp>(imag, abs, fmf);
952  Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
953  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
954  adaptor.getComplex(), sign);
955  return success();
956  }
957 };
958 
959 struct TanOpConversion : public OpConversionPattern<complex::TanOp> {
961 
963  matchAndRewrite(complex::TanOp op, OpAdaptor adaptor,
964  ConversionPatternRewriter &rewriter) const override {
965  auto loc = op.getLoc();
966  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
967 
968  Value cos = rewriter.create<complex::CosOp>(loc, adaptor.getComplex(), fmf);
969  Value sin = rewriter.create<complex::SinOp>(loc, adaptor.getComplex(), fmf);
970  rewriter.replaceOpWithNewOp<complex::DivOp>(op, sin, cos, fmf);
971  return success();
972  }
973 };
974 
975 struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
977 
979  matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor,
980  ConversionPatternRewriter &rewriter) const override {
981  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
982  auto loc = op.getLoc();
983  auto type = cast<ComplexType>(adaptor.getComplex().getType());
984  auto elementType = cast<FloatType>(type.getElementType());
985  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
986  const auto &floatSemantics = elementType.getFloatSemantics();
987 
988  Value real =
989  b.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
990  Value imag =
991  b.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
992 
993  auto cst = [&](APFloat v) {
994  return b.create<arith::ConstantOp>(elementType,
995  b.getFloatAttr(elementType, v));
996  };
997  Value inf = cst(APFloat::getInf(floatSemantics));
998  Value negOne = b.create<arith::ConstantOp>(
999  elementType, b.getFloatAttr(elementType, -1.0));
1000  Value four = b.create<arith::ConstantOp>(elementType,
1001  b.getFloatAttr(elementType, 4.0));
1002  Value twoReal = b.create<arith::AddFOp>(real, real, fmf);
1003  Value negTwoReal = b.create<arith::MulFOp>(negOne, twoReal, fmf);
1004 
1005  Value expTwoRealMinusOne = b.create<math::ExpM1Op>(twoReal, fmf);
1006  Value expNegTwoRealMinusOne = b.create<math::ExpM1Op>(negTwoReal, fmf);
1007  Value realNum =
1008  b.create<arith::SubFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
1009 
1010  Value cosImag = b.create<math::CosOp>(imag, fmf);
1011  Value cosImagSq = b.create<arith::MulFOp>(cosImag, cosImag, fmf);
1012  Value twoCosTwoImagPlusOne = b.create<arith::MulFOp>(cosImagSq, four, fmf);
1013  Value sinImag = b.create<math::SinOp>(imag, fmf);
1014 
1015  Value imagNum = b.create<arith::MulFOp>(
1016  four, b.create<arith::MulFOp>(cosImag, sinImag, fmf), fmf);
1017 
1018  Value expSumMinusTwo =
1019  b.create<arith::AddFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
1020  Value denom =
1021  b.create<arith::AddFOp>(expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
1022 
1023  Value isInf = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
1024  expSumMinusTwo, inf, fmf);
1025  Value realLimit = b.create<math::CopySignOp>(negOne, real, fmf);
1026 
1027  Value resultReal = b.create<arith::SelectOp>(
1028  isInf, realLimit, b.create<arith::DivFOp>(realNum, denom, fmf));
1029  Value resultImag = b.create<arith::DivFOp>(imagNum, denom, fmf);
1030 
1031  if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1032  arith::FastMathFlags::ninf)) {
1033  Value absReal = b.create<math::AbsFOp>(real, fmf);
1034  Value zero = b.create<arith::ConstantOp>(
1035  elementType, b.getFloatAttr(elementType, 0.0));
1036  Value nan = cst(APFloat::getNaN(floatSemantics));
1037 
1038  Value absRealIsInf =
1039  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
1040  Value imagIsZero =
1041  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
1042  Value absRealIsNotInf = b.create<arith::XOrIOp>(
1043  absRealIsInf, b.create<arith::ConstantIntOp>(true, /*width=*/1));
1044 
1045  Value imagNumIsNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO,
1046  imagNum, imagNum, fmf);
1047  Value resultRealIsNaN =
1048  b.create<arith::AndIOp>(imagNumIsNaN, absRealIsNotInf);
1049  Value resultImagIsZero = b.create<arith::OrIOp>(
1050  imagIsZero, b.create<arith::AndIOp>(absRealIsInf, imagNumIsNaN));
1051 
1052  resultReal = b.create<arith::SelectOp>(resultRealIsNaN, nan, resultReal);
1053  resultImag =
1054  b.create<arith::SelectOp>(resultImagIsZero, zero, resultImag);
1055  }
1056 
1057  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
1058  resultImag);
1059  return success();
1060  }
1061 };
1062 
1063 struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
1065 
1067  matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
1068  ConversionPatternRewriter &rewriter) const override {
1069  auto loc = op.getLoc();
1070  auto type = cast<ComplexType>(adaptor.getComplex().getType());
1071  auto elementType = cast<FloatType>(type.getElementType());
1072  Value real =
1073  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
1074  Value imag =
1075  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
1076  Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag);
1077 
1078  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);
1079 
1080  return success();
1081  }
1082 };
1083 
1084 /// Converts lhs^y = (a+bi)^(c+di) to
1085 /// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
1086 /// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
1087 static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
1088  ComplexType type, Value lhs, Value c, Value d,
1089  arith::FastMathFlags fmf) {
1090  auto elementType = cast<FloatType>(type.getElementType());
1091 
1092  Value a = builder.create<complex::ReOp>(lhs);
1093  Value b = builder.create<complex::ImOp>(lhs);
1094 
1095  Value abs = builder.create<complex::AbsOp>(lhs, fmf);
1096  Value absToC = builder.create<math::PowFOp>(abs, c, fmf);
1097 
1098  Value negD = builder.create<arith::NegFOp>(d, fmf);
1099  Value argLhs = builder.create<math::Atan2Op>(b, a, fmf);
1100  Value negDArgLhs = builder.create<arith::MulFOp>(negD, argLhs, fmf);
1101  Value expNegDArgLhs = builder.create<math::ExpOp>(negDArgLhs, fmf);
1102 
1103  Value coeff = builder.create<arith::MulFOp>(absToC, expNegDArgLhs, fmf);
1104  Value lnAbs = builder.create<math::LogOp>(abs, fmf);
1105  Value cArgLhs = builder.create<arith::MulFOp>(c, argLhs, fmf);
1106  Value dLnAbs = builder.create<arith::MulFOp>(d, lnAbs, fmf);
1107  Value q = builder.create<arith::AddFOp>(cArgLhs, dLnAbs, fmf);
1108  Value cosQ = builder.create<math::CosOp>(q, fmf);
1109  Value sinQ = builder.create<math::SinOp>(q, fmf);
1110 
1111  Value inf = builder.create<arith::ConstantOp>(
1112  elementType,
1113  builder.getFloatAttr(elementType,
1114  APFloat::getInf(elementType.getFloatSemantics())));
1115  Value zero = builder.create<arith::ConstantOp>(
1116  elementType, builder.getFloatAttr(elementType, 0.0));
1117  Value one = builder.create<arith::ConstantOp>(
1118  elementType, builder.getFloatAttr(elementType, 1.0));
1119  Value complexOne = builder.create<complex::CreateOp>(type, one, zero);
1120  Value complexZero = builder.create<complex::CreateOp>(type, zero, zero);
1121  Value complexInf = builder.create<complex::CreateOp>(type, inf, zero);
1122 
1123  // Case 0:
1124  // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
1125  // Branch Cuts for Complex Elementary Functions or Much Ado About
1126  // Nothing's Sign Bit, W. Kahan, Section 10.
1127  Value absEqZero =
1128  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, abs, zero, fmf);
1129  Value dEqZero =
1130  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero, fmf);
1131  Value cEqZero =
1132  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero, fmf);
1133  Value bEqZero =
1134  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, b, zero, fmf);
1135 
1136  Value zeroLeC =
1137  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLE, zero, c, fmf);
1138  Value coeffCosQ = builder.create<arith::MulFOp>(coeff, cosQ, fmf);
1139  Value coeffSinQ = builder.create<arith::MulFOp>(coeff, sinQ, fmf);
1140  Value complexOneOrZero =
1141  builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero);
1142  Value coeffCosSin =
1143  builder.create<complex::CreateOp>(type, coeffCosQ, coeffSinQ);
1144  Value cutoff0 = builder.create<arith::SelectOp>(
1145  builder.create<arith::AndIOp>(
1146  builder.create<arith::AndIOp>(absEqZero, dEqZero), zeroLeC),
1147  complexOneOrZero, coeffCosSin);
1148 
1149  // Case 1:
1150  // x^0 is defined to be 1 for any x, see
1151  // Branch Cuts for Complex Elementary Functions or Much Ado About
1152  // Nothing's Sign Bit, W. Kahan, Section 10.
1153  Value rhsEqZero = builder.create<arith::AndIOp>(cEqZero, dEqZero);
1154  Value cutoff1 =
1155  builder.create<arith::SelectOp>(rhsEqZero, complexOne, cutoff0);
1156 
1157  // Case 2:
1158  // 1^(c + d*i) = 1 + 0*i
1159  Value lhsEqOne = builder.create<arith::AndIOp>(
1160  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf),
1161  bEqZero);
1162  Value cutoff2 =
1163  builder.create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1);
1164 
1165  // Case 3:
1166  // inf^(c + 0*i) = inf + 0*i, c > 0
1167  Value lhsEqInf = builder.create<arith::AndIOp>(
1168  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf),
1169  bEqZero);
1170  Value rhsGt0 = builder.create<arith::AndIOp>(
1171  dEqZero,
1172  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf));
1173  Value cutoff3 = builder.create<arith::SelectOp>(
1174  builder.create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2);
1175 
1176  // Case 4:
1177  // inf^(c + 0*i) = 0 + 0*i, c < 0
1178  Value rhsLt0 = builder.create<arith::AndIOp>(
1179  dEqZero,
1180  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf));
1181  Value cutoff4 = builder.create<arith::SelectOp>(
1182  builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3);
1183 
1184  return cutoff4;
1185 }
1186 
1187 struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
1189 
1191  matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
1192  ConversionPatternRewriter &rewriter) const override {
1193  mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
1194  auto type = cast<ComplexType>(adaptor.getLhs().getType());
1195  auto elementType = cast<FloatType>(type.getElementType());
1196 
1197  Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
1198  Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
1199 
1200  rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
1201  c, d, op.getFastmath())});
1202  return success();
1203  }
1204 };
1205 
1206 struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
1208 
1210  matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
1211  ConversionPatternRewriter &rewriter) const override {
1212  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1213  auto type = cast<ComplexType>(adaptor.getComplex().getType());
1214  auto elementType = cast<FloatType>(type.getElementType());
1215 
1216  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
1217 
1218  auto cst = [&](APFloat v) {
1219  return b.create<arith::ConstantOp>(elementType,
1220  b.getFloatAttr(elementType, v));
1221  };
1222  const auto &floatSemantics = elementType.getFloatSemantics();
1223  Value zero = cst(APFloat::getZero(floatSemantics));
1224  Value inf = cst(APFloat::getInf(floatSemantics));
1225  Value negHalf = b.create<arith::ConstantOp>(
1226  elementType, b.getFloatAttr(elementType, -0.5));
1227  Value nan = cst(APFloat::getNaN(floatSemantics));
1228 
1229  Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
1230  Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
1231  Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
1232  Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
1233  Value rsqrtArg = b.create<arith::MulFOp>(argArg, negHalf, fmf);
1234  Value cos = b.create<math::CosOp>(rsqrtArg, fmf);
1235  Value sin = b.create<math::SinOp>(rsqrtArg, fmf);
1236 
1237  Value resultReal = b.create<arith::MulFOp>(absRsqrt, cos, fmf);
1238  Value resultImag = b.create<arith::MulFOp>(absRsqrt, sin, fmf);
1239 
1240  if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1241  arith::FastMathFlags::ninf)) {
1242  Value negOne = b.create<arith::ConstantOp>(
1243  elementType, b.getFloatAttr(elementType, -1));
1244 
1245  Value realSignedZero = b.create<math::CopySignOp>(zero, real, fmf);
1246  Value imagSignedZero = b.create<math::CopySignOp>(zero, imag, fmf);
1247  Value negImagSignedZero =
1248  b.create<arith::MulFOp>(negOne, imagSignedZero, fmf);
1249 
1250  Value absReal = b.create<math::AbsFOp>(real, fmf);
1251  Value absImag = b.create<math::AbsFOp>(imag, fmf);
1252 
1253  Value absImagIsInf =
1254  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
1255  Value realIsNan =
1256  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf);
1257  Value realIsInf =
1258  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
1259  Value inIsNanInf = b.create<arith::AndIOp>(absImagIsInf, realIsNan);
1260 
1261  Value resultIsZero = b.create<arith::OrIOp>(inIsNanInf, realIsInf);
1262 
1263  resultReal =
1264  b.create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal);
1265  resultImag = b.create<arith::SelectOp>(resultIsZero, negImagSignedZero,
1266  resultImag);
1267  }
1268 
1269  Value isRealZero =
1270  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf);
1271  Value isImagZero =
1272  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
1273  Value isZero = b.create<arith::AndIOp>(isRealZero, isImagZero);
1274 
1275  resultReal = b.create<arith::SelectOp>(isZero, inf, resultReal);
1276  resultImag = b.create<arith::SelectOp>(isZero, nan, resultImag);
1277 
1278  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
1279  resultImag);
1280  return success();
1281  }
1282 };
1283 
1284 struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
1286 
1288  matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
1289  ConversionPatternRewriter &rewriter) const override {
1290  auto loc = op.getLoc();
1291  auto type = op.getType();
1292 
1293  Value real =
1294  rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
1295  Value imag =
1296  rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
1297 
1298  rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real);
1299 
1300  return success();
1301  }
1302 };
1303 
1304 } // namespace
1305 
1307  RewritePatternSet &patterns) {
1308  // clang-format off
1309  patterns.add<
1310  AbsOpConversion,
1311  AngleOpConversion,
1312  Atan2OpConversion,
1313  BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1314  BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
1315  ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
1316  ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
1317  ConjOpConversion,
1318  CosOpConversion,
1319  DivOpConversion,
1320  ExpOpConversion,
1321  Expm1OpConversion,
1322  Log1pOpConversion,
1323  LogOpConversion,
1324  MulOpConversion,
1325  NegOpConversion,
1326  SignOpConversion,
1327  SinOpConversion,
1328  SqrtOpConversion,
1329  TanOpConversion,
1330  TanhOpConversion,
1331  PowOpConversion,
1332  RsqrtOpConversion
1333  >(patterns.getContext());
1334  // clang-format on
1335 }
1336 
1337 namespace {
1338 struct ConvertComplexToStandardPass
1339  : public impl::ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
1340  void runOnOperation() override;
1341 };
1342 
1343 void ConvertComplexToStandardPass::runOnOperation() {
1344  // Convert to the Standard dialect using the converter defined above.
1345  RewritePatternSet patterns(&getContext());
1347 
1348  ConversionTarget target(getContext());
1349  target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
1350  target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1351  if (failed(
1352  applyPartialConversion(getOperation(), target, std::move(patterns))))
1353  signalPassFailure();
1354 }
1355 } // namespace
1356 
1358  return std::make_unique<ConvertComplexToStandardPass>();
1359 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
IntegerType getI1Type()
Definition: Builders.cpp:73
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Fraction abs(const Fraction &f)
Definition: Fraction.h:104
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
Include the generated interface declarations.
void populateComplexToStandardConversionPatterns(RewritePatternSet &patterns)
Populate the given list with patterns that convert from Complex to Standard.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::unique_ptr< Pass > createConvertComplexToStandardPass()
Create a pass to convert Complex operations to the Standard dialect.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26