MLIR  22.0.0git
ComplexToStandard.cpp
Go to the documentation of this file.
1 //===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/IR/PatternMatch.h"
17 #include <type_traits>
18 
19 namespace mlir {
20 #define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARDPASS
21 #include "mlir/Conversion/Passes.h.inc"
22 } // namespace mlir
23 
24 using namespace mlir;
25 
26 namespace {
27 
28 enum class AbsFn { abs, sqrt, rsqrt };
29 
30 // Returns the absolute value, its square root or its reciprocal square root.
31 Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
32  ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) {
33  Value one = arith::ConstantOp::create(b, real.getType(),
34  b.getFloatAttr(real.getType(), 1.0));
35 
36  Value absReal = math::AbsFOp::create(b, real, fmf);
37  Value absImag = math::AbsFOp::create(b, imag, fmf);
38 
39  Value max = arith::MaximumFOp::create(b, absReal, absImag, fmf);
40  Value min = arith::MinimumFOp::create(b, absReal, absImag, fmf);
41 
42  // The lowering below requires NaNs and infinities to work correctly.
43  arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
44  fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
45  Value ratio = arith::DivFOp::create(b, min, max, fmfWithNaNInf);
46  Value ratioSq = arith::MulFOp::create(b, ratio, ratio, fmfWithNaNInf);
47  Value ratioSqPlusOne = arith::AddFOp::create(b, ratioSq, one, fmfWithNaNInf);
48  Value result;
49 
50  if (fn == AbsFn::rsqrt) {
51  ratioSqPlusOne = math::RsqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf);
52  min = math::RsqrtOp::create(b, min, fmfWithNaNInf);
53  max = math::RsqrtOp::create(b, max, fmfWithNaNInf);
54  }
55 
56  if (fn == AbsFn::sqrt) {
57  Value quarter = arith::ConstantOp::create(
58  b, real.getType(), b.getFloatAttr(real.getType(), 0.25));
59  // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
60  Value sqrt = math::SqrtOp::create(b, max, fmfWithNaNInf);
61  Value p025 =
62  math::PowFOp::create(b, ratioSqPlusOne, quarter, fmfWithNaNInf);
63  result = arith::MulFOp::create(b, sqrt, p025, fmfWithNaNInf);
64  } else {
65  Value sqrt = math::SqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf);
66  result = arith::MulFOp::create(b, max, sqrt, fmfWithNaNInf);
67  }
68 
69  Value isNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, result,
70  result, fmfWithNaNInf);
71  return arith::SelectOp::create(b, isNaN, min, result);
72 }
73 
74 struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
76 
77  LogicalResult
78  matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
79  ConversionPatternRewriter &rewriter) const override {
80  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
81 
82  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
83 
84  Value real = complex::ReOp::create(b, adaptor.getComplex());
85  Value imag = complex::ImOp::create(b, adaptor.getComplex());
86  rewriter.replaceOp(op, computeAbs(real, imag, fmf, b));
87 
88  return success();
89  }
90 };
91 
92 // atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
93 struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
95 
96  LogicalResult
97  matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
98  ConversionPatternRewriter &rewriter) const override {
99  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
100 
101  auto type = cast<ComplexType>(op.getType());
102  Type elementType = type.getElementType();
103  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
104 
105  Value lhs = adaptor.getLhs();
106  Value rhs = adaptor.getRhs();
107 
108  Value rhsSquared = complex::MulOp::create(b, type, rhs, rhs, fmf);
109  Value lhsSquared = complex::MulOp::create(b, type, lhs, lhs, fmf);
110  Value rhsSquaredPlusLhsSquared =
111  complex::AddOp::create(b, type, rhsSquared, lhsSquared, fmf);
112  Value sqrtOfRhsSquaredPlusLhsSquared =
113  complex::SqrtOp::create(b, type, rhsSquaredPlusLhsSquared, fmf);
114 
115  Value zero =
116  arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType));
117  Value one = arith::ConstantOp::create(b, elementType,
118  b.getFloatAttr(elementType, 1));
119  Value i = complex::CreateOp::create(b, type, zero, one);
120  Value iTimesLhs = complex::MulOp::create(b, i, lhs, fmf);
121  Value rhsPlusILhs = complex::AddOp::create(b, rhs, iTimesLhs, fmf);
122 
123  Value divResult = complex::DivOp::create(
124  b, rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
125  Value logResult = complex::LogOp::create(b, divResult, fmf);
126 
127  Value negativeOne = arith::ConstantOp::create(
128  b, elementType, b.getFloatAttr(elementType, -1));
129  Value negativeI = complex::CreateOp::create(b, type, zero, negativeOne);
130 
131  rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf);
132  return success();
133  }
134 };
135 
136 template <typename ComparisonOp, arith::CmpFPredicate p>
137 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
139  using ResultCombiner =
140  std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
141  arith::AndIOp, arith::OrIOp>;
142 
143  LogicalResult
144  matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
145  ConversionPatternRewriter &rewriter) const override {
146  auto loc = op.getLoc();
147  auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
148 
149  Value realLhs =
150  complex::ReOp::create(rewriter, loc, type, adaptor.getLhs());
151  Value imagLhs =
152  complex::ImOp::create(rewriter, loc, type, adaptor.getLhs());
153  Value realRhs =
154  complex::ReOp::create(rewriter, loc, type, adaptor.getRhs());
155  Value imagRhs =
156  complex::ImOp::create(rewriter, loc, type, adaptor.getRhs());
157  Value realComparison =
158  arith::CmpFOp::create(rewriter, loc, p, realLhs, realRhs);
159  Value imagComparison =
160  arith::CmpFOp::create(rewriter, loc, p, imagLhs, imagRhs);
161 
162  rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
163  imagComparison);
164  return success();
165  }
166 };
167 
168 // Default conversion which applies the BinaryStandardOp separately on the real
169 // and imaginary parts. Can for example be used for complex::AddOp and
170 // complex::SubOp.
171 template <typename BinaryComplexOp, typename BinaryStandardOp>
172 struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
174 
175  LogicalResult
176  matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
177  ConversionPatternRewriter &rewriter) const override {
178  auto type = cast<ComplexType>(adaptor.getLhs().getType());
179  auto elementType = cast<FloatType>(type.getElementType());
180  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
181  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
182 
183  Value realLhs = complex::ReOp::create(b, elementType, adaptor.getLhs());
184  Value realRhs = complex::ReOp::create(b, elementType, adaptor.getRhs());
185  Value resultReal = BinaryStandardOp::create(b, elementType, realLhs,
186  realRhs, fmf.getValue());
187  Value imagLhs = complex::ImOp::create(b, elementType, adaptor.getLhs());
188  Value imagRhs = complex::ImOp::create(b, elementType, adaptor.getRhs());
189  Value resultImag = BinaryStandardOp::create(b, elementType, imagLhs,
190  imagRhs, fmf.getValue());
191  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
192  resultImag);
193  return success();
194  }
195 };
196 
197 template <typename TrigonometricOp>
198 struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
199  using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
200 
202 
203  LogicalResult
204  matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
205  ConversionPatternRewriter &rewriter) const override {
206  auto loc = op.getLoc();
207  auto type = cast<ComplexType>(adaptor.getComplex().getType());
208  auto elementType = cast<FloatType>(type.getElementType());
209  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
210 
211  Value real =
212  complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
213  Value imag =
214  complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
215 
216  // Trigonometric ops use a set of common building blocks to convert to real
217  // ops. Here we create these building blocks and call into an op-specific
218  // implementation in the subclass to combine them.
219  Value half = arith::ConstantOp::create(
220  rewriter, loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
221  Value exp = math::ExpOp::create(rewriter, loc, imag, fmf);
222  Value scaledExp = arith::MulFOp::create(rewriter, loc, half, exp, fmf);
223  Value reciprocalExp = arith::DivFOp::create(rewriter, loc, half, exp, fmf);
224  Value sin = math::SinOp::create(rewriter, loc, real, fmf);
225  Value cos = math::CosOp::create(rewriter, loc, real, fmf);
226 
227  auto resultPair =
228  combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
229 
230  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
231  resultPair.second);
232  return success();
233  }
234 
235  virtual std::pair<Value, Value>
236  combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
237  Value cos, ConversionPatternRewriter &rewriter,
238  arith::FastMathFlagsAttr fmf) const = 0;
239 };
240 
241 struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
242  using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
243 
244  std::pair<Value, Value> combine(Location loc, Value scaledExp,
245  Value reciprocalExp, Value sin, Value cos,
246  ConversionPatternRewriter &rewriter,
247  arith::FastMathFlagsAttr fmf) const override {
248  // Complex cosine is defined as;
249  // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
250  // Plugging in:
251  // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
252  // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
253  // and defining t := exp(y)
254  // We get:
255  // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
256  // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
257  Value sum =
258  arith::AddFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf);
259  Value resultReal = arith::MulFOp::create(rewriter, loc, sum, cos, fmf);
260  Value diff =
261  arith::SubFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf);
262  Value resultImag = arith::MulFOp::create(rewriter, loc, diff, sin, fmf);
263  return {resultReal, resultImag};
264  }
265 };
266 
267 struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
268  DivOpConversion(MLIRContext *context, complex::ComplexRangeFlags target)
269  : OpConversionPattern<complex::DivOp>(context), complexRange(target) {}
270 
272 
273  LogicalResult
274  matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
275  ConversionPatternRewriter &rewriter) const override {
276  auto loc = op.getLoc();
277  auto type = cast<ComplexType>(adaptor.getLhs().getType());
278  auto elementType = cast<FloatType>(type.getElementType());
279  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
280 
281  Value lhsReal =
282  complex::ReOp::create(rewriter, loc, elementType, adaptor.getLhs());
283  Value lhsImag =
284  complex::ImOp::create(rewriter, loc, elementType, adaptor.getLhs());
285  Value rhsReal =
286  complex::ReOp::create(rewriter, loc, elementType, adaptor.getRhs());
287  Value rhsImag =
288  complex::ImOp::create(rewriter, loc, elementType, adaptor.getRhs());
289 
290  Value resultReal, resultImag;
291 
292  if (complexRange == complex::ComplexRangeFlags::basic ||
293  complexRange == complex::ComplexRangeFlags::none) {
295  rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal,
296  &resultImag);
297  } else if (complexRange == complex::ComplexRangeFlags::improved) {
299  rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal,
300  &resultImag);
301  }
302 
303  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
304  resultImag);
305 
306  return success();
307  }
308 
309 private:
310  complex::ComplexRangeFlags complexRange;
311 };
312 
313 struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
315 
316  // exp(x+I*y) = exp(x)*(cos(y)+I*sin(y))
317  // Handle special cases as StableHLO implementation does:
318  // 1. When b == 0, set imag(exp(z)) = 0
319  // 2. When exp(x) == inf, use exp(x/2)*(cos(y)+I*sin(y))*exp(x/2)
320  LogicalResult
321  matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
322  ConversionPatternRewriter &rewriter) const override {
323  auto loc = op.getLoc();
324  auto type = cast<ComplexType>(adaptor.getComplex().getType());
325  auto ET = cast<FloatType>(type.getElementType());
326  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
327  const auto &floatSemantics = ET.getFloatSemantics();
328  ImplicitLocOpBuilder b(loc, rewriter);
329 
330  Value x = complex::ReOp::create(b, ET, adaptor.getComplex());
331  Value y = complex::ImOp::create(b, ET, adaptor.getComplex());
332  Value zero = arith::ConstantOp::create(b, ET, b.getZeroAttr(ET));
333  Value half = arith::ConstantOp::create(b, ET, b.getFloatAttr(ET, 0.5));
334  Value inf = arith::ConstantOp::create(
335  b, ET, b.getFloatAttr(ET, APFloat::getInf(floatSemantics)));
336 
337  Value exp = math::ExpOp::create(b, x, fmf);
338  Value xHalf = arith::MulFOp::create(b, x, half, fmf);
339  Value expHalf = math::ExpOp::create(b, xHalf, fmf);
340  Value cos = math::CosOp::create(b, y, fmf);
341  Value sin = math::SinOp::create(b, y, fmf);
342 
343  Value expIsInf =
344  arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, exp, inf, fmf);
345  Value yIsZero =
346  arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, y, zero);
347 
348  // Real path: select between exp(x)*cos(y) and exp(x/2)*cos(y)*exp(x/2)
349  Value realNormal = arith::MulFOp::create(b, exp, cos, fmf);
350  Value expHalfCos = arith::MulFOp::create(b, expHalf, cos, fmf);
351  Value realOverflow = arith::MulFOp::create(b, expHalfCos, expHalf, fmf);
352  Value resultReal =
353  arith::SelectOp::create(b, expIsInf, realOverflow, realNormal);
354 
355  // Imaginary part: if y == 0 return 0 else select between exp(x)*sin(y) and
356  // exp(x/2)*sin(y)*exp(x/2)
357  Value imagNormal = arith::MulFOp::create(b, exp, sin, fmf);
358  Value expHalfSin = arith::MulFOp::create(b, expHalf, sin, fmf);
359  Value imagOverflow = arith::MulFOp::create(b, expHalfSin, expHalf, fmf);
360  Value imagNonZero =
361  arith::SelectOp::create(b, expIsInf, imagOverflow, imagNormal);
362  Value resultImag = arith::SelectOp::create(b, yIsZero, zero, imagNonZero);
363 
364  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
365  resultImag);
366  return success();
367  }
368 };
369 
370 Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg,
371  ArrayRef<double> coefficients,
372  arith::FastMathFlagsAttr fmf) {
373  auto argType = mlir::cast<FloatType>(arg.getType());
374  Value poly =
375  arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[0]));
376  for (unsigned i = 1; i < coefficients.size(); ++i) {
377  poly = math::FmaOp::create(
378  b, poly, arg,
379  arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[i])),
380  fmf);
381  }
382  return poly;
383 }
384 
385 struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
387 
388  // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
389  // [handle inaccuracies when a and/or b are small]
390  // = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i
391  // = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i
392  LogicalResult
393  matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
394  ConversionPatternRewriter &rewriter) const override {
395  auto type = op.getType();
396  auto elemType = mlir::cast<FloatType>(type.getElementType());
397 
398  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
399  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
400  Value real = complex::ReOp::create(b, adaptor.getComplex());
401  Value imag = complex::ImOp::create(b, adaptor.getComplex());
402 
403  Value zero = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 0.0));
404  Value one = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 1.0));
405 
406  Value expm1Real = math::ExpM1Op::create(b, real, fmf);
407  Value expReal = arith::AddFOp::create(b, expm1Real, one, fmf);
408 
409  Value sinImag = math::SinOp::create(b, imag, fmf);
410  Value cosm1Imag = emitCosm1(imag, fmf, b);
411  Value cosImag = arith::AddFOp::create(b, cosm1Imag, one, fmf);
412 
413  Value realResult = arith::AddFOp::create(
414  b, arith::MulFOp::create(b, expm1Real, cosImag, fmf), cosm1Imag, fmf);
415 
416  Value imagIsZero = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag,
417  zero, fmf.getValue());
418  Value imagResult = arith::SelectOp::create(
419  b, imagIsZero, zero, arith::MulFOp::create(b, expReal, sinImag, fmf));
420 
421  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult,
422  imagResult);
423  return success();
424  }
425 
426 private:
427  Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf,
428  ImplicitLocOpBuilder &b) const {
429  auto argType = mlir::cast<FloatType>(arg.getType());
430  auto negHalf = arith::ConstantOp::create(b, b.getFloatAttr(argType, -0.5));
431  auto negOne = arith::ConstantOp::create(b, b.getFloatAttr(argType, -1.0));
432 
433  // Algorithm copied from cephes cosm1.
434  SmallVector<double, 7> kCoeffs{
435  4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
436  2.0876754287081521758361E-9, -2.7557319214999787979814E-7,
437  2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
438  4.1666666666666666609054E-2,
439  };
440  Value cos = math::CosOp::create(b, arg, fmf);
441  Value forLargeArg = arith::AddFOp::create(b, cos, negOne, fmf);
442 
443  Value argPow2 = arith::MulFOp::create(b, arg, arg, fmf);
444  Value argPow4 = arith::MulFOp::create(b, argPow2, argPow2, fmf);
445  Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf);
446 
447  auto forSmallArg =
448  arith::AddFOp::create(b, arith::MulFOp::create(b, argPow4, poly, fmf),
449  arith::MulFOp::create(b, negHalf, argPow2, fmf));
450 
451  // (pi/4)^2 is approximately 0.61685
452  Value piOver4Pow2 =
453  arith::ConstantOp::create(b, b.getFloatAttr(argType, 0.61685));
454  Value cond = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, argPow2,
455  piOver4Pow2, fmf.getValue());
456  return arith::SelectOp::create(b, cond, forLargeArg, forSmallArg);
457  }
458 };
459 
460 struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
462 
463  LogicalResult
464  matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
465  ConversionPatternRewriter &rewriter) const override {
466  auto type = cast<ComplexType>(adaptor.getComplex().getType());
467  auto elementType = cast<FloatType>(type.getElementType());
468  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
469  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
470 
471  Value abs = complex::AbsOp::create(b, elementType, adaptor.getComplex(),
472  fmf.getValue());
473  Value resultReal = math::LogOp::create(b, elementType, abs, fmf.getValue());
474  Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
475  Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
476  Value resultImag =
477  math::Atan2Op::create(b, elementType, imag, real, fmf.getValue());
478  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
479  resultImag);
480  return success();
481  }
482 };
483 
484 struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
486 
487  LogicalResult
488  matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
489  ConversionPatternRewriter &rewriter) const override {
490  auto type = cast<ComplexType>(adaptor.getComplex().getType());
491  auto elementType = cast<FloatType>(type.getElementType());
492  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
493  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
494 
495  Value real = complex::ReOp::create(b, adaptor.getComplex());
496  Value imag = complex::ImOp::create(b, adaptor.getComplex());
497 
498  Value half = arith::ConstantOp::create(b, elementType,
499  b.getFloatAttr(elementType, 0.5));
500  Value one = arith::ConstantOp::create(b, elementType,
501  b.getFloatAttr(elementType, 1));
502  Value realPlusOne = arith::AddFOp::create(b, real, one, fmf);
503  Value absRealPlusOne = math::AbsFOp::create(b, realPlusOne, fmf);
504  Value absImag = math::AbsFOp::create(b, imag, fmf);
505 
506  Value maxAbs = arith::MaximumFOp::create(b, absRealPlusOne, absImag, fmf);
507  Value minAbs = arith::MinimumFOp::create(b, absRealPlusOne, absImag, fmf);
508 
509  Value useReal = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT,
510  realPlusOne, absImag, fmf);
511  Value maxMinusOne = arith::SubFOp::create(b, maxAbs, one, fmf);
512  Value maxAbsOfRealPlusOneAndImagMinusOne =
513  arith::SelectOp::create(b, useReal, real, maxMinusOne);
514  arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
515  fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
516  Value minMaxRatio = arith::DivFOp::create(b, minAbs, maxAbs, fmfWithNaNInf);
517  Value logOfMaxAbsOfRealPlusOneAndImag =
518  math::Log1pOp::create(b, maxAbsOfRealPlusOneAndImagMinusOne, fmf);
519  Value logOfSqrtPart = math::Log1pOp::create(
520  b, arith::MulFOp::create(b, minMaxRatio, minMaxRatio, fmfWithNaNInf),
521  fmfWithNaNInf);
522  Value r = arith::AddFOp::create(
523  b, arith::MulFOp::create(b, half, logOfSqrtPart, fmfWithNaNInf),
524  logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf);
525  Value resultReal = arith::SelectOp::create(
526  b,
527  arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, r, r,
528  fmfWithNaNInf),
529  minAbs, r);
530  Value resultImag = math::Atan2Op::create(b, imag, realPlusOne, fmf);
531  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
532  resultImag);
533  return success();
534  }
535 };
536 
537 struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
539 
540  LogicalResult
541  matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
542  ConversionPatternRewriter &rewriter) const override {
543  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
544  auto type = cast<ComplexType>(adaptor.getLhs().getType());
545  auto elementType = cast<FloatType>(type.getElementType());
546  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
547  auto fmfValue = fmf.getValue();
548  Value lhsReal = complex::ReOp::create(b, elementType, adaptor.getLhs());
549  Value lhsImag = complex::ImOp::create(b, elementType, adaptor.getLhs());
550  Value rhsReal = complex::ReOp::create(b, elementType, adaptor.getRhs());
551  Value rhsImag = complex::ImOp::create(b, elementType, adaptor.getRhs());
552  Value lhsRealTimesRhsReal =
553  arith::MulFOp::create(b, lhsReal, rhsReal, fmfValue);
554  Value lhsImagTimesRhsImag =
555  arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue);
556  Value real = arith::SubFOp::create(b, lhsRealTimesRhsReal,
557  lhsImagTimesRhsImag, fmfValue);
558  Value lhsImagTimesRhsReal =
559  arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue);
560  Value lhsRealTimesRhsImag =
561  arith::MulFOp::create(b, lhsReal, rhsImag, fmfValue);
562  Value imag = arith::AddFOp::create(b, lhsImagTimesRhsReal,
563  lhsRealTimesRhsImag, fmfValue);
564  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
565  return success();
566  }
567 };
568 
569 struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
571 
572  LogicalResult
573  matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
574  ConversionPatternRewriter &rewriter) const override {
575  auto loc = op.getLoc();
576  auto type = cast<ComplexType>(adaptor.getComplex().getType());
577  auto elementType = cast<FloatType>(type.getElementType());
578 
579  Value real =
580  complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
581  Value imag =
582  complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
583  Value negReal = arith::NegFOp::create(rewriter, loc, real);
584  Value negImag = arith::NegFOp::create(rewriter, loc, imag);
585  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
586  return success();
587  }
588 };
589 
590 struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
591  using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
592 
593  std::pair<Value, Value> combine(Location loc, Value scaledExp,
594  Value reciprocalExp, Value sin, Value cos,
595  ConversionPatternRewriter &rewriter,
596  arith::FastMathFlagsAttr fmf) const override {
597  // Complex sine is defined as;
598  // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
599  // Plugging in:
600  // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
601  // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
602  // and defining t := exp(y)
603  // We get:
604  // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
605  // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
606  Value sum =
607  arith::AddFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf);
608  Value resultReal = arith::MulFOp::create(rewriter, loc, sum, sin, fmf);
609  Value diff =
610  arith::SubFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf);
611  Value resultImag = arith::MulFOp::create(rewriter, loc, diff, cos, fmf);
612  return {resultReal, resultImag};
613  }
614 };
615 
616 // The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
617 struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
619 
620  LogicalResult
621  matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
622  ConversionPatternRewriter &rewriter) const override {
623  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
624 
625  auto type = cast<ComplexType>(op.getType());
626  auto elementType = cast<FloatType>(type.getElementType());
627  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
628 
629  auto cst = [&](APFloat v) {
630  return arith::ConstantOp::create(b, elementType,
631  b.getFloatAttr(elementType, v));
632  };
633  const auto &floatSemantics = elementType.getFloatSemantics();
634  Value zero = cst(APFloat::getZero(floatSemantics));
635  Value half = arith::ConstantOp::create(b, elementType,
636  b.getFloatAttr(elementType, 0.5));
637 
638  Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
639  Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
640  Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
641  Value argArg = math::Atan2Op::create(b, imag, real, fmf);
642  Value sqrtArg = arith::MulFOp::create(b, argArg, half, fmf);
643  Value cos = math::CosOp::create(b, sqrtArg, fmf);
644  Value sin = math::SinOp::create(b, sqrtArg, fmf);
645  // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply
646  // 0 * inf.
647  Value sinIsZero =
648  arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, sin, zero, fmf);
649 
650  Value resultReal = arith::MulFOp::create(b, absSqrt, cos, fmf);
651  Value resultImag = arith::SelectOp::create(
652  b, sinIsZero, zero, arith::MulFOp::create(b, absSqrt, sin, fmf));
653  if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
654  arith::FastMathFlags::ninf)) {
655  Value inf = cst(APFloat::getInf(floatSemantics));
656  Value negInf = cst(APFloat::getInf(floatSemantics, true));
657  Value nan = cst(APFloat::getNaN(floatSemantics));
658  Value absImag = math::AbsFOp::create(b, elementType, imag, fmf);
659 
660  Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
661  absImag, inf, fmf);
662  Value absImagIsNotInf = arith::CmpFOp::create(
663  b, arith::CmpFPredicate::ONE, absImag, inf, fmf);
664  Value realIsInf =
665  arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, inf, fmf);
666  Value realIsNegInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
667  real, negInf, fmf);
668 
669  resultReal = arith::SelectOp::create(
670  b, arith::AndIOp::create(b, realIsNegInf, absImagIsNotInf), zero,
671  resultReal);
672  resultReal = arith::SelectOp::create(
673  b, arith::OrIOp::create(b, absImagIsInf, realIsInf), inf, resultReal);
674 
675  Value imagSignInf = math::CopySignOp::create(b, inf, imag, fmf);
676  resultImag = arith::SelectOp::create(
677  b,
678  arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, absSqrt, absSqrt),
679  nan, resultImag);
680  resultImag = arith::SelectOp::create(
681  b, arith::OrIOp::create(b, absImagIsInf, realIsNegInf), imagSignInf,
682  resultImag);
683  }
684 
685  Value resultIsZero =
686  arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
687  resultReal = arith::SelectOp::create(b, resultIsZero, zero, resultReal);
688  resultImag = arith::SelectOp::create(b, resultIsZero, zero, resultImag);
689 
690  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
691  resultImag);
692  return success();
693  }
694 };
695 
696 struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
698 
699  LogicalResult
700  matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
701  ConversionPatternRewriter &rewriter) const override {
702  auto type = cast<ComplexType>(adaptor.getComplex().getType());
703  auto elementType = cast<FloatType>(type.getElementType());
704  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
705  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
706 
707  Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
708  Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
709  Value zero =
710  arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType));
711  Value realIsZero =
712  arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero);
713  Value imagIsZero =
714  arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero);
715  Value isZero = arith::AndIOp::create(b, realIsZero, imagIsZero);
716  auto abs =
717  complex::AbsOp::create(b, elementType, adaptor.getComplex(), fmf);
718  Value realSign = arith::DivFOp::create(b, real, abs, fmf);
719  Value imagSign = arith::DivFOp::create(b, imag, abs, fmf);
720  Value sign = complex::CreateOp::create(b, type, realSign, imagSign);
721  rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
722  adaptor.getComplex(), sign);
723  return success();
724  }
725 };
726 
727 template <typename Op>
728 struct TanTanhOpConversion : public OpConversionPattern<Op> {
730 
731  LogicalResult
732  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
733  ConversionPatternRewriter &rewriter) const override {
734  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
735  auto loc = op.getLoc();
736  auto type = cast<ComplexType>(adaptor.getComplex().getType());
737  auto elementType = cast<FloatType>(type.getElementType());
738  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
739  const auto &floatSemantics = elementType.getFloatSemantics();
740 
741  Value real =
742  complex::ReOp::create(b, loc, elementType, adaptor.getComplex());
743  Value imag =
744  complex::ImOp::create(b, loc, elementType, adaptor.getComplex());
745  Value negOne = arith::ConstantOp::create(b, elementType,
746  b.getFloatAttr(elementType, -1.0));
747 
748  if constexpr (std::is_same_v<Op, complex::TanOp>) {
749  // tan(x+yi) = -i*tanh(-y + xi)
750  std::swap(real, imag);
751  real = arith::MulFOp::create(b, real, negOne, fmf);
752  }
753 
754  auto cst = [&](APFloat v) {
755  return arith::ConstantOp::create(b, elementType,
756  b.getFloatAttr(elementType, v));
757  };
758  Value inf = cst(APFloat::getInf(floatSemantics));
759  Value four = arith::ConstantOp::create(b, elementType,
760  b.getFloatAttr(elementType, 4.0));
761  Value twoReal = arith::AddFOp::create(b, real, real, fmf);
762  Value negTwoReal = arith::MulFOp::create(b, negOne, twoReal, fmf);
763 
764  Value expTwoRealMinusOne = math::ExpM1Op::create(b, twoReal, fmf);
765  Value expNegTwoRealMinusOne = math::ExpM1Op::create(b, negTwoReal, fmf);
766  Value realNum = arith::SubFOp::create(b, expTwoRealMinusOne,
767  expNegTwoRealMinusOne, fmf);
768 
769  Value cosImag = math::CosOp::create(b, imag, fmf);
770  Value cosImagSq = arith::MulFOp::create(b, cosImag, cosImag, fmf);
771  Value twoCosTwoImagPlusOne = arith::MulFOp::create(b, cosImagSq, four, fmf);
772  Value sinImag = math::SinOp::create(b, imag, fmf);
773 
774  Value imagNum = arith::MulFOp::create(
775  b, four, arith::MulFOp::create(b, cosImag, sinImag, fmf), fmf);
776 
777  Value expSumMinusTwo = arith::AddFOp::create(b, expTwoRealMinusOne,
778  expNegTwoRealMinusOne, fmf);
779  Value denom =
780  arith::AddFOp::create(b, expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
781 
782  Value isInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
783  expSumMinusTwo, inf, fmf);
784  Value realLimit = math::CopySignOp::create(b, negOne, real, fmf);
785 
786  Value resultReal = arith::SelectOp::create(
787  b, isInf, realLimit, arith::DivFOp::create(b, realNum, denom, fmf));
788  Value resultImag = arith::DivFOp::create(b, imagNum, denom, fmf);
789 
790  if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
791  arith::FastMathFlags::ninf)) {
792  Value absReal = math::AbsFOp::create(b, real, fmf);
793  Value zero = arith::ConstantOp::create(b, elementType,
794  b.getFloatAttr(elementType, 0.0));
795  Value nan = cst(APFloat::getNaN(floatSemantics));
796 
797  Value absRealIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
798  absReal, inf, fmf);
799  Value imagIsZero =
800  arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf);
801  Value absRealIsNotInf = arith::XOrIOp::create(
802  b, absRealIsInf, arith::ConstantIntOp::create(b, true, /*width=*/1));
803 
804  Value imagNumIsNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO,
805  imagNum, imagNum, fmf);
806  Value resultRealIsNaN =
807  arith::AndIOp::create(b, imagNumIsNaN, absRealIsNotInf);
808  Value resultImagIsZero = arith::OrIOp::create(
809  b, imagIsZero, arith::AndIOp::create(b, absRealIsInf, imagNumIsNaN));
810 
811  resultReal = arith::SelectOp::create(b, resultRealIsNaN, nan, resultReal);
812  resultImag =
813  arith::SelectOp::create(b, resultImagIsZero, zero, resultImag);
814  }
815 
816  if constexpr (std::is_same_v<Op, complex::TanOp>) {
817  // tan(x+yi) = -i*tanh(-y + xi)
818  std::swap(resultReal, resultImag);
819  resultImag = arith::MulFOp::create(b, resultImag, negOne, fmf);
820  }
821 
822  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
823  resultImag);
824  return success();
825  }
826 };
827 
828 struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
830 
831  LogicalResult
832  matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
833  ConversionPatternRewriter &rewriter) const override {
834  auto loc = op.getLoc();
835  auto type = cast<ComplexType>(adaptor.getComplex().getType());
836  auto elementType = cast<FloatType>(type.getElementType());
837  Value real =
838  complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
839  Value imag =
840  complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
841  Value negImag = arith::NegFOp::create(rewriter, loc, elementType, imag);
842 
843  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);
844 
845  return success();
846  }
847 };
848 
849 /// Converts lhs^y = (a+bi)^(c+di) to
850 /// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
851 /// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
852 static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
853  ComplexType type, Value lhs, Value c, Value d,
854  arith::FastMathFlags fmf) {
855  auto elementType = cast<FloatType>(type.getElementType());
856 
857  Value a = complex::ReOp::create(builder, lhs);
858  Value b = complex::ImOp::create(builder, lhs);
859 
860  Value abs = complex::AbsOp::create(builder, lhs, fmf);
861  Value absToC = math::PowFOp::create(builder, abs, c, fmf);
862 
863  Value negD = arith::NegFOp::create(builder, d, fmf);
864  Value argLhs = math::Atan2Op::create(builder, b, a, fmf);
865  Value negDArgLhs = arith::MulFOp::create(builder, negD, argLhs, fmf);
866  Value expNegDArgLhs = math::ExpOp::create(builder, negDArgLhs, fmf);
867 
868  Value coeff = arith::MulFOp::create(builder, absToC, expNegDArgLhs, fmf);
869  Value lnAbs = math::LogOp::create(builder, abs, fmf);
870  Value cArgLhs = arith::MulFOp::create(builder, c, argLhs, fmf);
871  Value dLnAbs = arith::MulFOp::create(builder, d, lnAbs, fmf);
872  Value q = arith::AddFOp::create(builder, cArgLhs, dLnAbs, fmf);
873  Value cosQ = math::CosOp::create(builder, q, fmf);
874  Value sinQ = math::SinOp::create(builder, q, fmf);
875 
876  Value inf = arith::ConstantOp::create(
877  builder, elementType,
878  builder.getFloatAttr(elementType,
879  APFloat::getInf(elementType.getFloatSemantics())));
880  Value zero = arith::ConstantOp::create(
881  builder, elementType, builder.getFloatAttr(elementType, 0.0));
882  Value one = arith::ConstantOp::create(builder, elementType,
883  builder.getFloatAttr(elementType, 1.0));
884  Value complexOne = complex::CreateOp::create(builder, type, one, zero);
885  Value complexZero = complex::CreateOp::create(builder, type, zero, zero);
886  Value complexInf = complex::CreateOp::create(builder, type, inf, zero);
887 
888  // Case 0:
889  // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
890  // Branch Cuts for Complex Elementary Functions or Much Ado About
891  // Nothing's Sign Bit, W. Kahan, Section 10.
892  Value absEqZero =
893  arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, abs, zero, fmf);
894  Value dEqZero =
895  arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, d, zero, fmf);
896  Value cEqZero =
897  arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, c, zero, fmf);
898  Value bEqZero =
899  arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, b, zero, fmf);
900 
901  Value zeroLeC =
902  arith::CmpFOp::create(builder, arith::CmpFPredicate::OLE, zero, c, fmf);
903  Value coeffCosQ = arith::MulFOp::create(builder, coeff, cosQ, fmf);
904  Value coeffSinQ = arith::MulFOp::create(builder, coeff, sinQ, fmf);
905  Value complexOneOrZero =
906  arith::SelectOp::create(builder, cEqZero, complexOne, complexZero);
907  Value coeffCosSin =
908  complex::CreateOp::create(builder, type, coeffCosQ, coeffSinQ);
909  Value cutoff0 = arith::SelectOp::create(
910  builder,
911  arith::AndIOp::create(
912  builder, arith::AndIOp::create(builder, absEqZero, dEqZero), zeroLeC),
913  complexOneOrZero, coeffCosSin);
914 
915  // Case 1:
916  // x^0 is defined to be 1 for any x, see
917  // Branch Cuts for Complex Elementary Functions or Much Ado About
918  // Nothing's Sign Bit, W. Kahan, Section 10.
919  Value rhsEqZero = arith::AndIOp::create(builder, cEqZero, dEqZero);
920  Value cutoff1 =
921  arith::SelectOp::create(builder, rhsEqZero, complexOne, cutoff0);
922 
923  // Case 2:
924  // 1^(c + d*i) = 1 + 0*i
925  Value lhsEqOne = arith::AndIOp::create(
926  builder,
927  arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, one, fmf),
928  bEqZero);
929  Value cutoff2 =
930  arith::SelectOp::create(builder, lhsEqOne, complexOne, cutoff1);
931 
932  // Case 3:
933  // inf^(c + 0*i) = inf + 0*i, c > 0
934  Value lhsEqInf = arith::AndIOp::create(
935  builder,
936  arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, inf, fmf),
937  bEqZero);
938  Value rhsGt0 = arith::AndIOp::create(
939  builder, dEqZero,
940  arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, c, zero, fmf));
941  Value cutoff3 = arith::SelectOp::create(
942  builder, arith::AndIOp::create(builder, lhsEqInf, rhsGt0), complexInf,
943  cutoff2);
944 
945  // Case 4:
946  // inf^(c + 0*i) = 0 + 0*i, c < 0
947  Value rhsLt0 = arith::AndIOp::create(
948  builder, dEqZero,
949  arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, c, zero, fmf));
950  Value cutoff4 = arith::SelectOp::create(
951  builder, arith::AndIOp::create(builder, lhsEqInf, rhsLt0), complexZero,
952  cutoff3);
953 
954  return cutoff4;
955 }
956 
957 struct PowiOpConversion : public OpConversionPattern<complex::PowiOp> {
959 
960  LogicalResult
961  matchAndRewrite(complex::PowiOp op, OpAdaptor adaptor,
962  ConversionPatternRewriter &rewriter) const override {
963  ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
964  auto type = cast<ComplexType>(op.getType());
965  auto elementType = cast<FloatType>(type.getElementType());
966 
967  Value floatExponent =
968  arith::SIToFPOp::create(builder, elementType, adaptor.getRhs());
969  Value zero = arith::ConstantOp::create(
970  builder, elementType, builder.getFloatAttr(elementType, 0.0));
971  Value complexExponent =
972  complex::CreateOp::create(builder, type, floatExponent, zero);
973 
974  auto pow = complex::PowOp::create(builder, type, adaptor.getLhs(),
975  complexExponent, op.getFastmathAttr());
976  rewriter.replaceOp(op, pow.getResult());
977  return success();
978  }
979 };
980 
981 struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
983 
984  LogicalResult
985  matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
986  ConversionPatternRewriter &rewriter) const override {
987  mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
988  auto type = cast<ComplexType>(adaptor.getLhs().getType());
989  auto elementType = cast<FloatType>(type.getElementType());
990 
991  Value c = complex::ReOp::create(builder, elementType, adaptor.getRhs());
992  Value d = complex::ImOp::create(builder, elementType, adaptor.getRhs());
993 
994  rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
995  c, d, op.getFastmath())});
996  return success();
997  }
998 };
999 
1000 struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
1002 
1003  LogicalResult
1004  matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
1005  ConversionPatternRewriter &rewriter) const override {
1006  mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1007  auto type = cast<ComplexType>(adaptor.getComplex().getType());
1008  auto elementType = cast<FloatType>(type.getElementType());
1009 
1010  arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
1011 
1012  auto cst = [&](APFloat v) {
1013  return arith::ConstantOp::create(b, elementType,
1014  b.getFloatAttr(elementType, v));
1015  };
1016  const auto &floatSemantics = elementType.getFloatSemantics();
1017  Value zero = cst(APFloat::getZero(floatSemantics));
1018  Value inf = cst(APFloat::getInf(floatSemantics));
1019  Value negHalf = arith::ConstantOp::create(
1020  b, elementType, b.getFloatAttr(elementType, -0.5));
1021  Value nan = cst(APFloat::getNaN(floatSemantics));
1022 
1023  Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
1024  Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
1025  Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
1026  Value argArg = math::Atan2Op::create(b, imag, real, fmf);
1027  Value rsqrtArg = arith::MulFOp::create(b, argArg, negHalf, fmf);
1028  Value cos = math::CosOp::create(b, rsqrtArg, fmf);
1029  Value sin = math::SinOp::create(b, rsqrtArg, fmf);
1030 
1031  Value resultReal = arith::MulFOp::create(b, absRsqrt, cos, fmf);
1032  Value resultImag = arith::MulFOp::create(b, absRsqrt, sin, fmf);
1033 
1034  if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1035  arith::FastMathFlags::ninf)) {
1036  Value negOne = arith::ConstantOp::create(b, elementType,
1037  b.getFloatAttr(elementType, -1));
1038 
1039  Value realSignedZero = math::CopySignOp::create(b, zero, real, fmf);
1040  Value imagSignedZero = math::CopySignOp::create(b, zero, imag, fmf);
1041  Value negImagSignedZero =
1042  arith::MulFOp::create(b, negOne, imagSignedZero, fmf);
1043 
1044  Value absReal = math::AbsFOp::create(b, real, fmf);
1045  Value absImag = math::AbsFOp::create(b, imag, fmf);
1046 
1047  Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
1048  absImag, inf, fmf);
1049  Value realIsNan =
1050  arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, real, real, fmf);
1051  Value realIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
1052  absReal, inf, fmf);
1053  Value inIsNanInf = arith::AndIOp::create(b, absImagIsInf, realIsNan);
1054 
1055  Value resultIsZero = arith::OrIOp::create(b, inIsNanInf, realIsInf);
1056 
1057  resultReal =
1058  arith::SelectOp::create(b, resultIsZero, realSignedZero, resultReal);
1059  resultImag = arith::SelectOp::create(b, resultIsZero, negImagSignedZero,
1060  resultImag);
1061  }
1062 
1063  Value isRealZero =
1064  arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero, fmf);
1065  Value isImagZero =
1066  arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf);
1067  Value isZero = arith::AndIOp::create(b, isRealZero, isImagZero);
1068 
1069  resultReal = arith::SelectOp::create(b, isZero, inf, resultReal);
1070  resultImag = arith::SelectOp::create(b, isZero, nan, resultImag);
1071 
1072  rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
1073  resultImag);
1074  return success();
1075  }
1076 };
1077 
1078 struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
1080 
1081  LogicalResult
1082  matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
1083  ConversionPatternRewriter &rewriter) const override {
1084  auto loc = op.getLoc();
1085  auto type = op.getType();
1086  arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
1087 
1088  Value real =
1089  complex::ReOp::create(rewriter, loc, type, adaptor.getComplex());
1090  Value imag =
1091  complex::ImOp::create(rewriter, loc, type, adaptor.getComplex());
1092 
1093  rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf);
1094 
1095  return success();
1096  }
1097 };
1098 
1099 } // namespace
1100 
1102  RewritePatternSet &patterns, complex::ComplexRangeFlags complexRange) {
1103  // clang-format off
1104  patterns.add<
1105  AbsOpConversion,
1106  AngleOpConversion,
1107  Atan2OpConversion,
1108  BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1109  BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
1110  ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
1111  ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
1112  ConjOpConversion,
1113  CosOpConversion,
1114  ExpOpConversion,
1115  Expm1OpConversion,
1116  Log1pOpConversion,
1117  LogOpConversion,
1118  MulOpConversion,
1119  NegOpConversion,
1120  SignOpConversion,
1121  SinOpConversion,
1122  SqrtOpConversion,
1123  TanTanhOpConversion<complex::TanOp>,
1124  TanTanhOpConversion<complex::TanhOp>,
1125  PowiOpConversion,
1126  PowOpConversion,
1127  RsqrtOpConversion
1128  >(patterns.getContext());
1129 
1130  patterns.add<DivOpConversion>(patterns.getContext(), complexRange);
1131 
1132  // clang-format on
1133 }
1134 
1135 namespace {
1136 struct ConvertComplexToStandardPass
1137  : public impl::ConvertComplexToStandardPassBase<
1138  ConvertComplexToStandardPass> {
1139  using Base::Base;
1140 
1141  void runOnOperation() override;
1142 };
1143 
1144 void ConvertComplexToStandardPass::runOnOperation() {
1145  // Convert to the Standard dialect using the converter defined above.
1148 
1149  ConversionTarget target(getContext());
1150  target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
1151  target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1152  if (failed(
1153  applyPartialConversion(getOperation(), target, std::move(patterns))))
1154  signalPassFailure();
1155 }
1156 } // namespace
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:254
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:324
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:630
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:129
This provides public APIs that all operations should have.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition: ArithOps.cpp:258
void convertDivToStandardUsingAlgebraic(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the arith/math dialects using algebraic method
void convertDivToStandardUsingRangeReduction(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the arith/math dialects using Smith's method
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
Include the generated interface declarations.
void populateComplexToStandardConversionPatterns(RewritePatternSet &patterns, mlir::complex::ComplexRangeFlags complexRange=mlir::complex::ComplexRangeFlags::improved)
Populate the given list with patterns that convert from Complex to Standard.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.