MLIR  16.0.0git
ComplexToStandard.cpp
Go to the documentation of this file.
1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Pass/Pass.h"
18 #include <memory>
19 #include <type_traits>
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARD
23 #include "mlir/Conversion/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 
28 namespace {
29 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
31 
33  matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
34  ConversionPatternRewriter &rewriter) const override {
35  auto loc = op.getLoc();
36  auto type = op.getType();
37 
38  Value real =
39  rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
40  Value imag =
41  rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
42  Value realSqr = rewriter.create<arith::MulFOp>(loc, real, real);
43  Value imagSqr = rewriter.create<arith::MulFOp>(loc, imag, imag);
44  Value sqNorm = rewriter.create<arith::AddFOp>(loc, realSqr, imagSqr);
45 
46  rewriter.replaceOpWithNewOp<math::SqrtOp>(op, sqNorm);
47  return success();
48  }
49 };
50 
51 // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
52 struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
54 
56  matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
57  ConversionPatternRewriter &rewriter) const override {
58  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
59 
60  auto type = op.getType().cast<ComplexType>();
61  Type elementType = type.getElementType();
62 
63  Value lhs = adaptor.getLhs();
64  Value rhs = adaptor.getRhs();
65 
66  Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs);
67  Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs);
68  Value rhsSquaredPlusLhsSquared =
69  b.create<complex::AddOp>(type, rhsSquared, lhsSquared);
70  Value sqrtOfRhsSquaredPlusLhsSquared =
71  b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared);
72 
73  Value zero =
74  b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
75  Value one = b.create<arith::ConstantOp>(elementType,
76  b.getFloatAttr(elementType, 1));
77  Value i = b.create<complex::CreateOp>(type, zero, one);
78  Value iTimesLhs = b.create<complex::MulOp>(i, lhs);
79  Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs);
80 
81  Value divResult =
82  b.create<complex::DivOp>(rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared);
83  Value logResult = b.create<complex::LogOp>(divResult);
84 
85  Value negativeOne = b.create<arith::ConstantOp>(
86  elementType, b.getFloatAttr(elementType, -1));
87  Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
88 
89  rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult);
90  return success();
91  }
92 };
93 
94 template <typename ComparisonOp, arith::CmpFPredicate p>
95 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
97  using ResultCombiner =
99  arith::AndIOp, arith::OrIOp>;
100 
102  matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
103  ConversionPatternRewriter &rewriter) const override {
104  auto loc = op.getLoc();
105  auto type = adaptor.getLhs()
106  .getType()
107  .template cast<ComplexType>()
108  .getElementType();
109 
110  Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
111  Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
112  Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
113  Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
114  Value realComparison =
115  rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
116  Value imagComparison =
117  rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
118 
119  rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
120  imagComparison);
121  return success();
122  }
123 };
124 
125 // Default conversion which applies the BinaryStandardOp separately on the real
126 // and imaginary parts. Can for example be used for complex::AddOp and
127 // complex::SubOp.
128 template <typename BinaryComplexOp, typename BinaryStandardOp>
129 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
131 
133  matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
134  ConversionPatternRewriter &rewriter) const override {
135  auto type = adaptor.getLhs().getType().template cast<ComplexType>();
136  auto elementType = type.getElementType().template cast<FloatType>();
137  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
138 
139  Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
140  Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
141  Value resultReal =
142  b.create<BinaryStandardOp>(elementType, realLhs, realRhs);
143  Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
144  Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
145  Value resultImag =
146  b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs);
147  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
148  resultImag);
149  return success();
150  }
151 };
152 
153 template <typename TrigonometricOp>
154 struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
155  using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
156 
158 
160  matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
161  ConversionPatternRewriter &rewriter) const override {
162  auto loc = op.getLoc();
163  auto type = adaptor.getComplex().getType().template cast<ComplexType>();
164  auto elementType = type.getElementType().template cast<FloatType>();
165 
166  Value real =
167  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
168  Value imag =
169  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
170 
171  // Trigonometric ops use a set of common building blocks to convert to real
172  // ops. Here we create these building blocks and call into an op-specific
173  // implementation in the subclass to combine them.
174  Value half = rewriter.create<arith::ConstantOp>(
175  loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
176  Value exp = rewriter.create<math::ExpOp>(loc, imag);
177  Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp);
178  Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp);
179  Value sin = rewriter.create<math::SinOp>(loc, real);
180  Value cos = rewriter.create<math::CosOp>(loc, real);
181 
182  auto resultPair =
183  combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter);
184 
185  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
186  resultPair.second);
187  return success();
188  }
189 
190  virtual std::pair<Value, Value>
191  combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
192  Value cos, ConversionPatternRewriter &rewriter) const = 0;
193 };
194 
195 struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
196  using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
197 
198  std::pair<Value, Value>
199  combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
200  Value cos, ConversionPatternRewriter &rewriter) const override {
201  // Complex cosine is defined as;
202  // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
203  // Plugging in:
204  // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
205  // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
206  // and defining t := exp(y)
207  // We get:
208  // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
209  // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
210  Value sum = rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp);
211  Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos);
212  Value diff = rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp);
213  Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin);
214  return {resultReal, resultImag};
215  }
216 };
217 
218 struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
220 
222  matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
223  ConversionPatternRewriter &rewriter) const override {
224  auto loc = op.getLoc();
225  auto type = adaptor.getLhs().getType().cast<ComplexType>();
226  auto elementType = type.getElementType().cast<FloatType>();
227 
228  Value lhsReal =
229  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
230  Value lhsImag =
231  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs());
232  Value rhsReal =
233  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs());
234  Value rhsImag =
235  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs());
236 
237  // Smith's algorithm to divide complex numbers. It is just a bit smarter
238  // way to compute the following formula:
239  // (lhsReal + lhsImag * i) / (rhsReal + rhsImag * i)
240  // = (lhsReal + lhsImag * i) (rhsReal - rhsImag * i) /
241  // ((rhsReal + rhsImag * i)(rhsReal - rhsImag * i))
242  // = ((lhsReal * rhsReal + lhsImag * rhsImag) +
243  // (lhsImag * rhsReal - lhsReal * rhsImag) * i) / ||rhs||^2
244  //
245  // Depending on whether |rhsReal| < |rhsImag| we compute either
246  // rhsRealImagRatio = rhsReal / rhsImag
247  // rhsRealImagDenom = rhsImag + rhsReal * rhsRealImagRatio
248  // resultReal = (lhsReal * rhsRealImagRatio + lhsImag) / rhsRealImagDenom
249  // resultImag = (lhsImag * rhsRealImagRatio - lhsReal) / rhsRealImagDenom
250  //
251  // or
252  //
253  // rhsImagRealRatio = rhsImag / rhsReal
254  // rhsImagRealDenom = rhsReal + rhsImag * rhsImagRealRatio
255  // resultReal = (lhsReal + lhsImag * rhsImagRealRatio) / rhsImagRealDenom
256  // resultImag = (lhsImag - lhsReal * rhsImagRealRatio) / rhsImagRealDenom
257  //
258  // See https://dl.acm.org/citation.cfm?id=368661 for more details.
259  Value rhsRealImagRatio =
260  rewriter.create<arith::DivFOp>(loc, rhsReal, rhsImag);
261  Value rhsRealImagDenom = rewriter.create<arith::AddFOp>(
262  loc, rhsImag,
263  rewriter.create<arith::MulFOp>(loc, rhsRealImagRatio, rhsReal));
264  Value realNumerator1 = rewriter.create<arith::AddFOp>(
265  loc, rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealImagRatio),
266  lhsImag);
267  Value resultReal1 =
268  rewriter.create<arith::DivFOp>(loc, realNumerator1, rhsRealImagDenom);
269  Value imagNumerator1 = rewriter.create<arith::SubFOp>(
270  loc, rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealImagRatio),
271  lhsReal);
272  Value resultImag1 =
273  rewriter.create<arith::DivFOp>(loc, imagNumerator1, rhsRealImagDenom);
274 
275  Value rhsImagRealRatio =
276  rewriter.create<arith::DivFOp>(loc, rhsImag, rhsReal);
277  Value rhsImagRealDenom = rewriter.create<arith::AddFOp>(
278  loc, rhsReal,
279  rewriter.create<arith::MulFOp>(loc, rhsImagRealRatio, rhsImag));
280  Value realNumerator2 = rewriter.create<arith::AddFOp>(
281  loc, lhsReal,
282  rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagRealRatio));
283  Value resultReal2 =
284  rewriter.create<arith::DivFOp>(loc, realNumerator2, rhsImagRealDenom);
285  Value imagNumerator2 = rewriter.create<arith::SubFOp>(
286  loc, lhsImag,
287  rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagRealRatio));
288  Value resultImag2 =
289  rewriter.create<arith::DivFOp>(loc, imagNumerator2, rhsImagRealDenom);
290 
291  // Consider corner cases.
292  // Case 1. Zero denominator, numerator contains at most one NaN value.
293  Value zero = rewriter.create<arith::ConstantOp>(
294  loc, elementType, rewriter.getZeroAttr(elementType));
295  Value rhsRealAbs = rewriter.create<math::AbsFOp>(loc, rhsReal);
296  Value rhsRealIsZero = rewriter.create<arith::CmpFOp>(
297  loc, arith::CmpFPredicate::OEQ, rhsRealAbs, zero);
298  Value rhsImagAbs = rewriter.create<math::AbsFOp>(loc, rhsImag);
299  Value rhsImagIsZero = rewriter.create<arith::CmpFOp>(
300  loc, arith::CmpFPredicate::OEQ, rhsImagAbs, zero);
301  Value lhsRealIsNotNaN = rewriter.create<arith::CmpFOp>(
302  loc, arith::CmpFPredicate::ORD, lhsReal, zero);
303  Value lhsImagIsNotNaN = rewriter.create<arith::CmpFOp>(
304  loc, arith::CmpFPredicate::ORD, lhsImag, zero);
305  Value lhsContainsNotNaNValue =
306  rewriter.create<arith::OrIOp>(loc, lhsRealIsNotNaN, lhsImagIsNotNaN);
307  Value resultIsInfinity = rewriter.create<arith::AndIOp>(
308  loc, lhsContainsNotNaNValue,
309  rewriter.create<arith::AndIOp>(loc, rhsRealIsZero, rhsImagIsZero));
310  Value inf = rewriter.create<arith::ConstantOp>(
311  loc, elementType,
312  rewriter.getFloatAttr(
313  elementType, APFloat::getInf(elementType.getFloatSemantics())));
314  Value infWithSignOfRhsReal =
315  rewriter.create<math::CopySignOp>(loc, inf, rhsReal);
316  Value infinityResultReal =
317  rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsReal);
318  Value infinityResultImag =
319  rewriter.create<arith::MulFOp>(loc, infWithSignOfRhsReal, lhsImag);
320 
321  // Case 2. Infinite numerator, finite denominator.
322  Value rhsRealFinite = rewriter.create<arith::CmpFOp>(
323  loc, arith::CmpFPredicate::ONE, rhsRealAbs, inf);
324  Value rhsImagFinite = rewriter.create<arith::CmpFOp>(
325  loc, arith::CmpFPredicate::ONE, rhsImagAbs, inf);
326  Value rhsFinite =
327  rewriter.create<arith::AndIOp>(loc, rhsRealFinite, rhsImagFinite);
328  Value lhsRealAbs = rewriter.create<math::AbsFOp>(loc, lhsReal);
329  Value lhsRealInfinite = rewriter.create<arith::CmpFOp>(
330  loc, arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
331  Value lhsImagAbs = rewriter.create<math::AbsFOp>(loc, lhsImag);
332  Value lhsImagInfinite = rewriter.create<arith::CmpFOp>(
333  loc, arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
334  Value lhsInfinite =
335  rewriter.create<arith::OrIOp>(loc, lhsRealInfinite, lhsImagInfinite);
336  Value infNumFiniteDenom =
337  rewriter.create<arith::AndIOp>(loc, lhsInfinite, rhsFinite);
338  Value one = rewriter.create<arith::ConstantOp>(
339  loc, elementType, rewriter.getFloatAttr(elementType, 1));
340  Value lhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
341  loc, rewriter.create<arith::SelectOp>(loc, lhsRealInfinite, one, zero),
342  lhsReal);
343  Value lhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
344  loc, rewriter.create<arith::SelectOp>(loc, lhsImagInfinite, one, zero),
345  lhsImag);
346  Value lhsRealIsInfWithSignTimesRhsReal =
347  rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsReal);
348  Value lhsImagIsInfWithSignTimesRhsImag =
349  rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsImag);
350  Value resultReal3 = rewriter.create<arith::MulFOp>(
351  loc, inf,
352  rewriter.create<arith::AddFOp>(loc, lhsRealIsInfWithSignTimesRhsReal,
353  lhsImagIsInfWithSignTimesRhsImag));
354  Value lhsRealIsInfWithSignTimesRhsImag =
355  rewriter.create<arith::MulFOp>(loc, lhsRealIsInfWithSign, rhsImag);
356  Value lhsImagIsInfWithSignTimesRhsReal =
357  rewriter.create<arith::MulFOp>(loc, lhsImagIsInfWithSign, rhsReal);
358  Value resultImag3 = rewriter.create<arith::MulFOp>(
359  loc, inf,
360  rewriter.create<arith::SubFOp>(loc, lhsImagIsInfWithSignTimesRhsReal,
361  lhsRealIsInfWithSignTimesRhsImag));
362 
363  // Case 3: Finite numerator, infinite denominator.
364  Value lhsRealFinite = rewriter.create<arith::CmpFOp>(
365  loc, arith::CmpFPredicate::ONE, lhsRealAbs, inf);
366  Value lhsImagFinite = rewriter.create<arith::CmpFOp>(
367  loc, arith::CmpFPredicate::ONE, lhsImagAbs, inf);
368  Value lhsFinite =
369  rewriter.create<arith::AndIOp>(loc, lhsRealFinite, lhsImagFinite);
370  Value rhsRealInfinite = rewriter.create<arith::CmpFOp>(
371  loc, arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
372  Value rhsImagInfinite = rewriter.create<arith::CmpFOp>(
373  loc, arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
374  Value rhsInfinite =
375  rewriter.create<arith::OrIOp>(loc, rhsRealInfinite, rhsImagInfinite);
376  Value finiteNumInfiniteDenom =
377  rewriter.create<arith::AndIOp>(loc, lhsFinite, rhsInfinite);
378  Value rhsRealIsInfWithSign = rewriter.create<math::CopySignOp>(
379  loc, rewriter.create<arith::SelectOp>(loc, rhsRealInfinite, one, zero),
380  rhsReal);
381  Value rhsImagIsInfWithSign = rewriter.create<math::CopySignOp>(
382  loc, rewriter.create<arith::SelectOp>(loc, rhsImagInfinite, one, zero),
383  rhsImag);
384  Value rhsRealIsInfWithSignTimesLhsReal =
385  rewriter.create<arith::MulFOp>(loc, lhsReal, rhsRealIsInfWithSign);
386  Value rhsImagIsInfWithSignTimesLhsImag =
387  rewriter.create<arith::MulFOp>(loc, lhsImag, rhsImagIsInfWithSign);
388  Value resultReal4 = rewriter.create<arith::MulFOp>(
389  loc, zero,
390  rewriter.create<arith::AddFOp>(loc, rhsRealIsInfWithSignTimesLhsReal,
391  rhsImagIsInfWithSignTimesLhsImag));
392  Value rhsRealIsInfWithSignTimesLhsImag =
393  rewriter.create<arith::MulFOp>(loc, lhsImag, rhsRealIsInfWithSign);
394  Value rhsImagIsInfWithSignTimesLhsReal =
395  rewriter.create<arith::MulFOp>(loc, lhsReal, rhsImagIsInfWithSign);
396  Value resultImag4 = rewriter.create<arith::MulFOp>(
397  loc, zero,
398  rewriter.create<arith::SubFOp>(loc, rhsRealIsInfWithSignTimesLhsImag,
399  rhsImagIsInfWithSignTimesLhsReal));
400 
401  Value realAbsSmallerThanImagAbs = rewriter.create<arith::CmpFOp>(
402  loc, arith::CmpFPredicate::OLT, rhsRealAbs, rhsImagAbs);
403  Value resultReal = rewriter.create<arith::SelectOp>(
404  loc, realAbsSmallerThanImagAbs, resultReal1, resultReal2);
405  Value resultImag = rewriter.create<arith::SelectOp>(
406  loc, realAbsSmallerThanImagAbs, resultImag1, resultImag2);
407  Value resultRealSpecialCase3 = rewriter.create<arith::SelectOp>(
408  loc, finiteNumInfiniteDenom, resultReal4, resultReal);
409  Value resultImagSpecialCase3 = rewriter.create<arith::SelectOp>(
410  loc, finiteNumInfiniteDenom, resultImag4, resultImag);
411  Value resultRealSpecialCase2 = rewriter.create<arith::SelectOp>(
412  loc, infNumFiniteDenom, resultReal3, resultRealSpecialCase3);
413  Value resultImagSpecialCase2 = rewriter.create<arith::SelectOp>(
414  loc, infNumFiniteDenom, resultImag3, resultImagSpecialCase3);
415  Value resultRealSpecialCase1 = rewriter.create<arith::SelectOp>(
416  loc, resultIsInfinity, infinityResultReal, resultRealSpecialCase2);
417  Value resultImagSpecialCase1 = rewriter.create<arith::SelectOp>(
418  loc, resultIsInfinity, infinityResultImag, resultImagSpecialCase2);
419 
420  Value resultRealIsNaN = rewriter.create<arith::CmpFOp>(
421  loc, arith::CmpFPredicate::UNO, resultReal, zero);
422  Value resultImagIsNaN = rewriter.create<arith::CmpFOp>(
423  loc, arith::CmpFPredicate::UNO, resultImag, zero);
424  Value resultIsNaN =
425  rewriter.create<arith::AndIOp>(loc, resultRealIsNaN, resultImagIsNaN);
426  Value resultRealWithSpecialCases = rewriter.create<arith::SelectOp>(
427  loc, resultIsNaN, resultRealSpecialCase1, resultReal);
428  Value resultImagWithSpecialCases = rewriter.create<arith::SelectOp>(
429  loc, resultIsNaN, resultImagSpecialCase1, resultImag);
430 
431  rewriter.replaceOpWithNewOp<complex::CreateOp>(
432  op, type, resultRealWithSpecialCases, resultImagWithSpecialCases);
433  return success();
434  }
435 };
436 
437 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
439 
441  matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
442  ConversionPatternRewriter &rewriter) const override {
443  auto loc = op.getLoc();
444  auto type = adaptor.getComplex().getType().cast<ComplexType>();
445  auto elementType = type.getElementType().cast<FloatType>();
446 
447  Value real =
448  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
449  Value imag =
450  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
451  Value expReal = rewriter.create<math::ExpOp>(loc, real);
452  Value cosImag = rewriter.create<math::CosOp>(loc, imag);
453  Value resultReal = rewriter.create<arith::MulFOp>(loc, expReal, cosImag);
454  Value sinImag = rewriter.create<math::SinOp>(loc, imag);
455  Value resultImag = rewriter.create<arith::MulFOp>(loc, expReal, sinImag);
456 
457  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
458  resultImag);
459  return success();
460  }
461 };
462 
463 struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
465 
467  matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
468  ConversionPatternRewriter &rewriter) const override {
469  auto type = adaptor.getComplex().getType().cast<ComplexType>();
470  auto elementType = type.getElementType().cast<FloatType>();
471 
472  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
473  Value exp = b.create<complex::ExpOp>(adaptor.getComplex());
474 
475  Value real = b.create<complex::ReOp>(elementType, exp);
476  Value one = b.create<arith::ConstantOp>(elementType,
477  b.getFloatAttr(elementType, 1));
478  Value realMinusOne = b.create<arith::SubFOp>(real, one);
479  Value imag = b.create<complex::ImOp>(elementType, exp);
480 
481  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realMinusOne,
482  imag);
483  return success();
484  }
485 };
486 
487 struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
489 
491  matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
492  ConversionPatternRewriter &rewriter) const override {
493  auto type = adaptor.getComplex().getType().cast<ComplexType>();
494  auto elementType = type.getElementType().cast<FloatType>();
495  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
496 
497  Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
498  Value resultReal = b.create<math::LogOp>(elementType, abs);
499  Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
500  Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
501  Value resultImag = b.create<math::Atan2Op>(elementType, imag, real);
502  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
503  resultImag);
504  return success();
505  }
506 };
507 
508 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
510 
512  matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
513  ConversionPatternRewriter &rewriter) const override {
514  auto type = adaptor.getComplex().getType().cast<ComplexType>();
515  auto elementType = type.getElementType().cast<FloatType>();
516  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
517 
518  Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
519  Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
520 
521  Value half = b.create<arith::ConstantOp>(elementType,
522  b.getFloatAttr(elementType, 0.5));
523  Value one = b.create<arith::ConstantOp>(elementType,
524  b.getFloatAttr(elementType, 1));
525  Value two = b.create<arith::ConstantOp>(elementType,
526  b.getFloatAttr(elementType, 2));
527 
528  // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
529  // log((a+1)+bi) = .5*log(a*a + 2*a + 1 + b*b) + i*atan2(b, a+1)
530  // log((a+1)+bi) = .5*log1p(a*a + 2*a + b*b) + i*atan2(b, a+1)
531  Value sumSq = b.create<arith::MulFOp>(real, real);
532  sumSq = b.create<arith::AddFOp>(sumSq, b.create<arith::MulFOp>(real, two));
533  sumSq = b.create<arith::AddFOp>(sumSq, b.create<arith::MulFOp>(imag, imag));
534  Value logSumSq = b.create<math::Log1pOp>(elementType, sumSq);
535  Value resultReal = b.create<arith::MulFOp>(logSumSq, half);
536 
537  Value realPlusOne = b.create<arith::AddFOp>(real, one);
538 
539  Value resultImag = b.create<math::Atan2Op>(elementType, imag, realPlusOne);
540  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
541  resultImag);
542  return success();
543  }
544 };
545 
546 struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
548 
550  matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
551  ConversionPatternRewriter &rewriter) const override {
552  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
553  auto type = adaptor.getLhs().getType().cast<ComplexType>();
554  auto elementType = type.getElementType().cast<FloatType>();
555 
556  Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
557  Value lhsRealAbs = b.create<math::AbsFOp>(lhsReal);
558  Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
559  Value lhsImagAbs = b.create<math::AbsFOp>(lhsImag);
560  Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
561  Value rhsRealAbs = b.create<math::AbsFOp>(rhsReal);
562  Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
563  Value rhsImagAbs = b.create<math::AbsFOp>(rhsImag);
564 
565  Value lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
566  Value lhsRealTimesRhsRealAbs = b.create<math::AbsFOp>(lhsRealTimesRhsReal);
567  Value lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
568  Value lhsImagTimesRhsImagAbs = b.create<math::AbsFOp>(lhsImagTimesRhsImag);
569  Value real =
570  b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
571 
572  Value lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
573  Value lhsImagTimesRhsRealAbs = b.create<math::AbsFOp>(lhsImagTimesRhsReal);
574  Value lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
575  Value lhsRealTimesRhsImagAbs = b.create<math::AbsFOp>(lhsRealTimesRhsImag);
576  Value imag =
577  b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
578 
579  // Handle cases where the "naive" calculation results in NaN values.
580  Value realIsNan =
581  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real);
582  Value imagIsNan =
583  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, imag, imag);
584  Value isNan = b.create<arith::AndIOp>(realIsNan, imagIsNan);
585 
586  Value inf = b.create<arith::ConstantOp>(
587  elementType,
588  b.getFloatAttr(elementType,
589  APFloat::getInf(elementType.getFloatSemantics())));
590 
591  // Case 1. `lhsReal` or `lhsImag` are infinite.
592  Value lhsRealIsInf =
593  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsRealAbs, inf);
594  Value lhsImagIsInf =
595  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhsImagAbs, inf);
596  Value lhsIsInf = b.create<arith::OrIOp>(lhsRealIsInf, lhsImagIsInf);
597  Value rhsRealIsNan =
598  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsReal, rhsReal);
599  Value rhsImagIsNan =
600  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, rhsImag, rhsImag);
601  Value zero =
602  b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
603  Value one = b.create<arith::ConstantOp>(elementType,
604  b.getFloatAttr(elementType, 1));
605  Value lhsRealIsInfFloat =
606  b.create<arith::SelectOp>(lhsRealIsInf, one, zero);
607  lhsReal = b.create<arith::SelectOp>(
608  lhsIsInf, b.create<math::CopySignOp>(lhsRealIsInfFloat, lhsReal),
609  lhsReal);
610  Value lhsImagIsInfFloat =
611  b.create<arith::SelectOp>(lhsImagIsInf, one, zero);
612  lhsImag = b.create<arith::SelectOp>(
613  lhsIsInf, b.create<math::CopySignOp>(lhsImagIsInfFloat, lhsImag),
614  lhsImag);
615  Value lhsIsInfAndRhsRealIsNan =
616  b.create<arith::AndIOp>(lhsIsInf, rhsRealIsNan);
617  rhsReal = b.create<arith::SelectOp>(
618  lhsIsInfAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
619  rhsReal);
620  Value lhsIsInfAndRhsImagIsNan =
621  b.create<arith::AndIOp>(lhsIsInf, rhsImagIsNan);
622  rhsImag = b.create<arith::SelectOp>(
623  lhsIsInfAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
624  rhsImag);
625 
626  // Case 2. `rhsReal` or `rhsImag` are infinite.
627  Value rhsRealIsInf =
628  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsRealAbs, inf);
629  Value rhsImagIsInf =
630  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, rhsImagAbs, inf);
631  Value rhsIsInf = b.create<arith::OrIOp>(rhsRealIsInf, rhsImagIsInf);
632  Value lhsRealIsNan =
633  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsReal, lhsReal);
634  Value lhsImagIsNan =
635  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, lhsImag, lhsImag);
636  Value rhsRealIsInfFloat =
637  b.create<arith::SelectOp>(rhsRealIsInf, one, zero);
638  rhsReal = b.create<arith::SelectOp>(
639  rhsIsInf, b.create<math::CopySignOp>(rhsRealIsInfFloat, rhsReal),
640  rhsReal);
641  Value rhsImagIsInfFloat =
642  b.create<arith::SelectOp>(rhsImagIsInf, one, zero);
643  rhsImag = b.create<arith::SelectOp>(
644  rhsIsInf, b.create<math::CopySignOp>(rhsImagIsInfFloat, rhsImag),
645  rhsImag);
646  Value rhsIsInfAndLhsRealIsNan =
647  b.create<arith::AndIOp>(rhsIsInf, lhsRealIsNan);
648  lhsReal = b.create<arith::SelectOp>(
649  rhsIsInfAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
650  lhsReal);
651  Value rhsIsInfAndLhsImagIsNan =
652  b.create<arith::AndIOp>(rhsIsInf, lhsImagIsNan);
653  lhsImag = b.create<arith::SelectOp>(
654  rhsIsInfAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
655  lhsImag);
656  Value recalc = b.create<arith::OrIOp>(lhsIsInf, rhsIsInf);
657 
658  // Case 3. One of the pairwise products of left hand side with right hand
659  // side is infinite.
660  Value lhsRealTimesRhsRealIsInf = b.create<arith::CmpFOp>(
661  arith::CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
662  Value lhsImagTimesRhsImagIsInf = b.create<arith::CmpFOp>(
663  arith::CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
664  Value isSpecialCase = b.create<arith::OrIOp>(lhsRealTimesRhsRealIsInf,
665  lhsImagTimesRhsImagIsInf);
666  Value lhsRealTimesRhsImagIsInf = b.create<arith::CmpFOp>(
667  arith::CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
668  isSpecialCase =
669  b.create<arith::OrIOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
670  Value lhsImagTimesRhsRealIsInf = b.create<arith::CmpFOp>(
671  arith::CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
672  isSpecialCase =
673  b.create<arith::OrIOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
674  Type i1Type = b.getI1Type();
675  Value notRecalc = b.create<arith::XOrIOp>(
676  recalc,
677  b.create<arith::ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
678  isSpecialCase = b.create<arith::AndIOp>(isSpecialCase, notRecalc);
679  Value isSpecialCaseAndLhsRealIsNan =
680  b.create<arith::AndIOp>(isSpecialCase, lhsRealIsNan);
681  lhsReal = b.create<arith::SelectOp>(
682  isSpecialCaseAndLhsRealIsNan, b.create<math::CopySignOp>(zero, lhsReal),
683  lhsReal);
684  Value isSpecialCaseAndLhsImagIsNan =
685  b.create<arith::AndIOp>(isSpecialCase, lhsImagIsNan);
686  lhsImag = b.create<arith::SelectOp>(
687  isSpecialCaseAndLhsImagIsNan, b.create<math::CopySignOp>(zero, lhsImag),
688  lhsImag);
689  Value isSpecialCaseAndRhsRealIsNan =
690  b.create<arith::AndIOp>(isSpecialCase, rhsRealIsNan);
691  rhsReal = b.create<arith::SelectOp>(
692  isSpecialCaseAndRhsRealIsNan, b.create<math::CopySignOp>(zero, rhsReal),
693  rhsReal);
694  Value isSpecialCaseAndRhsImagIsNan =
695  b.create<arith::AndIOp>(isSpecialCase, rhsImagIsNan);
696  rhsImag = b.create<arith::SelectOp>(
697  isSpecialCaseAndRhsImagIsNan, b.create<math::CopySignOp>(zero, rhsImag),
698  rhsImag);
699  recalc = b.create<arith::OrIOp>(recalc, isSpecialCase);
700  recalc = b.create<arith::AndIOp>(isNan, recalc);
701 
702  // Recalculate real part.
703  lhsRealTimesRhsReal = b.create<arith::MulFOp>(lhsReal, rhsReal);
704  lhsImagTimesRhsImag = b.create<arith::MulFOp>(lhsImag, rhsImag);
705  Value newReal =
706  b.create<arith::SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
707  real = b.create<arith::SelectOp>(
708  recalc, b.create<arith::MulFOp>(inf, newReal), real);
709 
710  // Recalculate imag part.
711  lhsImagTimesRhsReal = b.create<arith::MulFOp>(lhsImag, rhsReal);
712  lhsRealTimesRhsImag = b.create<arith::MulFOp>(lhsReal, rhsImag);
713  Value newImag =
714  b.create<arith::AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
715  imag = b.create<arith::SelectOp>(
716  recalc, b.create<arith::MulFOp>(inf, newImag), imag);
717 
718  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
719  return success();
720  }
721 };
722 
723 struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
725 
727  matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
728  ConversionPatternRewriter &rewriter) const override {
729  auto loc = op.getLoc();
730  auto type = adaptor.getComplex().getType().cast<ComplexType>();
731  auto elementType = type.getElementType().cast<FloatType>();
732 
733  Value real =
734  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
735  Value imag =
736  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
737  Value negReal = rewriter.create<arith::NegFOp>(loc, real);
738  Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
739  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
740  return success();
741  }
742 };
743 
744 struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
745  using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
746 
747  std::pair<Value, Value>
748  combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
749  Value cos, ConversionPatternRewriter &rewriter) const override {
750  // Complex sine is defined as;
751  // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
752  // Plugging in:
753  // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
754  // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
755  // and defining t := exp(y)
756  // We get:
757  // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
758  // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
759  Value sum = rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp);
760  Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin);
761  Value diff = rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp);
762  Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos);
763  return {resultReal, resultImag};
764  }
765 };
766 
767 // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
768 struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
770 
772  matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
773  ConversionPatternRewriter &rewriter) const override {
774  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
775 
776  auto type = op.getType().cast<ComplexType>();
777  Type elementType = type.getElementType();
778  Value arg = adaptor.getComplex();
779 
780  Value zero =
781  b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
782 
783  Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
784  Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
785 
786  Value absLhs = b.create<math::AbsFOp>(real);
787  Value absArg = b.create<complex::AbsOp>(elementType, arg);
788  Value addAbs = b.create<arith::AddFOp>(absLhs, absArg);
789 
790  Value half = b.create<arith::ConstantOp>(elementType,
791  b.getFloatAttr(elementType, 0.5));
792  Value halfAddAbs = b.create<arith::MulFOp>(addAbs, half);
793  Value sqrtAddAbs = b.create<math::SqrtOp>(halfAddAbs);
794 
795  Value realIsNegative =
796  b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero);
797  Value imagIsNegative =
798  b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, imag, zero);
799 
800  Value resultReal = sqrtAddAbs;
801 
802  Value imagDivTwoResultReal = b.create<arith::DivFOp>(
803  imag, b.create<arith::AddFOp>(resultReal, resultReal));
804 
805  Value negativeResultReal = b.create<arith::NegFOp>(resultReal);
806 
807  Value resultImag = b.create<arith::SelectOp>(
808  realIsNegative,
809  b.create<arith::SelectOp>(imagIsNegative, negativeResultReal,
810  resultReal),
811  imagDivTwoResultReal);
812 
813  resultReal = b.create<arith::SelectOp>(
814  realIsNegative,
815  b.create<arith::DivFOp>(
816  imag, b.create<arith::AddFOp>(resultImag, resultImag)),
817  resultReal);
818 
819  Value realIsZero =
820  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
821  Value imagIsZero =
822  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
823  Value argIsZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
824 
825  resultReal = b.create<arith::SelectOp>(argIsZero, zero, resultReal);
826  resultImag = b.create<arith::SelectOp>(argIsZero, zero, resultImag);
827 
828  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
829  resultImag);
830  return success();
831  }
832 };
833 
834 struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
836 
838  matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
839  ConversionPatternRewriter &rewriter) const override {
840  auto type = adaptor.getComplex().getType().cast<ComplexType>();
841  auto elementType = type.getElementType().cast<FloatType>();
842  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
843 
844  Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
845  Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
846  Value zero =
847  b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
848  Value realIsZero =
849  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
850  Value imagIsZero =
851  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
852  Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
853  auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex());
854  Value realSign = b.create<arith::DivFOp>(real, abs);
855  Value imagSign = b.create<arith::DivFOp>(imag, abs);
856  Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
857  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
858  adaptor.getComplex(), sign);
859  return success();
860  }
861 };
862 
863 struct TanOpConversion : public OpConversionPattern<complex::TanOp> {
865 
867  matchAndRewrite(complex::TanOp op, OpAdaptor adaptor,
868  ConversionPatternRewriter &rewriter) const override {
869  auto loc = op.getLoc();
870  Value cos = rewriter.create<complex::CosOp>(loc, adaptor.getComplex());
871  Value sin = rewriter.create<complex::SinOp>(loc, adaptor.getComplex());
872  rewriter.replaceOpWithNewOp<complex::DivOp>(op, sin, cos);
873  return success();
874  }
875 };
876 
877 struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
879 
881  matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor,
882  ConversionPatternRewriter &rewriter) const override {
883  auto loc = op.getLoc();
884  auto type = adaptor.getComplex().getType().cast<ComplexType>();
885  auto elementType = type.getElementType().cast<FloatType>();
886 
887  // The hyperbolic tangent for complex number can be calculated as follows.
888  // tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y))
889  // See: https://proofwiki.org/wiki/Hyperbolic_Tangent_of_Complex_Number
890  Value real =
891  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
892  Value imag =
893  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
894  Value tanhA = rewriter.create<math::TanhOp>(loc, real);
895  Value cosB = rewriter.create<math::CosOp>(loc, imag);
896  Value sinB = rewriter.create<math::SinOp>(loc, imag);
897  Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB);
898  Value numerator =
899  rewriter.create<complex::CreateOp>(loc, type, tanhA, tanB);
900  Value one = rewriter.create<arith::ConstantOp>(
901  loc, elementType, rewriter.getFloatAttr(elementType, 1));
902  Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB);
903  Value denominator = rewriter.create<complex::CreateOp>(loc, type, one, mul);
904  rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator);
905  return success();
906  }
907 };
908 
909 struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
911 
913  matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
914  ConversionPatternRewriter &rewriter) const override {
915  auto loc = op.getLoc();
916  auto type = adaptor.getComplex().getType().cast<ComplexType>();
917  auto elementType = type.getElementType().cast<FloatType>();
918  Value real =
919  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
920  Value imag =
921  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
922  Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag);
923 
924  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);
925 
926  return success();
927  }
928 };
929 
930 /// Coverts x^y = (a+bi)^(c+di) to
931 /// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
932 /// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
933 static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
934  ComplexType type, Value a, Value b, Value c,
935  Value d) {
936  auto elementType = type.getElementType().cast<FloatType>();
937 
938  // Compute (a*a+b*b)^(0.5c).
939  Value aaPbb = builder.create<arith::AddFOp>(
940  builder.create<arith::MulFOp>(a, a), builder.create<arith::MulFOp>(b, b));
941  Value half = builder.create<arith::ConstantOp>(
942  elementType, builder.getFloatAttr(elementType, 0.5));
943  Value halfC = builder.create<arith::MulFOp>(half, c);
944  Value aaPbbTohalfC = builder.create<math::PowFOp>(aaPbb, halfC);
945 
946  // Compute exp(-d*atan2(b,a)).
947  Value negD = builder.create<arith::NegFOp>(d);
948  Value argX = builder.create<math::Atan2Op>(b, a);
949  Value negDArgX = builder.create<arith::MulFOp>(negD, argX);
950  Value eToNegDArgX = builder.create<math::ExpOp>(negDArgX);
951 
952  // Compute (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)).
953  Value coeff = builder.create<arith::MulFOp>(aaPbbTohalfC, eToNegDArgX);
954 
955  // Compute c*atan2(b,a)+0.5d*ln(a*a+b*b).
956  Value lnAaPbb = builder.create<math::LogOp>(aaPbb);
957  Value halfD = builder.create<arith::MulFOp>(half, d);
958  Value q = builder.create<arith::AddFOp>(
959  builder.create<arith::MulFOp>(c, argX),
960  builder.create<arith::MulFOp>(halfD, lnAaPbb));
961 
962  Value cosQ = builder.create<math::CosOp>(q);
963  Value sinQ = builder.create<math::SinOp>(q);
964  Value zero = builder.create<arith::ConstantOp>(
965  elementType, builder.getFloatAttr(elementType, 0));
966  Value one = builder.create<arith::ConstantOp>(
967  elementType, builder.getFloatAttr(elementType, 1));
968 
969  Value xEqZero =
970  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, aaPbb, zero);
971  Value yGeZero = builder.create<arith::AndIOp>(
972  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, c, zero),
973  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero));
974  Value cEqZero =
975  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero);
976  Value complexZero = builder.create<complex::CreateOp>(type, zero, zero);
977  Value complexOne = builder.create<complex::CreateOp>(type, one, zero);
978  Value complexOther = builder.create<complex::CreateOp>(
979  type, builder.create<arith::MulFOp>(coeff, cosQ),
980  builder.create<arith::MulFOp>(coeff, sinQ));
981 
982  // x^y is 0 if x is 0 and y > 0. 0^0 is defined to be 1.0, see
983  // Branch Cuts for Complex Elementary Functions or Much Ado About
984  // Nothing's Sign Bit, W. Kahan, Section 10.
985  return builder.create<arith::SelectOp>(
986  builder.create<arith::AndIOp>(xEqZero, yGeZero),
987  builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero),
988  complexOther);
989 }
990 
991 struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
993 
995  matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
996  ConversionPatternRewriter &rewriter) const override {
997  mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
998  auto type = adaptor.getLhs().getType().cast<ComplexType>();
999  auto elementType = type.getElementType().cast<FloatType>();
1000 
1001  Value a = builder.create<complex::ReOp>(elementType, adaptor.getLhs());
1002  Value b = builder.create<complex::ImOp>(elementType, adaptor.getLhs());
1003  Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
1004  Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
1005 
1006  rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)});
1007  return success();
1008  }
1009 };
1010 
1011 struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
1013 
1015  matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
1016  ConversionPatternRewriter &rewriter) const override {
1017  mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
1018  auto type = adaptor.getComplex().getType().cast<ComplexType>();
1019  auto elementType = type.getElementType().cast<FloatType>();
1020 
1021  Value a = builder.create<complex::ReOp>(elementType, adaptor.getComplex());
1022  Value b = builder.create<complex::ImOp>(elementType, adaptor.getComplex());
1023  Value c = builder.create<arith::ConstantOp>(
1024  elementType, builder.getFloatAttr(elementType, -0.5));
1025  Value d = builder.create<arith::ConstantOp>(
1026  elementType, builder.getFloatAttr(elementType, 0));
1027 
1028  rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)});
1029  return success();
1030  }
1031 };
1032 
1033 struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
1035 
1037  matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
1038  ConversionPatternRewriter &rewriter) const override {
1039  auto loc = op.getLoc();
1040  auto type = op.getType();
1041 
1042  Value real =
1043  rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
1044  Value imag =
1045  rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
1046 
1047  rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real);
1048 
1049  return success();
1050  }
1051 };
1052 
1053 } // namespace
1054 
1056  RewritePatternSet &patterns) {
1057  // clang-format off
1058  patterns.add<
1059  AbsOpConversion,
1060  AngleOpConversion,
1061  Atan2OpConversion,
1062  BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1063  BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
1064  ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
1065  ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
1066  ConjOpConversion,
1067  CosOpConversion,
1068  DivOpConversion,
1069  ExpOpConversion,
1070  Expm1OpConversion,
1071  Log1pOpConversion,
1072  LogOpConversion,
1073  MulOpConversion,
1074  NegOpConversion,
1075  SignOpConversion,
1076  SinOpConversion,
1077  SqrtOpConversion,
1078  TanOpConversion,
1079  TanhOpConversion,
1080  PowOpConversion,
1081  RsqrtOpConversion
1082  >(patterns.getContext());
1083  // clang-format on
1084 }
1085 
1086 namespace {
1087 struct ConvertComplexToStandardPass
1088  : public impl::ConvertComplexToStandardBase<ConvertComplexToStandardPass> {
1089  void runOnOperation() override;
1090 };
1091 
1092 void ConvertComplexToStandardPass::runOnOperation() {
1093  // Convert to the Standard dialect using the converter defined above.
1094  RewritePatternSet patterns(&getContext());
1096 
1097  ConversionTarget target(getContext());
1098  target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
1099  target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1100  if (failed(
1101  applyPartialConversion(getOperation(), target, std::move(patterns))))
1102  signalPassFailure();
1103 }
1104 } // namespace
1105 
1107  return std::make_unique<ConvertComplexToStandardPass>();
1108 }
Include the generated interface declarations.
void addLegalOp(OperationName op)
Register the given operations as legal.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
Definition: MPInt.h:370
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:302
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:694
typename SourceOp::Adaptor OpAdaptor
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:87
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:231
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:418
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
std::unique_ptr< Pass > createConvertComplexToStandardPass()
Create a pass to convert Complex operations to the Standard dialect.
void populateComplexToStandardConversionPatterns(RewritePatternSet &patterns)
Populate the given list with patterns that convert from Complex to Standard.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
ImplicitLocOpBuilder maintains a &#39;current location&#39;, allowing use of the create<> method without spec...
This class implements a pattern rewriter for use with ConversionPatterns.
U cast() const
Definition: Value.h:108
This class describes a specific conversion target.
static bool isZero(OpFoldResult v)
Definition: Tiling.cpp:43
MLIRContext * getContext() const