MLIR  21.0.0git
ComplexToStandard.cpp
Go to the documentation of this file.
1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/Pass/Pass.h"
19 #include <memory>
20 #include <type_traits>
21 
22 namespace mlir {
23 #define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARDPASS
24 #include "mlir/Conversion/Passes.h.inc"
25 } // namespace mlir
26 
27 using namespace mlir;
28 
29 namespace {
30 
31 enum class AbsFn { abs, sqrt, rsqrt };
32 
33 // Returns the absolute value, its square root or its reciprocal square root.
34 Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
35  ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) {
36  Value one = b.create<arith::ConstantOp>(real.getType(),
37  b.getFloatAttr(real.getType(), 1.0));
38 
39  Value absReal = b.create<math::AbsFOp>(real, fmf);
40  Value absImag = b.create<math::AbsFOp>(imag, fmf);
41 
42  Value max = b.create<arith::MaximumFOp>(absReal, absImag, fmf);
43  Value min = b.create<arith::MinimumFOp>(absReal, absImag, fmf);
44 
45  // The lowering below requires NaNs and infinities to work correctly.
46  arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
47  fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
48  Value ratio = b.create<arith::DivFOp>(min, max, fmfWithNaNInf);
49  Value ratioSq = b.create<arith::MulFOp>(ratio, ratio, fmfWithNaNInf);
50  Value ratioSqPlusOne = b.create<arith::AddFOp>(ratioSq, one, fmfWithNaNInf);
51  Value result;
52 
53  if (fn == AbsFn::rsqrt) {
54  ratioSqPlusOne = b.create<math::RsqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
55  min = b.create<math::RsqrtOp>(min, fmfWithNaNInf);
56  max = b.create<math::RsqrtOp>(max, fmfWithNaNInf);
57  }
58 
59  if (fn == AbsFn::sqrt) {
60  Value quarter = b.create<arith::ConstantOp>(
61  real.getType(), b.getFloatAttr(real.getType(), 0.25));
62  // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
63  Value sqrt = b.create<math::SqrtOp>(max, fmfWithNaNInf);
64  Value p025 = b.create<math::PowFOp>(ratioSqPlusOne, quarter, fmfWithNaNInf);
65  result = b.create<arith::MulFOp>(sqrt, p025, fmfWithNaNInf);
66  } else {
67  Value sqrt = b.create<math::SqrtOp>(ratioSqPlusOne, fmfWithNaNInf);
68  result = b.create<arith::MulFOp>(max, sqrt, fmfWithNaNInf);
69  }
70 
71  Value isNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, result,
72  result, fmfWithNaNInf);
73  return b.create<arith::SelectOp>(isNaN, min, result);
74 }
75 
76 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
78 
79  LogicalResult
80  matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
81  ConversionPatternRewriter &rewriter) const override {
82  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
83 
84  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
85 
86  Value real = b.create<complex::ReOp>(adaptor.getComplex());
87  Value imag = b.create<complex::ImOp>(adaptor.getComplex());
88  rewriter.replaceOp(op, computeAbs(real, imag, fmf, b));
89 
90  return success();
91  }
92 };
93 
94 // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
95 struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
97 
98  LogicalResult
99  matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
100  ConversionPatternRewriter &rewriter) const override {
101  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
102 
103  auto type = cast<ComplexType>(op.getType());
104  Type elementType = type.getElementType();
105  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
106 
107  Value lhs = adaptor.getLhs();
108  Value rhs = adaptor.getRhs();
109 
110  Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs, fmf);
111  Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs, fmf);
112  Value rhsSquaredPlusLhsSquared =
113  b.create<complex::AddOp>(type, rhsSquared, lhsSquared, fmf);
114  Value sqrtOfRhsSquaredPlusLhsSquared =
115  b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared, fmf);
116 
117  Value zero =
118  b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
119  Value one = b.create<arith::ConstantOp>(elementType,
120  b.getFloatAttr(elementType, 1));
121  Value i = b.create<complex::CreateOp>(type, zero, one);
122  Value iTimesLhs = b.create<complex::MulOp>(i, lhs, fmf);
123  Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs, fmf);
124 
125  Value divResult = b.create<complex::DivOp>(
126  rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
127  Value logResult = b.create<complex::LogOp>(divResult, fmf);
128 
129  Value negativeOne = b.create<arith::ConstantOp>(
130  elementType, b.getFloatAttr(elementType, -1));
131  Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
132 
133  rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf);
134  return success();
135  }
136 };
137 
138 template <typename ComparisonOp, arith::CmpFPredicate p>
139 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
141  using ResultCombiner =
142  std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
143  arith::AndIOp, arith::OrIOp>;
144 
145  LogicalResult
146  matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
147  ConversionPatternRewriter &rewriter) const override {
148  auto loc = op.getLoc();
149  auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
150 
151  Value realLhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getLhs());
152  Value imagLhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getLhs());
153  Value realRhs = rewriter.create<complex::ReOp>(loc, type, adaptor.getRhs());
154  Value imagRhs = rewriter.create<complex::ImOp>(loc, type, adaptor.getRhs());
155  Value realComparison =
156  rewriter.create<arith::CmpFOp>(loc, p, realLhs, realRhs);
157  Value imagComparison =
158  rewriter.create<arith::CmpFOp>(loc, p, imagLhs, imagRhs);
159 
160  rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
161  imagComparison);
162  return success();
163  }
164 };
165 
166 // Default conversion which applies the BinaryStandardOp separately on the real
167 // and imaginary parts. Can for example be used for complex::AddOp and
168 // complex::SubOp.
169 template <typename BinaryComplexOp, typename BinaryStandardOp>
170 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
172 
173  LogicalResult
174  matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
175  ConversionPatternRewriter &rewriter) const override {
176  auto type = cast<ComplexType>(adaptor.getLhs().getType());
177  auto elementType = cast<FloatType>(type.getElementType());
178  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
179  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
180 
181  Value realLhs = b.create<complex::ReOp>(elementType, adaptor.getLhs());
182  Value realRhs = b.create<complex::ReOp>(elementType, adaptor.getRhs());
183  Value resultReal = b.create<BinaryStandardOp>(elementType, realLhs, realRhs,
184  fmf.getValue());
185  Value imagLhs = b.create<complex::ImOp>(elementType, adaptor.getLhs());
186  Value imagRhs = b.create<complex::ImOp>(elementType, adaptor.getRhs());
187  Value resultImag = b.create<BinaryStandardOp>(elementType, imagLhs, imagRhs,
188  fmf.getValue());
189  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
190  resultImag);
191  return success();
192  }
193 };
194 
195 template <typename TrigonometricOp>
196 struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
197  using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
198 
200 
201  LogicalResult
202  matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
203  ConversionPatternRewriter &rewriter) const override {
204  auto loc = op.getLoc();
205  auto type = cast<ComplexType>(adaptor.getComplex().getType());
206  auto elementType = cast<FloatType>(type.getElementType());
207  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
208 
209  Value real =
210  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
211  Value imag =
212  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
213 
214  // Trigonometric ops use a set of common building blocks to convert to real
215  // ops. Here we create these building blocks and call into an op-specific
216  // implementation in the subclass to combine them.
217  Value half = rewriter.create<arith::ConstantOp>(
218  loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
219  Value exp = rewriter.create<math::ExpOp>(loc, imag, fmf);
220  Value scaledExp = rewriter.create<arith::MulFOp>(loc, half, exp, fmf);
221  Value reciprocalExp = rewriter.create<arith::DivFOp>(loc, half, exp, fmf);
222  Value sin = rewriter.create<math::SinOp>(loc, real, fmf);
223  Value cos = rewriter.create<math::CosOp>(loc, real, fmf);
224 
225  auto resultPair =
226  combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
227 
228  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
229  resultPair.second);
230  return success();
231  }
232 
233  virtual std::pair<Value, Value>
234  combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
235  Value cos, ConversionPatternRewriter &rewriter,
236  arith::FastMathFlagsAttr fmf) const = 0;
237 };
238 
239 struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
240  using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
241 
242  std::pair<Value, Value> combine(Location loc, Value scaledExp,
243  Value reciprocalExp, Value sin, Value cos,
244  ConversionPatternRewriter &rewriter,
245  arith::FastMathFlagsAttr fmf) const override {
246  // Complex cosine is defined as;
247  // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
248  // Plugging in:
249  // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
250  // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
251  // and defining t := exp(y)
252  // We get:
253  // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
254  // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
255  Value sum =
256  rewriter.create<arith::AddFOp>(loc, reciprocalExp, scaledExp, fmf);
257  Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, cos, fmf);
258  Value diff =
259  rewriter.create<arith::SubFOp>(loc, reciprocalExp, scaledExp, fmf);
260  Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, sin, fmf);
261  return {resultReal, resultImag};
262  }
263 };
264 
265 struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
266  DivOpConversion(MLIRContext *context, complex::ComplexRangeFlags target)
267  : OpConversionPattern<complex::DivOp>(context), complexRange(target) {}
268 
270 
271  LogicalResult
272  matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
273  ConversionPatternRewriter &rewriter) const override {
274  auto loc = op.getLoc();
275  auto type = cast<ComplexType>(adaptor.getLhs().getType());
276  auto elementType = cast<FloatType>(type.getElementType());
277  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
278 
279  Value lhsReal =
280  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getLhs());
281  Value lhsImag =
282  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getLhs());
283  Value rhsReal =
284  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getRhs());
285  Value rhsImag =
286  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getRhs());
287 
288  Value resultReal, resultImag;
289 
290  if (complexRange == complex::ComplexRangeFlags::basic ||
291  complexRange == complex::ComplexRangeFlags::none) {
293  rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal,
294  &resultImag);
295  } else if (complexRange == complex::ComplexRangeFlags::improved) {
297  rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal,
298  &resultImag);
299  }
300 
301  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
302  resultImag);
303 
304  return success();
305  }
306 
307 private:
308  complex::ComplexRangeFlags complexRange;
309 };
310 
311 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
313 
314  LogicalResult
315  matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
316  ConversionPatternRewriter &rewriter) const override {
317  auto loc = op.getLoc();
318  auto type = cast<ComplexType>(adaptor.getComplex().getType());
319  auto elementType = cast<FloatType>(type.getElementType());
320  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
321 
322  Value real =
323  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
324  Value imag =
325  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
326  Value expReal = rewriter.create<math::ExpOp>(loc, real, fmf.getValue());
327  Value cosImag = rewriter.create<math::CosOp>(loc, imag, fmf.getValue());
328  Value resultReal =
329  rewriter.create<arith::MulFOp>(loc, expReal, cosImag, fmf.getValue());
330  Value sinImag = rewriter.create<math::SinOp>(loc, imag, fmf.getValue());
331  Value resultImag =
332  rewriter.create<arith::MulFOp>(loc, expReal, sinImag, fmf.getValue());
333 
334  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
335  resultImag);
336  return success();
337  }
338 };
339 
340 Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg,
341  ArrayRef<double> coefficients,
342  arith::FastMathFlagsAttr fmf) {
343  auto argType = mlir::cast<FloatType>(arg.getType());
344  Value poly =
345  b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[0]));
346  for (unsigned i = 1; i < coefficients.size(); ++i) {
347  poly = b.create<math::FmaOp>(
348  poly, arg,
349  b.create<arith::ConstantOp>(b.getFloatAttr(argType, coefficients[i])),
350  fmf);
351  }
352  return poly;
353 }
354 
355 struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
357 
358  // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
359  // [handle inaccuracies when a and/or b are small]
360  // = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i
361  // = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i
362  LogicalResult
363  matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
364  ConversionPatternRewriter &rewriter) const override {
365  auto type = op.getType();
366  auto elemType = mlir::cast<FloatType>(type.getElementType());
367 
368  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
369  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
370  Value real = b.create<complex::ReOp>(adaptor.getComplex());
371  Value imag = b.create<complex::ImOp>(adaptor.getComplex());
372 
373  Value zero = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 0.0));
374  Value one = b.create<arith::ConstantOp>(b.getFloatAttr(elemType, 1.0));
375 
376  Value expm1Real = b.create<math::ExpM1Op>(real, fmf);
377  Value expReal = b.create<arith::AddFOp>(expm1Real, one, fmf);
378 
379  Value sinImag = b.create<math::SinOp>(imag, fmf);
380  Value cosm1Imag = emitCosm1(imag, fmf, b);
381  Value cosImag = b.create<arith::AddFOp>(cosm1Imag, one, fmf);
382 
383  Value realResult = b.create<arith::AddFOp>(
384  b.create<arith::MulFOp>(expm1Real, cosImag, fmf), cosm1Imag, fmf);
385 
386  Value imagIsZero = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag,
387  zero, fmf.getValue());
388  Value imagResult = b.create<arith::SelectOp>(
389  imagIsZero, zero, b.create<arith::MulFOp>(expReal, sinImag, fmf));
390 
391  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult,
392  imagResult);
393  return success();
394  }
395 
396 private:
397  Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf,
398  ImplicitLocOpBuilder &b) const {
399  auto argType = mlir::cast<FloatType>(arg.getType());
400  auto negHalf = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -0.5));
401  auto negOne = b.create<arith::ConstantOp>(b.getFloatAttr(argType, -1.0));
402 
403  // Algorithm copied from cephes cosm1.
404  SmallVector<double, 7> kCoeffs{
405  4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
406  2.0876754287081521758361E-9, -2.7557319214999787979814E-7,
407  2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
408  4.1666666666666666609054E-2,
409  };
410  Value cos = b.create<math::CosOp>(arg, fmf);
411  Value forLargeArg = b.create<arith::AddFOp>(cos, negOne, fmf);
412 
413  Value argPow2 = b.create<arith::MulFOp>(arg, arg, fmf);
414  Value argPow4 = b.create<arith::MulFOp>(argPow2, argPow2, fmf);
415  Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf);
416 
417  auto forSmallArg =
418  b.create<arith::AddFOp>(b.create<arith::MulFOp>(argPow4, poly, fmf),
419  b.create<arith::MulFOp>(negHalf, argPow2, fmf));
420 
421  // (pi/4)^2 is approximately 0.61685
422  Value piOver4Pow2 =
423  b.create<arith::ConstantOp>(b.getFloatAttr(argType, 0.61685));
424  Value cond = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, argPow2,
425  piOver4Pow2, fmf.getValue());
426  return b.create<arith::SelectOp>(cond, forLargeArg, forSmallArg);
427  }
428 };
429 
430 struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
432 
433  LogicalResult
434  matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
435  ConversionPatternRewriter &rewriter) const override {
436  auto type = cast<ComplexType>(adaptor.getComplex().getType());
437  auto elementType = cast<FloatType>(type.getElementType());
438  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
439  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
440 
441  Value abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(),
442  fmf.getValue());
443  Value resultReal = b.create<math::LogOp>(elementType, abs, fmf.getValue());
444  Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
445  Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
446  Value resultImag =
447  b.create<math::Atan2Op>(elementType, imag, real, fmf.getValue());
448  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
449  resultImag);
450  return success();
451  }
452 };
453 
454 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
456 
457  LogicalResult
458  matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
459  ConversionPatternRewriter &rewriter) const override {
460  auto type = cast<ComplexType>(adaptor.getComplex().getType());
461  auto elementType = cast<FloatType>(type.getElementType());
462  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
463  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
464 
465  Value real = b.create<complex::ReOp>(adaptor.getComplex());
466  Value imag = b.create<complex::ImOp>(adaptor.getComplex());
467 
468  Value half = b.create<arith::ConstantOp>(elementType,
469  b.getFloatAttr(elementType, 0.5));
470  Value one = b.create<arith::ConstantOp>(elementType,
471  b.getFloatAttr(elementType, 1));
472  Value realPlusOne = b.create<arith::AddFOp>(real, one, fmf);
473  Value absRealPlusOne = b.create<math::AbsFOp>(realPlusOne, fmf);
474  Value absImag = b.create<math::AbsFOp>(imag, fmf);
475 
476  Value maxAbs = b.create<arith::MaximumFOp>(absRealPlusOne, absImag, fmf);
477  Value minAbs = b.create<arith::MinimumFOp>(absRealPlusOne, absImag, fmf);
478 
479  Value useReal = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGT,
480  realPlusOne, absImag, fmf);
481  Value maxMinusOne = b.create<arith::SubFOp>(maxAbs, one, fmf);
482  Value maxAbsOfRealPlusOneAndImagMinusOne =
483  b.create<arith::SelectOp>(useReal, real, maxMinusOne);
484  arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
485  fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
486  Value minMaxRatio = b.create<arith::DivFOp>(minAbs, maxAbs, fmfWithNaNInf);
487  Value logOfMaxAbsOfRealPlusOneAndImag =
488  b.create<math::Log1pOp>(maxAbsOfRealPlusOneAndImagMinusOne, fmf);
489  Value logOfSqrtPart = b.create<math::Log1pOp>(
490  b.create<arith::MulFOp>(minMaxRatio, minMaxRatio, fmfWithNaNInf),
491  fmfWithNaNInf);
492  Value r = b.create<arith::AddFOp>(
493  b.create<arith::MulFOp>(half, logOfSqrtPart, fmfWithNaNInf),
494  logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf);
495  Value resultReal = b.create<arith::SelectOp>(
496  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, r, r, fmfWithNaNInf),
497  minAbs, r);
498  Value resultImag = b.create<math::Atan2Op>(imag, realPlusOne, fmf);
499  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
500  resultImag);
501  return success();
502  }
503 };
504 
505 struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
507 
508  LogicalResult
509  matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
510  ConversionPatternRewriter &rewriter) const override {
511  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
512  auto type = cast<ComplexType>(adaptor.getLhs().getType());
513  auto elementType = cast<FloatType>(type.getElementType());
514  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
515  auto fmfValue = fmf.getValue();
516  Value lhsReal = b.create<complex::ReOp>(elementType, adaptor.getLhs());
517  Value lhsImag = b.create<complex::ImOp>(elementType, adaptor.getLhs());
518  Value rhsReal = b.create<complex::ReOp>(elementType, adaptor.getRhs());
519  Value rhsImag = b.create<complex::ImOp>(elementType, adaptor.getRhs());
520  Value lhsRealTimesRhsReal =
521  b.create<arith::MulFOp>(lhsReal, rhsReal, fmfValue);
522  Value lhsImagTimesRhsImag =
523  b.create<arith::MulFOp>(lhsImag, rhsImag, fmfValue);
524  Value real = b.create<arith::SubFOp>(lhsRealTimesRhsReal,
525  lhsImagTimesRhsImag, fmfValue);
526  Value lhsImagTimesRhsReal =
527  b.create<arith::MulFOp>(lhsImag, rhsReal, fmfValue);
528  Value lhsRealTimesRhsImag =
529  b.create<arith::MulFOp>(lhsReal, rhsImag, fmfValue);
530  Value imag = b.create<arith::AddFOp>(lhsImagTimesRhsReal,
531  lhsRealTimesRhsImag, fmfValue);
532  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
533  return success();
534  }
535 };
536 
537 struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
539 
540  LogicalResult
541  matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
542  ConversionPatternRewriter &rewriter) const override {
543  auto loc = op.getLoc();
544  auto type = cast<ComplexType>(adaptor.getComplex().getType());
545  auto elementType = cast<FloatType>(type.getElementType());
546 
547  Value real =
548  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
549  Value imag =
550  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
551  Value negReal = rewriter.create<arith::NegFOp>(loc, real);
552  Value negImag = rewriter.create<arith::NegFOp>(loc, imag);
553  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
554  return success();
555  }
556 };
557 
558 struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
559  using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
560 
561  std::pair<Value, Value> combine(Location loc, Value scaledExp,
562  Value reciprocalExp, Value sin, Value cos,
563  ConversionPatternRewriter &rewriter,
564  arith::FastMathFlagsAttr fmf) const override {
565  // Complex sine is defined as;
566  // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
567  // Plugging in:
568  // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
569  // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
570  // and defining t := exp(y)
571  // We get:
572  // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
573  // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
574  Value sum =
575  rewriter.create<arith::AddFOp>(loc, scaledExp, reciprocalExp, fmf);
576  Value resultReal = rewriter.create<arith::MulFOp>(loc, sum, sin, fmf);
577  Value diff =
578  rewriter.create<arith::SubFOp>(loc, scaledExp, reciprocalExp, fmf);
579  Value resultImag = rewriter.create<arith::MulFOp>(loc, diff, cos, fmf);
580  return {resultReal, resultImag};
581  }
582 };
583 
584 // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
585 struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
587 
588  LogicalResult
589  matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
590  ConversionPatternRewriter &rewriter) const override {
591  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
592 
593  auto type = cast<ComplexType>(op.getType());
594  auto elementType = cast<FloatType>(type.getElementType());
595  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
596 
597  auto cst = [&](APFloat v) {
598  return b.create<arith::ConstantOp>(elementType,
599  b.getFloatAttr(elementType, v));
600  };
601  const auto &floatSemantics = elementType.getFloatSemantics();
602  Value zero = cst(APFloat::getZero(floatSemantics));
603  Value half = b.create<arith::ConstantOp>(elementType,
604  b.getFloatAttr(elementType, 0.5));
605 
606  Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
607  Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
608  Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
609  Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
610  Value sqrtArg = b.create<arith::MulFOp>(argArg, half, fmf);
611  Value cos = b.create<math::CosOp>(sqrtArg, fmf);
612  Value sin = b.create<math::SinOp>(sqrtArg, fmf);
613  // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply
614  // 0 * inf.
615  Value sinIsZero =
616  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, sin, zero, fmf);
617 
618  Value resultReal = b.create<arith::MulFOp>(absSqrt, cos, fmf);
619  Value resultImag = b.create<arith::SelectOp>(
620  sinIsZero, zero, b.create<arith::MulFOp>(absSqrt, sin, fmf));
621  if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
622  arith::FastMathFlags::ninf)) {
623  Value inf = cst(APFloat::getInf(floatSemantics));
624  Value negInf = cst(APFloat::getInf(floatSemantics, true));
625  Value nan = cst(APFloat::getNaN(floatSemantics));
626  Value absImag = b.create<math::AbsFOp>(elementType, imag, fmf);
627 
628  Value absImagIsInf =
629  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
630  Value absImagIsNotInf =
631  b.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, absImag, inf, fmf);
632  Value realIsInf =
633  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, inf, fmf);
634  Value realIsNegInf =
635  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, negInf, fmf);
636 
637  resultReal = b.create<arith::SelectOp>(
638  b.create<arith::AndIOp>(realIsNegInf, absImagIsNotInf), zero,
639  resultReal);
640  resultReal = b.create<arith::SelectOp>(
641  b.create<arith::OrIOp>(absImagIsInf, realIsInf), inf, resultReal);
642 
643  Value imagSignInf = b.create<math::CopySignOp>(inf, imag, fmf);
644  resultImag = b.create<arith::SelectOp>(
645  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, absSqrt, absSqrt),
646  nan, resultImag);
647  resultImag = b.create<arith::SelectOp>(
648  b.create<arith::OrIOp>(absImagIsInf, realIsNegInf), imagSignInf,
649  resultImag);
650  }
651 
652  Value resultIsZero =
653  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
654  resultReal = b.create<arith::SelectOp>(resultIsZero, zero, resultReal);
655  resultImag = b.create<arith::SelectOp>(resultIsZero, zero, resultImag);
656 
657  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
658  resultImag);
659  return success();
660  }
661 };
662 
663 struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
665 
666  LogicalResult
667  matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
668  ConversionPatternRewriter &rewriter) const override {
669  auto type = cast<ComplexType>(adaptor.getComplex().getType());
670  auto elementType = cast<FloatType>(type.getElementType());
671  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
672  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
673 
674  Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
675  Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
676  Value zero =
677  b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
678  Value realIsZero =
679  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
680  Value imagIsZero =
681  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
682  Value isZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
683  auto abs = b.create<complex::AbsOp>(elementType, adaptor.getComplex(), fmf);
684  Value realSign = b.create<arith::DivFOp>(real, abs, fmf);
685  Value imagSign = b.create<arith::DivFOp>(imag, abs, fmf);
686  Value sign = b.create<complex::CreateOp>(type, realSign, imagSign);
687  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
688  adaptor.getComplex(), sign);
689  return success();
690  }
691 };
692 
693 template <typename Op>
694 struct TanTanhOpConversion : public OpConversionPattern<Op> {
696 
697  LogicalResult
698  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
699  ConversionPatternRewriter &rewriter) const override {
700  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
701  auto loc = op.getLoc();
702  auto type = cast<ComplexType>(adaptor.getComplex().getType());
703  auto elementType = cast<FloatType>(type.getElementType());
704  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
705  const auto &floatSemantics = elementType.getFloatSemantics();
706 
707  Value real =
708  b.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
709  Value imag =
710  b.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
711  Value negOne = b.create<arith::ConstantOp>(
712  elementType, b.getFloatAttr(elementType, -1.0));
713 
714  if constexpr (std::is_same_v<Op, complex::TanOp>) {
715  // tan(x+yi) = -i*tanh(-y + xi)
716  std::swap(real, imag);
717  real = b.create<arith::MulFOp>(real, negOne, fmf);
718  }
719 
720  auto cst = [&](APFloat v) {
721  return b.create<arith::ConstantOp>(elementType,
722  b.getFloatAttr(elementType, v));
723  };
724  Value inf = cst(APFloat::getInf(floatSemantics));
725  Value four = b.create<arith::ConstantOp>(elementType,
726  b.getFloatAttr(elementType, 4.0));
727  Value twoReal = b.create<arith::AddFOp>(real, real, fmf);
728  Value negTwoReal = b.create<arith::MulFOp>(negOne, twoReal, fmf);
729 
730  Value expTwoRealMinusOne = b.create<math::ExpM1Op>(twoReal, fmf);
731  Value expNegTwoRealMinusOne = b.create<math::ExpM1Op>(negTwoReal, fmf);
732  Value realNum =
733  b.create<arith::SubFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
734 
735  Value cosImag = b.create<math::CosOp>(imag, fmf);
736  Value cosImagSq = b.create<arith::MulFOp>(cosImag, cosImag, fmf);
737  Value twoCosTwoImagPlusOne = b.create<arith::MulFOp>(cosImagSq, four, fmf);
738  Value sinImag = b.create<math::SinOp>(imag, fmf);
739 
740  Value imagNum = b.create<arith::MulFOp>(
741  four, b.create<arith::MulFOp>(cosImag, sinImag, fmf), fmf);
742 
743  Value expSumMinusTwo =
744  b.create<arith::AddFOp>(expTwoRealMinusOne, expNegTwoRealMinusOne, fmf);
745  Value denom =
746  b.create<arith::AddFOp>(expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
747 
748  Value isInf = b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
749  expSumMinusTwo, inf, fmf);
750  Value realLimit = b.create<math::CopySignOp>(negOne, real, fmf);
751 
752  Value resultReal = b.create<arith::SelectOp>(
753  isInf, realLimit, b.create<arith::DivFOp>(realNum, denom, fmf));
754  Value resultImag = b.create<arith::DivFOp>(imagNum, denom, fmf);
755 
756  if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
757  arith::FastMathFlags::ninf)) {
758  Value absReal = b.create<math::AbsFOp>(real, fmf);
759  Value zero = b.create<arith::ConstantOp>(
760  elementType, b.getFloatAttr(elementType, 0.0));
761  Value nan = cst(APFloat::getNaN(floatSemantics));
762 
763  Value absRealIsInf =
764  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
765  Value imagIsZero =
766  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
767  Value absRealIsNotInf = b.create<arith::XOrIOp>(
768  absRealIsInf, b.create<arith::ConstantIntOp>(true, /*width=*/1));
769 
770  Value imagNumIsNaN = b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO,
771  imagNum, imagNum, fmf);
772  Value resultRealIsNaN =
773  b.create<arith::AndIOp>(imagNumIsNaN, absRealIsNotInf);
774  Value resultImagIsZero = b.create<arith::OrIOp>(
775  imagIsZero, b.create<arith::AndIOp>(absRealIsInf, imagNumIsNaN));
776 
777  resultReal = b.create<arith::SelectOp>(resultRealIsNaN, nan, resultReal);
778  resultImag =
779  b.create<arith::SelectOp>(resultImagIsZero, zero, resultImag);
780  }
781 
782  if constexpr (std::is_same_v<Op, complex::TanOp>) {
783  // tan(x+yi) = -i*tanh(-y + xi)
784  std::swap(resultReal, resultImag);
785  resultImag = b.create<arith::MulFOp>(resultImag, negOne, fmf);
786  }
787 
788  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
789  resultImag);
790  return success();
791  }
792 };
793 
794 struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
796 
797  LogicalResult
798  matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
799  ConversionPatternRewriter &rewriter) const override {
800  auto loc = op.getLoc();
801  auto type = cast<ComplexType>(adaptor.getComplex().getType());
802  auto elementType = cast<FloatType>(type.getElementType());
803  Value real =
804  rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
805  Value imag =
806  rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
807  Value negImag = rewriter.create<arith::NegFOp>(loc, elementType, imag);
808 
809  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);
810 
811  return success();
812  }
813 };
814 
815 /// Converts lhs^y = (a+bi)^(c+di) to
816 /// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
817 /// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
818 static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
819  ComplexType type, Value lhs, Value c, Value d,
820  arith::FastMathFlags fmf) {
821  auto elementType = cast<FloatType>(type.getElementType());
822 
823  Value a = builder.create<complex::ReOp>(lhs);
824  Value b = builder.create<complex::ImOp>(lhs);
825 
826  Value abs = builder.create<complex::AbsOp>(lhs, fmf);
827  Value absToC = builder.create<math::PowFOp>(abs, c, fmf);
828 
829  Value negD = builder.create<arith::NegFOp>(d, fmf);
830  Value argLhs = builder.create<math::Atan2Op>(b, a, fmf);
831  Value negDArgLhs = builder.create<arith::MulFOp>(negD, argLhs, fmf);
832  Value expNegDArgLhs = builder.create<math::ExpOp>(negDArgLhs, fmf);
833 
834  Value coeff = builder.create<arith::MulFOp>(absToC, expNegDArgLhs, fmf);
835  Value lnAbs = builder.create<math::LogOp>(abs, fmf);
836  Value cArgLhs = builder.create<arith::MulFOp>(c, argLhs, fmf);
837  Value dLnAbs = builder.create<arith::MulFOp>(d, lnAbs, fmf);
838  Value q = builder.create<arith::AddFOp>(cArgLhs, dLnAbs, fmf);
839  Value cosQ = builder.create<math::CosOp>(q, fmf);
840  Value sinQ = builder.create<math::SinOp>(q, fmf);
841 
842  Value inf = builder.create<arith::ConstantOp>(
843  elementType,
844  builder.getFloatAttr(elementType,
845  APFloat::getInf(elementType.getFloatSemantics())));
846  Value zero = builder.create<arith::ConstantOp>(
847  elementType, builder.getFloatAttr(elementType, 0.0));
848  Value one = builder.create<arith::ConstantOp>(
849  elementType, builder.getFloatAttr(elementType, 1.0));
850  Value complexOne = builder.create<complex::CreateOp>(type, one, zero);
851  Value complexZero = builder.create<complex::CreateOp>(type, zero, zero);
852  Value complexInf = builder.create<complex::CreateOp>(type, inf, zero);
853 
854  // Case 0:
855  // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
856  // Branch Cuts for Complex Elementary Functions or Much Ado About
857  // Nothing's Sign Bit, W. Kahan, Section 10.
858  Value absEqZero =
859  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, abs, zero, fmf);
860  Value dEqZero =
861  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero, fmf);
862  Value cEqZero =
863  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero, fmf);
864  Value bEqZero =
865  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, b, zero, fmf);
866 
867  Value zeroLeC =
868  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLE, zero, c, fmf);
869  Value coeffCosQ = builder.create<arith::MulFOp>(coeff, cosQ, fmf);
870  Value coeffSinQ = builder.create<arith::MulFOp>(coeff, sinQ, fmf);
871  Value complexOneOrZero =
872  builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero);
873  Value coeffCosSin =
874  builder.create<complex::CreateOp>(type, coeffCosQ, coeffSinQ);
875  Value cutoff0 = builder.create<arith::SelectOp>(
876  builder.create<arith::AndIOp>(
877  builder.create<arith::AndIOp>(absEqZero, dEqZero), zeroLeC),
878  complexOneOrZero, coeffCosSin);
879 
880  // Case 1:
881  // x^0 is defined to be 1 for any x, see
882  // Branch Cuts for Complex Elementary Functions or Much Ado About
883  // Nothing's Sign Bit, W. Kahan, Section 10.
884  Value rhsEqZero = builder.create<arith::AndIOp>(cEqZero, dEqZero);
885  Value cutoff1 =
886  builder.create<arith::SelectOp>(rhsEqZero, complexOne, cutoff0);
887 
888  // Case 2:
889  // 1^(c + d*i) = 1 + 0*i
890  Value lhsEqOne = builder.create<arith::AndIOp>(
891  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, one, fmf),
892  bEqZero);
893  Value cutoff2 =
894  builder.create<arith::SelectOp>(lhsEqOne, complexOne, cutoff1);
895 
896  // Case 3:
897  // inf^(c + 0*i) = inf + 0*i, c > 0
898  Value lhsEqInf = builder.create<arith::AndIOp>(
899  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, a, inf, fmf),
900  bEqZero);
901  Value rhsGt0 = builder.create<arith::AndIOp>(
902  dEqZero,
903  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, c, zero, fmf));
904  Value cutoff3 = builder.create<arith::SelectOp>(
905  builder.create<arith::AndIOp>(lhsEqInf, rhsGt0), complexInf, cutoff2);
906 
907  // Case 4:
908  // inf^(c + 0*i) = 0 + 0*i, c < 0
909  Value rhsLt0 = builder.create<arith::AndIOp>(
910  dEqZero,
911  builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, c, zero, fmf));
912  Value cutoff4 = builder.create<arith::SelectOp>(
913  builder.create<arith::AndIOp>(lhsEqInf, rhsLt0), complexZero, cutoff3);
914 
915  return cutoff4;
916 }
917 
918 struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
920 
921  LogicalResult
922  matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
923  ConversionPatternRewriter &rewriter) const override {
924  mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
925  auto type = cast<ComplexType>(adaptor.getLhs().getType());
926  auto elementType = cast<FloatType>(type.getElementType());
927 
928  Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
929  Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
930 
931  rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
932  c, d, op.getFastmath())});
933  return success();
934  }
935 };
936 
937 struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
939 
940  LogicalResult
941  matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
942  ConversionPatternRewriter &rewriter) const override {
943  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
944  auto type = cast<ComplexType>(adaptor.getComplex().getType());
945  auto elementType = cast<FloatType>(type.getElementType());
946 
947  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
948 
949  auto cst = [&](APFloat v) {
950  return b.create<arith::ConstantOp>(elementType,
951  b.getFloatAttr(elementType, v));
952  };
953  const auto &floatSemantics = elementType.getFloatSemantics();
954  Value zero = cst(APFloat::getZero(floatSemantics));
955  Value inf = cst(APFloat::getInf(floatSemantics));
956  Value negHalf = b.create<arith::ConstantOp>(
957  elementType, b.getFloatAttr(elementType, -0.5));
958  Value nan = cst(APFloat::getNaN(floatSemantics));
959 
960  Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
961  Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
962  Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
963  Value argArg = b.create<math::Atan2Op>(imag, real, fmf);
964  Value rsqrtArg = b.create<arith::MulFOp>(argArg, negHalf, fmf);
965  Value cos = b.create<math::CosOp>(rsqrtArg, fmf);
966  Value sin = b.create<math::SinOp>(rsqrtArg, fmf);
967 
968  Value resultReal = b.create<arith::MulFOp>(absRsqrt, cos, fmf);
969  Value resultImag = b.create<arith::MulFOp>(absRsqrt, sin, fmf);
970 
971  if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
972  arith::FastMathFlags::ninf)) {
973  Value negOne = b.create<arith::ConstantOp>(
974  elementType, b.getFloatAttr(elementType, -1));
975 
976  Value realSignedZero = b.create<math::CopySignOp>(zero, real, fmf);
977  Value imagSignedZero = b.create<math::CopySignOp>(zero, imag, fmf);
978  Value negImagSignedZero =
979  b.create<arith::MulFOp>(negOne, imagSignedZero, fmf);
980 
981  Value absReal = b.create<math::AbsFOp>(real, fmf);
982  Value absImag = b.create<math::AbsFOp>(imag, fmf);
983 
984  Value absImagIsInf =
985  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absImag, inf, fmf);
986  Value realIsNan =
987  b.create<arith::CmpFOp>(arith::CmpFPredicate::UNO, real, real, fmf);
988  Value realIsInf =
989  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, absReal, inf, fmf);
990  Value inIsNanInf = b.create<arith::AndIOp>(absImagIsInf, realIsNan);
991 
992  Value resultIsZero = b.create<arith::OrIOp>(inIsNanInf, realIsInf);
993 
994  resultReal =
995  b.create<arith::SelectOp>(resultIsZero, realSignedZero, resultReal);
996  resultImag = b.create<arith::SelectOp>(resultIsZero, negImagSignedZero,
997  resultImag);
998  }
999 
1000  Value isRealZero =
1001  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero, fmf);
1002  Value isImagZero =
1003  b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero, fmf);
1004  Value isZero = b.create<arith::AndIOp>(isRealZero, isImagZero);
1005 
1006  resultReal = b.create<arith::SelectOp>(isZero, inf, resultReal);
1007  resultImag = b.create<arith::SelectOp>(isZero, nan, resultImag);
1008 
1009  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
1010  resultImag);
1011  return success();
1012  }
1013 };
1014 
1015 struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
1017 
1018  LogicalResult
1019  matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
1020  ConversionPatternRewriter &rewriter) const override {
1021  auto loc = op.getLoc();
1022  auto type = op.getType();
1023  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
1024 
1025  Value real =
1026  rewriter.create<complex::ReOp>(loc, type, adaptor.getComplex());
1027  Value imag =
1028  rewriter.create<complex::ImOp>(loc, type, adaptor.getComplex());
1029 
1030  rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf);
1031 
1032  return success();
1033  }
1034 };
1035 
1036 } // namespace
1037 
1039  RewritePatternSet &patterns, complex::ComplexRangeFlags complexRange) {
1040  // clang-format off
1041  patterns.add<
1042  AbsOpConversion,
1043  AngleOpConversion,
1044  Atan2OpConversion,
1045  BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1046  BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
1047  ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
1048  ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
1049  ConjOpConversion,
1050  CosOpConversion,
1051  ExpOpConversion,
1052  Expm1OpConversion,
1053  Log1pOpConversion,
1054  LogOpConversion,
1055  MulOpConversion,
1056  NegOpConversion,
1057  SignOpConversion,
1058  SinOpConversion,
1059  SqrtOpConversion,
1060  TanTanhOpConversion<complex::TanOp>,
1061  TanTanhOpConversion<complex::TanhOp>,
1062  PowOpConversion,
1063  RsqrtOpConversion
1064  >(patterns.getContext());
1065 
1066  patterns.add<DivOpConversion>(patterns.getContext(), complexRange);
1067 
1068  // clang-format on
1069 }
1070 
1071 namespace {
1072 struct ConvertComplexToStandardPass
1073  : public impl::ConvertComplexToStandardPassBase<
1074  ConvertComplexToStandardPass> {
1075  using ConvertComplexToStandardPassBase::ConvertComplexToStandardPassBase;
1076 
1077  void runOnOperation() override;
1078 };
1079 
1080 void ConvertComplexToStandardPass::runOnOperation() {
1081  // Convert to the Standard dialect using the converter defined above.
1084 
1085  ConversionTarget target(getContext());
1086  target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
1087  target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1088  if (failed(
1089  applyPartialConversion(getOperation(), target, std::move(patterns))))
1090  signalPassFailure();
1091 }
1092 } // namespace
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:125
This provides public APIs that all operations should have.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void convertDivToStandardUsingAlgebraic(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the arith/math dialects using algebraic method
void convertDivToStandardUsingRangeReduction(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the arith/math dialects using Smith's method
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
Include the generated interface declarations.
void populateComplexToStandardConversionPatterns(RewritePatternSet &patterns, mlir::complex::ComplexRangeFlags complexRange=mlir::complex::ComplexRangeFlags::improved)
Populate the given list with patterns that convert from Complex to Standard.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.