MLIR 22.0.0git
ComplexToStandard.cpp
Go to the documentation of this file.
1//===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
17#include <type_traits>
18
19namespace mlir {
20#define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARDPASS
21#include "mlir/Conversion/Passes.h.inc"
22} // namespace mlir
23
24using namespace mlir;
25
26namespace {
27
28enum class AbsFn { abs, sqrt, rsqrt };
29
30// Returns the absolute value, its square root or its reciprocal square root.
31Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
32 ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) {
33 Value one = arith::ConstantOp::create(b, real.getType(),
34 b.getFloatAttr(real.getType(), 1.0));
35
36 Value absReal = math::AbsFOp::create(b, real, fmf);
37 Value absImag = math::AbsFOp::create(b, imag, fmf);
38
39 Value max = arith::MaximumFOp::create(b, absReal, absImag, fmf);
40 Value min = arith::MinimumFOp::create(b, absReal, absImag, fmf);
41
42 // The lowering below requires NaNs and infinities to work correctly.
43 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
44 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
45 Value ratio = arith::DivFOp::create(b, min, max, fmfWithNaNInf);
46 Value ratioSq = arith::MulFOp::create(b, ratio, ratio, fmfWithNaNInf);
47 Value ratioSqPlusOne = arith::AddFOp::create(b, ratioSq, one, fmfWithNaNInf);
49
50 if (fn == AbsFn::rsqrt) {
51 ratioSqPlusOne = math::RsqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf);
52 min = math::RsqrtOp::create(b, min, fmfWithNaNInf);
53 max = math::RsqrtOp::create(b, max, fmfWithNaNInf);
54 }
55
56 if (fn == AbsFn::sqrt) {
57 Value quarter = arith::ConstantOp::create(
58 b, real.getType(), b.getFloatAttr(real.getType(), 0.25));
59 // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
60 Value sqrt = math::SqrtOp::create(b, max, fmfWithNaNInf);
61 Value p025 =
62 math::PowFOp::create(b, ratioSqPlusOne, quarter, fmfWithNaNInf);
63 result = arith::MulFOp::create(b, sqrt, p025, fmfWithNaNInf);
64 } else {
65 Value sqrt = math::SqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf);
66 result = arith::MulFOp::create(b, max, sqrt, fmfWithNaNInf);
67 }
68
69 Value isNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, result,
70 result, fmfWithNaNInf);
71 return arith::SelectOp::create(b, isNaN, min, result);
72}
73
74struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
75 using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
76
77 LogicalResult
78 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
79 ConversionPatternRewriter &rewriter) const override {
80 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
81
82 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
83
84 Value real = complex::ReOp::create(b, adaptor.getComplex());
85 Value imag = complex::ImOp::create(b, adaptor.getComplex());
86 rewriter.replaceOp(op, computeAbs(real, imag, fmf, b));
87
88 return success();
89 }
90};
91
92// atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
93struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
94 using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
95
96 LogicalResult
97 matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
98 ConversionPatternRewriter &rewriter) const override {
99 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
100
101 auto type = cast<ComplexType>(op.getType());
102 Type elementType = type.getElementType();
103 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
104
105 Value lhs = adaptor.getLhs();
106 Value rhs = adaptor.getRhs();
107
108 Value rhsSquared = complex::MulOp::create(b, type, rhs, rhs, fmf);
109 Value lhsSquared = complex::MulOp::create(b, type, lhs, lhs, fmf);
110 Value rhsSquaredPlusLhsSquared =
111 complex::AddOp::create(b, type, rhsSquared, lhsSquared, fmf);
112 Value sqrtOfRhsSquaredPlusLhsSquared =
113 complex::SqrtOp::create(b, type, rhsSquaredPlusLhsSquared, fmf);
114
115 Value zero =
116 arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType));
117 Value one = arith::ConstantOp::create(b, elementType,
118 b.getFloatAttr(elementType, 1));
119 Value i = complex::CreateOp::create(b, type, zero, one);
120 Value iTimesLhs = complex::MulOp::create(b, i, lhs, fmf);
121 Value rhsPlusILhs = complex::AddOp::create(b, rhs, iTimesLhs, fmf);
122
123 Value divResult = complex::DivOp::create(
124 b, rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
125 Value logResult = complex::LogOp::create(b, divResult, fmf);
126
127 Value negativeOne = arith::ConstantOp::create(
128 b, elementType, b.getFloatAttr(elementType, -1));
129 Value negativeI = complex::CreateOp::create(b, type, zero, negativeOne);
130
131 rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf);
132 return success();
133 }
134};
135
136template <typename ComparisonOp, arith::CmpFPredicate p>
137struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
138 using OpConversionPattern<ComparisonOp>::OpConversionPattern;
139 using ResultCombiner =
140 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
141 arith::AndIOp, arith::OrIOp>;
142
143 LogicalResult
144 matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
145 ConversionPatternRewriter &rewriter) const override {
146 auto loc = op.getLoc();
147 auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
148
149 Value realLhs =
150 complex::ReOp::create(rewriter, loc, type, adaptor.getLhs());
151 Value imagLhs =
152 complex::ImOp::create(rewriter, loc, type, adaptor.getLhs());
153 Value realRhs =
154 complex::ReOp::create(rewriter, loc, type, adaptor.getRhs());
155 Value imagRhs =
156 complex::ImOp::create(rewriter, loc, type, adaptor.getRhs());
157 Value realComparison =
158 arith::CmpFOp::create(rewriter, loc, p, realLhs, realRhs);
159 Value imagComparison =
160 arith::CmpFOp::create(rewriter, loc, p, imagLhs, imagRhs);
161
162 rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
163 imagComparison);
164 return success();
165 }
166};
167
168// Default conversion which applies the BinaryStandardOp separately on the real
169// and imaginary parts. Can for example be used for complex::AddOp and
170// complex::SubOp.
171template <typename BinaryComplexOp, typename BinaryStandardOp>
172struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
173 using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
174
175 LogicalResult
176 matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
177 ConversionPatternRewriter &rewriter) const override {
178 auto type = cast<ComplexType>(adaptor.getLhs().getType());
179 auto elementType = cast<FloatType>(type.getElementType());
180 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
181 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
182
183 Value realLhs = complex::ReOp::create(b, elementType, adaptor.getLhs());
184 Value realRhs = complex::ReOp::create(b, elementType, adaptor.getRhs());
185 Value resultReal = BinaryStandardOp::create(b, elementType, realLhs,
186 realRhs, fmf.getValue());
187 Value imagLhs = complex::ImOp::create(b, elementType, adaptor.getLhs());
188 Value imagRhs = complex::ImOp::create(b, elementType, adaptor.getRhs());
189 Value resultImag = BinaryStandardOp::create(b, elementType, imagLhs,
190 imagRhs, fmf.getValue());
191 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
192 resultImag);
193 return success();
194 }
195};
196
197template <typename TrigonometricOp>
198struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
199 using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
200
201 using OpConversionPattern<TrigonometricOp>::OpConversionPattern;
202
203 LogicalResult
204 matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
205 ConversionPatternRewriter &rewriter) const override {
206 auto loc = op.getLoc();
207 auto type = cast<ComplexType>(adaptor.getComplex().getType());
208 auto elementType = cast<FloatType>(type.getElementType());
209 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
210
211 Value real =
212 complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
213 Value imag =
214 complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
215
216 // Trigonometric ops use a set of common building blocks to convert to real
217 // ops. Here we create these building blocks and call into an op-specific
218 // implementation in the subclass to combine them.
219 Value half = arith::ConstantOp::create(
220 rewriter, loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
221 Value exp = math::ExpOp::create(rewriter, loc, imag, fmf);
222 Value scaledExp = arith::MulFOp::create(rewriter, loc, half, exp, fmf);
223 Value reciprocalExp = arith::DivFOp::create(rewriter, loc, half, exp, fmf);
224 Value sin = math::SinOp::create(rewriter, loc, real, fmf);
225 Value cos = math::CosOp::create(rewriter, loc, real, fmf);
226
227 auto resultPair =
228 combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
229
230 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
231 resultPair.second);
232 return success();
233 }
234
235 virtual std::pair<Value, Value>
236 combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
237 Value cos, ConversionPatternRewriter &rewriter,
238 arith::FastMathFlagsAttr fmf) const = 0;
239};
240
241struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
242 using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
243
244 std::pair<Value, Value> combine(Location loc, Value scaledExp,
245 Value reciprocalExp, Value sin, Value cos,
246 ConversionPatternRewriter &rewriter,
247 arith::FastMathFlagsAttr fmf) const override {
248 // Complex cosine is defined as;
249 // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
250 // Plugging in:
251 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
252 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
253 // and defining t := exp(y)
254 // We get:
255 // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
256 // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
257 Value sum =
258 arith::AddFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf);
259 Value resultReal = arith::MulFOp::create(rewriter, loc, sum, cos, fmf);
260 Value diff =
261 arith::SubFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf);
262 Value resultImag = arith::MulFOp::create(rewriter, loc, diff, sin, fmf);
263 return {resultReal, resultImag};
264 }
265};
266
267struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
268 DivOpConversion(MLIRContext *context, complex::ComplexRangeFlags target)
269 : OpConversionPattern<complex::DivOp>(context), complexRange(target) {}
270
271 using OpConversionPattern<complex::DivOp>::OpConversionPattern;
272
273 LogicalResult
274 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
275 ConversionPatternRewriter &rewriter) const override {
276 auto loc = op.getLoc();
277 auto type = cast<ComplexType>(adaptor.getLhs().getType());
278 auto elementType = cast<FloatType>(type.getElementType());
279 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
280
281 Value lhsReal =
282 complex::ReOp::create(rewriter, loc, elementType, adaptor.getLhs());
283 Value lhsImag =
284 complex::ImOp::create(rewriter, loc, elementType, adaptor.getLhs());
285 Value rhsReal =
286 complex::ReOp::create(rewriter, loc, elementType, adaptor.getRhs());
287 Value rhsImag =
288 complex::ImOp::create(rewriter, loc, elementType, adaptor.getRhs());
289
290 Value resultReal, resultImag;
291
292 if (complexRange == complex::ComplexRangeFlags::basic ||
293 complexRange == complex::ComplexRangeFlags::none) {
295 rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal,
296 &resultImag);
297 } else if (complexRange == complex::ComplexRangeFlags::improved) {
299 rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal,
300 &resultImag);
301 }
302
303 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
304 resultImag);
305
306 return success();
307 }
308
309private:
310 complex::ComplexRangeFlags complexRange;
311};
312
313struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
314 using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
315
316 // exp(x+I*y) = exp(x)*(cos(y)+I*sin(y))
317 // Handle special cases as StableHLO implementation does:
318 // 1. When b == 0, set imag(exp(z)) = 0
319 // 2. When exp(x) == inf, use exp(x/2)*(cos(y)+I*sin(y))*exp(x/2)
320 LogicalResult
321 matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
322 ConversionPatternRewriter &rewriter) const override {
323 auto loc = op.getLoc();
324 auto type = cast<ComplexType>(adaptor.getComplex().getType());
325 auto ET = cast<FloatType>(type.getElementType());
326 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
327 const auto &floatSemantics = ET.getFloatSemantics();
328 ImplicitLocOpBuilder b(loc, rewriter);
329
330 Value x = complex::ReOp::create(b, ET, adaptor.getComplex());
331 Value y = complex::ImOp::create(b, ET, adaptor.getComplex());
332 Value zero = arith::ConstantOp::create(b, ET, b.getZeroAttr(ET));
333 Value half = arith::ConstantOp::create(b, ET, b.getFloatAttr(ET, 0.5));
334 Value inf = arith::ConstantOp::create(
335 b, ET, b.getFloatAttr(ET, APFloat::getInf(floatSemantics)));
336
337 Value exp = math::ExpOp::create(b, x, fmf);
338 Value xHalf = arith::MulFOp::create(b, x, half, fmf);
339 Value expHalf = math::ExpOp::create(b, xHalf, fmf);
340 Value cos = math::CosOp::create(b, y, fmf);
341 Value sin = math::SinOp::create(b, y, fmf);
342
343 Value expIsInf =
344 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, exp, inf, fmf);
345 Value yIsZero =
346 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, y, zero);
347
348 // Real path: select between exp(x)*cos(y) and exp(x/2)*cos(y)*exp(x/2)
349 Value realNormal = arith::MulFOp::create(b, exp, cos, fmf);
350 Value expHalfCos = arith::MulFOp::create(b, expHalf, cos, fmf);
351 Value realOverflow = arith::MulFOp::create(b, expHalfCos, expHalf, fmf);
352 Value resultReal =
353 arith::SelectOp::create(b, expIsInf, realOverflow, realNormal);
354
355 // Imaginary part: if y == 0 return 0 else select between exp(x)*sin(y) and
356 // exp(x/2)*sin(y)*exp(x/2)
357 Value imagNormal = arith::MulFOp::create(b, exp, sin, fmf);
358 Value expHalfSin = arith::MulFOp::create(b, expHalf, sin, fmf);
359 Value imagOverflow = arith::MulFOp::create(b, expHalfSin, expHalf, fmf);
360 Value imagNonZero =
361 arith::SelectOp::create(b, expIsInf, imagOverflow, imagNormal);
362 Value resultImag = arith::SelectOp::create(b, yIsZero, zero, imagNonZero);
363
364 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
365 resultImag);
366 return success();
367 }
368};
369
370Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg,
371 ArrayRef<double> coefficients,
372 arith::FastMathFlagsAttr fmf) {
373 auto argType = mlir::cast<FloatType>(arg.getType());
374 Value poly =
375 arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[0]));
376 for (unsigned i = 1; i < coefficients.size(); ++i) {
377 poly = math::FmaOp::create(
378 b, poly, arg,
379 arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[i])),
380 fmf);
381 }
382 return poly;
383}
384
385struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
386 using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
387
388 // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
389 // [handle inaccuracies when a and/or b are small]
390 // = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i
391 // = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i
392 LogicalResult
393 matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
394 ConversionPatternRewriter &rewriter) const override {
395 auto type = op.getType();
396 auto elemType = mlir::cast<FloatType>(type.getElementType());
397
398 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
399 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
400 Value real = complex::ReOp::create(b, adaptor.getComplex());
401 Value imag = complex::ImOp::create(b, adaptor.getComplex());
402
403 Value zero = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 0.0));
404 Value one = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 1.0));
405
406 Value expm1Real = math::ExpM1Op::create(b, real, fmf);
407 Value expReal = arith::AddFOp::create(b, expm1Real, one, fmf);
408
409 Value sinImag = math::SinOp::create(b, imag, fmf);
410 Value cosm1Imag = emitCosm1(imag, fmf, b);
411 Value cosImag = arith::AddFOp::create(b, cosm1Imag, one, fmf);
412
413 Value realResult = arith::AddFOp::create(
414 b, arith::MulFOp::create(b, expm1Real, cosImag, fmf), cosm1Imag, fmf);
415
416 Value imagIsZero = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag,
417 zero, fmf.getValue());
418 Value imagResult = arith::SelectOp::create(
419 b, imagIsZero, zero, arith::MulFOp::create(b, expReal, sinImag, fmf));
420
421 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult,
422 imagResult);
423 return success();
424 }
425
426private:
427 Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf,
428 ImplicitLocOpBuilder &b) const {
429 auto argType = mlir::cast<FloatType>(arg.getType());
430 auto negHalf = arith::ConstantOp::create(b, b.getFloatAttr(argType, -0.5));
431 auto negOne = arith::ConstantOp::create(b, b.getFloatAttr(argType, -1.0));
432
433 // Algorithm copied from cephes cosm1.
434 SmallVector<double, 7> kCoeffs{
435 4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
436 2.0876754287081521758361E-9, -2.7557319214999787979814E-7,
437 2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
438 4.1666666666666666609054E-2,
439 };
440 Value cos = math::CosOp::create(b, arg, fmf);
441 Value forLargeArg = arith::AddFOp::create(b, cos, negOne, fmf);
442
443 Value argPow2 = arith::MulFOp::create(b, arg, arg, fmf);
444 Value argPow4 = arith::MulFOp::create(b, argPow2, argPow2, fmf);
445 Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf);
446
447 auto forSmallArg =
448 arith::AddFOp::create(b, arith::MulFOp::create(b, argPow4, poly, fmf),
449 arith::MulFOp::create(b, negHalf, argPow2, fmf));
450
451 // (pi/4)^2 is approximately 0.61685
452 Value piOver4Pow2 =
453 arith::ConstantOp::create(b, b.getFloatAttr(argType, 0.61685));
454 Value cond = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, argPow2,
455 piOver4Pow2, fmf.getValue());
456 return arith::SelectOp::create(b, cond, forLargeArg, forSmallArg);
457 }
458};
459
460struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
461 using OpConversionPattern<complex::LogOp>::OpConversionPattern;
462
463 LogicalResult
464 matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
465 ConversionPatternRewriter &rewriter) const override {
466 auto type = cast<ComplexType>(adaptor.getComplex().getType());
467 auto elementType = cast<FloatType>(type.getElementType());
468 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
469 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
470
471 Value abs = complex::AbsOp::create(b, elementType, adaptor.getComplex(),
472 fmf.getValue());
473 Value resultReal = math::LogOp::create(b, elementType, abs, fmf.getValue());
474 Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
475 Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
476 Value resultImag =
477 math::Atan2Op::create(b, elementType, imag, real, fmf.getValue());
478 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
479 resultImag);
480 return success();
481 }
482};
483
484struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
485 using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
486
487 LogicalResult
488 matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
489 ConversionPatternRewriter &rewriter) const override {
490 auto type = cast<ComplexType>(adaptor.getComplex().getType());
491 auto elementType = cast<FloatType>(type.getElementType());
492 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
493 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
494
495 Value real = complex::ReOp::create(b, adaptor.getComplex());
496 Value imag = complex::ImOp::create(b, adaptor.getComplex());
497
498 Value half = arith::ConstantOp::create(b, elementType,
499 b.getFloatAttr(elementType, 0.5));
500 Value one = arith::ConstantOp::create(b, elementType,
501 b.getFloatAttr(elementType, 1));
502 Value realPlusOne = arith::AddFOp::create(b, real, one, fmf);
503 Value absRealPlusOne = math::AbsFOp::create(b, realPlusOne, fmf);
504 Value absImag = math::AbsFOp::create(b, imag, fmf);
505
506 Value maxAbs = arith::MaximumFOp::create(b, absRealPlusOne, absImag, fmf);
507 Value minAbs = arith::MinimumFOp::create(b, absRealPlusOne, absImag, fmf);
508
509 Value useReal = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT,
510 realPlusOne, absImag, fmf);
511 Value maxMinusOne = arith::SubFOp::create(b, maxAbs, one, fmf);
512 Value maxAbsOfRealPlusOneAndImagMinusOne =
513 arith::SelectOp::create(b, useReal, real, maxMinusOne);
514 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
515 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
516 Value minMaxRatio = arith::DivFOp::create(b, minAbs, maxAbs, fmfWithNaNInf);
517 Value logOfMaxAbsOfRealPlusOneAndImag =
518 math::Log1pOp::create(b, maxAbsOfRealPlusOneAndImagMinusOne, fmf);
519 Value logOfSqrtPart = math::Log1pOp::create(
520 b, arith::MulFOp::create(b, minMaxRatio, minMaxRatio, fmfWithNaNInf),
521 fmfWithNaNInf);
522 Value r = arith::AddFOp::create(
523 b, arith::MulFOp::create(b, half, logOfSqrtPart, fmfWithNaNInf),
524 logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf);
525 Value resultReal = arith::SelectOp::create(
526 b,
527 arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, r, r,
528 fmfWithNaNInf),
529 minAbs, r);
530 Value resultImag = math::Atan2Op::create(b, imag, realPlusOne, fmf);
531 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
532 resultImag);
533 return success();
534 }
535};
536
537struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
538 using OpConversionPattern<complex::MulOp>::OpConversionPattern;
539
540 LogicalResult
541 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
542 ConversionPatternRewriter &rewriter) const override {
543 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
544 auto type = cast<ComplexType>(adaptor.getLhs().getType());
545 auto elementType = cast<FloatType>(type.getElementType());
546 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
547 auto fmfValue = fmf.getValue();
548 Value lhsReal = complex::ReOp::create(b, elementType, adaptor.getLhs());
549 Value lhsImag = complex::ImOp::create(b, elementType, adaptor.getLhs());
550 Value rhsReal = complex::ReOp::create(b, elementType, adaptor.getRhs());
551 Value rhsImag = complex::ImOp::create(b, elementType, adaptor.getRhs());
552 Value lhsRealTimesRhsReal =
553 arith::MulFOp::create(b, lhsReal, rhsReal, fmfValue);
554 Value lhsImagTimesRhsImag =
555 arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue);
556 Value real = arith::SubFOp::create(b, lhsRealTimesRhsReal,
557 lhsImagTimesRhsImag, fmfValue);
558 Value lhsImagTimesRhsReal =
559 arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue);
560 Value lhsRealTimesRhsImag =
561 arith::MulFOp::create(b, lhsReal, rhsImag, fmfValue);
562 Value imag = arith::AddFOp::create(b, lhsImagTimesRhsReal,
563 lhsRealTimesRhsImag, fmfValue);
564 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
565 return success();
566 }
567};
568
569struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
570 using OpConversionPattern<complex::NegOp>::OpConversionPattern;
571
572 LogicalResult
573 matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
574 ConversionPatternRewriter &rewriter) const override {
575 auto loc = op.getLoc();
576 auto type = cast<ComplexType>(adaptor.getComplex().getType());
577 auto elementType = cast<FloatType>(type.getElementType());
578
579 Value real =
580 complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
581 Value imag =
582 complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
583 Value negReal = arith::NegFOp::create(rewriter, loc, real);
584 Value negImag = arith::NegFOp::create(rewriter, loc, imag);
585 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
586 return success();
587 }
588};
589
590struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
591 using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
592
593 std::pair<Value, Value> combine(Location loc, Value scaledExp,
594 Value reciprocalExp, Value sin, Value cos,
595 ConversionPatternRewriter &rewriter,
596 arith::FastMathFlagsAttr fmf) const override {
597 // Complex sine is defined as;
598 // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
599 // Plugging in:
600 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
601 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
602 // and defining t := exp(y)
603 // We get:
604 // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
605 // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
606 Value sum =
607 arith::AddFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf);
608 Value resultReal = arith::MulFOp::create(rewriter, loc, sum, sin, fmf);
609 Value diff =
610 arith::SubFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf);
611 Value resultImag = arith::MulFOp::create(rewriter, loc, diff, cos, fmf);
612 return {resultReal, resultImag};
613 }
614};
615
616// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
617struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
618 using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
619
620 LogicalResult
621 matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
622 ConversionPatternRewriter &rewriter) const override {
623 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
624
625 auto type = cast<ComplexType>(op.getType());
626 auto elementType = cast<FloatType>(type.getElementType());
627 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
628
629 auto cst = [&](APFloat v) {
630 return arith::ConstantOp::create(b, elementType,
631 b.getFloatAttr(elementType, v));
632 };
633 const auto &floatSemantics = elementType.getFloatSemantics();
634 Value zero = cst(APFloat::getZero(floatSemantics));
635 Value half = arith::ConstantOp::create(b, elementType,
636 b.getFloatAttr(elementType, 0.5));
637
638 Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
639 Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
640 Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
641 Value argArg = math::Atan2Op::create(b, imag, real, fmf);
642 Value sqrtArg = arith::MulFOp::create(b, argArg, half, fmf);
643 Value cos = math::CosOp::create(b, sqrtArg, fmf);
644 Value sin = math::SinOp::create(b, sqrtArg, fmf);
645 // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply
646 // 0 * inf.
647 Value sinIsZero =
648 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, sin, zero, fmf);
649
650 Value resultReal = arith::MulFOp::create(b, absSqrt, cos, fmf);
651 Value resultImag = arith::SelectOp::create(
652 b, sinIsZero, zero, arith::MulFOp::create(b, absSqrt, sin, fmf));
653 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
654 arith::FastMathFlags::ninf)) {
655 Value inf = cst(APFloat::getInf(floatSemantics));
656 Value negInf = cst(APFloat::getInf(floatSemantics, true));
657 Value nan = cst(APFloat::getNaN(floatSemantics));
658 Value absImag = math::AbsFOp::create(b, elementType, imag, fmf);
659
660 Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
661 absImag, inf, fmf);
662 Value absImagIsNotInf = arith::CmpFOp::create(
663 b, arith::CmpFPredicate::ONE, absImag, inf, fmf);
664 Value realIsInf =
665 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, inf, fmf);
666 Value realIsNegInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
667 real, negInf, fmf);
668
669 resultReal = arith::SelectOp::create(
670 b, arith::AndIOp::create(b, realIsNegInf, absImagIsNotInf), zero,
671 resultReal);
672 resultReal = arith::SelectOp::create(
673 b, arith::OrIOp::create(b, absImagIsInf, realIsInf), inf, resultReal);
674
675 Value imagSignInf = math::CopySignOp::create(b, inf, imag, fmf);
676 resultImag = arith::SelectOp::create(
677 b,
678 arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, absSqrt, absSqrt),
679 nan, resultImag);
680 resultImag = arith::SelectOp::create(
681 b, arith::OrIOp::create(b, absImagIsInf, realIsNegInf), imagSignInf,
682 resultImag);
683 }
684
685 Value resultIsZero =
686 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
687 resultReal = arith::SelectOp::create(b, resultIsZero, zero, resultReal);
688 resultImag = arith::SelectOp::create(b, resultIsZero, zero, resultImag);
689
690 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
691 resultImag);
692 return success();
693 }
694};
695
696struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
697 using OpConversionPattern<complex::SignOp>::OpConversionPattern;
698
699 LogicalResult
700 matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
701 ConversionPatternRewriter &rewriter) const override {
702 auto type = cast<ComplexType>(adaptor.getComplex().getType());
703 auto elementType = cast<FloatType>(type.getElementType());
704 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
705 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
706
707 Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
708 Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
709 Value zero =
710 arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType));
711 Value realIsZero =
712 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero);
713 Value imagIsZero =
714 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero);
715 Value isZero = arith::AndIOp::create(b, realIsZero, imagIsZero);
716 auto abs =
717 complex::AbsOp::create(b, elementType, adaptor.getComplex(), fmf);
718 Value realSign = arith::DivFOp::create(b, real, abs, fmf);
719 Value imagSign = arith::DivFOp::create(b, imag, abs, fmf);
720 Value sign = complex::CreateOp::create(b, type, realSign, imagSign);
721 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
722 adaptor.getComplex(), sign);
723 return success();
724 }
725};
726
727template <typename Op>
728struct TanTanhOpConversion : public OpConversionPattern<Op> {
729 using OpConversionPattern<Op>::OpConversionPattern;
730
731 LogicalResult
732 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
733 ConversionPatternRewriter &rewriter) const override {
734 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
735 auto loc = op.getLoc();
736 auto type = cast<ComplexType>(adaptor.getComplex().getType());
737 auto elementType = cast<FloatType>(type.getElementType());
738 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
739 const auto &floatSemantics = elementType.getFloatSemantics();
740
741 Value real =
742 complex::ReOp::create(b, loc, elementType, adaptor.getComplex());
743 Value imag =
744 complex::ImOp::create(b, loc, elementType, adaptor.getComplex());
745 Value negOne = arith::ConstantOp::create(b, elementType,
746 b.getFloatAttr(elementType, -1.0));
747
748 if constexpr (std::is_same_v<Op, complex::TanOp>) {
749 // tan(x+yi) = -i*tanh(-y + xi)
750 std::swap(real, imag);
751 real = arith::MulFOp::create(b, real, negOne, fmf);
752 }
753
754 auto cst = [&](APFloat v) {
755 return arith::ConstantOp::create(b, elementType,
756 b.getFloatAttr(elementType, v));
757 };
758 Value inf = cst(APFloat::getInf(floatSemantics));
759 Value four = arith::ConstantOp::create(b, elementType,
760 b.getFloatAttr(elementType, 4.0));
761 Value twoReal = arith::AddFOp::create(b, real, real, fmf);
762 Value negTwoReal = arith::MulFOp::create(b, negOne, twoReal, fmf);
763
764 Value expTwoRealMinusOne = math::ExpM1Op::create(b, twoReal, fmf);
765 Value expNegTwoRealMinusOne = math::ExpM1Op::create(b, negTwoReal, fmf);
766 Value realNum = arith::SubFOp::create(b, expTwoRealMinusOne,
767 expNegTwoRealMinusOne, fmf);
768
769 Value cosImag = math::CosOp::create(b, imag, fmf);
770 Value cosImagSq = arith::MulFOp::create(b, cosImag, cosImag, fmf);
771 Value twoCosTwoImagPlusOne = arith::MulFOp::create(b, cosImagSq, four, fmf);
772 Value sinImag = math::SinOp::create(b, imag, fmf);
773
774 Value imagNum = arith::MulFOp::create(
775 b, four, arith::MulFOp::create(b, cosImag, sinImag, fmf), fmf);
776
777 Value expSumMinusTwo = arith::AddFOp::create(b, expTwoRealMinusOne,
778 expNegTwoRealMinusOne, fmf);
779 Value denom =
780 arith::AddFOp::create(b, expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
781
782 Value isInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
783 expSumMinusTwo, inf, fmf);
784 Value realLimit = math::CopySignOp::create(b, negOne, real, fmf);
785
786 Value resultReal = arith::SelectOp::create(
787 b, isInf, realLimit, arith::DivFOp::create(b, realNum, denom, fmf));
788 Value resultImag = arith::DivFOp::create(b, imagNum, denom, fmf);
789
790 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
791 arith::FastMathFlags::ninf)) {
792 Value absReal = math::AbsFOp::create(b, real, fmf);
793 Value zero = arith::ConstantOp::create(b, elementType,
794 b.getFloatAttr(elementType, 0.0));
795 Value nan = cst(APFloat::getNaN(floatSemantics));
796
797 Value absRealIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
798 absReal, inf, fmf);
799 Value imagIsZero =
800 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf);
801 Value absRealIsNotInf = arith::XOrIOp::create(
802 b, absRealIsInf, arith::ConstantIntOp::create(b, true, /*width=*/1));
803
804 Value imagNumIsNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO,
805 imagNum, imagNum, fmf);
806 Value resultRealIsNaN =
807 arith::AndIOp::create(b, imagNumIsNaN, absRealIsNotInf);
808 Value resultImagIsZero = arith::OrIOp::create(
809 b, imagIsZero, arith::AndIOp::create(b, absRealIsInf, imagNumIsNaN));
810
811 resultReal = arith::SelectOp::create(b, resultRealIsNaN, nan, resultReal);
812 resultImag =
813 arith::SelectOp::create(b, resultImagIsZero, zero, resultImag);
814 }
815
816 if constexpr (std::is_same_v<Op, complex::TanOp>) {
817 // tan(x+yi) = -i*tanh(-y + xi)
818 std::swap(resultReal, resultImag);
819 resultImag = arith::MulFOp::create(b, resultImag, negOne, fmf);
820 }
821
822 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
823 resultImag);
824 return success();
825 }
826};
827
828struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
829 using OpConversionPattern<complex::ConjOp>::OpConversionPattern;
830
831 LogicalResult
832 matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
833 ConversionPatternRewriter &rewriter) const override {
834 auto loc = op.getLoc();
835 auto type = cast<ComplexType>(adaptor.getComplex().getType());
836 auto elementType = cast<FloatType>(type.getElementType());
837 Value real =
838 complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
839 Value imag =
840 complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
841 Value negImag = arith::NegFOp::create(rewriter, loc, elementType, imag);
842
843 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);
844
845 return success();
846 }
847};
848
849/// Converts lhs^y = (a+bi)^(c+di) to
850/// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
851/// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
852static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
853 ComplexType type, Value lhs, Value c, Value d,
854 arith::FastMathFlags fmf) {
855 auto elementType = cast<FloatType>(type.getElementType());
856
857 Value a = complex::ReOp::create(builder, lhs);
858 Value b = complex::ImOp::create(builder, lhs);
859
860 Value abs = complex::AbsOp::create(builder, lhs, fmf);
861 Value absToC = math::PowFOp::create(builder, abs, c, fmf);
862
863 Value negD = arith::NegFOp::create(builder, d, fmf);
864 Value argLhs = math::Atan2Op::create(builder, b, a, fmf);
865 Value negDArgLhs = arith::MulFOp::create(builder, negD, argLhs, fmf);
866 Value expNegDArgLhs = math::ExpOp::create(builder, negDArgLhs, fmf);
867
868 Value coeff = arith::MulFOp::create(builder, absToC, expNegDArgLhs, fmf);
869 Value lnAbs = math::LogOp::create(builder, abs, fmf);
870 Value cArgLhs = arith::MulFOp::create(builder, c, argLhs, fmf);
871 Value dLnAbs = arith::MulFOp::create(builder, d, lnAbs, fmf);
872 Value q = arith::AddFOp::create(builder, cArgLhs, dLnAbs, fmf);
873 Value cosQ = math::CosOp::create(builder, q, fmf);
874 Value sinQ = math::SinOp::create(builder, q, fmf);
875
876 Value inf = arith::ConstantOp::create(
877 builder, elementType,
878 builder.getFloatAttr(elementType,
879 APFloat::getInf(elementType.getFloatSemantics())));
880 Value zero = arith::ConstantOp::create(
881 builder, elementType, builder.getFloatAttr(elementType, 0.0));
882 Value one = arith::ConstantOp::create(builder, elementType,
883 builder.getFloatAttr(elementType, 1.0));
884 Value complexOne = complex::CreateOp::create(builder, type, one, zero);
885 Value complexZero = complex::CreateOp::create(builder, type, zero, zero);
886 Value complexInf = complex::CreateOp::create(builder, type, inf, zero);
887
888 // Case 0:
889 // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
890 // Branch Cuts for Complex Elementary Functions or Much Ado About
891 // Nothing's Sign Bit, W. Kahan, Section 10.
892 Value absEqZero =
893 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, abs, zero, fmf);
894 Value dEqZero =
895 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, d, zero, fmf);
896 Value cEqZero =
897 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, c, zero, fmf);
898 Value bEqZero =
899 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, b, zero, fmf);
900
901 Value zeroLeC =
902 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLE, zero, c, fmf);
903 Value coeffCosQ = arith::MulFOp::create(builder, coeff, cosQ, fmf);
904 Value coeffSinQ = arith::MulFOp::create(builder, coeff, sinQ, fmf);
905 Value complexOneOrZero =
906 arith::SelectOp::create(builder, cEqZero, complexOne, complexZero);
907 Value coeffCosSin =
908 complex::CreateOp::create(builder, type, coeffCosQ, coeffSinQ);
909 Value cutoff0 = arith::SelectOp::create(
910 builder,
911 arith::AndIOp::create(
912 builder, arith::AndIOp::create(builder, absEqZero, dEqZero), zeroLeC),
913 complexOneOrZero, coeffCosSin);
914
915 // Case 1:
916 // x^0 is defined to be 1 for any x, see
917 // Branch Cuts for Complex Elementary Functions or Much Ado About
918 // Nothing's Sign Bit, W. Kahan, Section 10.
919 Value rhsEqZero = arith::AndIOp::create(builder, cEqZero, dEqZero);
920 Value cutoff1 =
921 arith::SelectOp::create(builder, rhsEqZero, complexOne, cutoff0);
922
923 // Case 2:
924 // 1^(c + d*i) = 1 + 0*i
925 Value lhsEqOne = arith::AndIOp::create(
926 builder,
927 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, one, fmf),
928 bEqZero);
929 Value cutoff2 =
930 arith::SelectOp::create(builder, lhsEqOne, complexOne, cutoff1);
931
932 // Case 3:
933 // inf^(c + 0*i) = inf + 0*i, c > 0
934 Value lhsEqInf = arith::AndIOp::create(
935 builder,
936 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, inf, fmf),
937 bEqZero);
938 Value rhsGt0 = arith::AndIOp::create(
939 builder, dEqZero,
940 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, c, zero, fmf));
941 Value cutoff3 = arith::SelectOp::create(
942 builder, arith::AndIOp::create(builder, lhsEqInf, rhsGt0), complexInf,
943 cutoff2);
944
945 // Case 4:
946 // inf^(c + 0*i) = 0 + 0*i, c < 0
947 Value rhsLt0 = arith::AndIOp::create(
948 builder, dEqZero,
949 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, c, zero, fmf));
950 Value cutoff4 = arith::SelectOp::create(
951 builder, arith::AndIOp::create(builder, lhsEqInf, rhsLt0), complexZero,
952 cutoff3);
953
954 return cutoff4;
955}
956
957struct PowiOpConversion : public OpConversionPattern<complex::PowiOp> {
958 using OpConversionPattern<complex::PowiOp>::OpConversionPattern;
959
960 LogicalResult
961 matchAndRewrite(complex::PowiOp op, OpAdaptor adaptor,
962 ConversionPatternRewriter &rewriter) const override {
963 ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
964 auto type = cast<ComplexType>(op.getType());
965 auto elementType = cast<FloatType>(type.getElementType());
966
967 Value floatExponent =
968 arith::SIToFPOp::create(builder, elementType, adaptor.getRhs());
969 Value zero = arith::ConstantOp::create(
970 builder, elementType, builder.getFloatAttr(elementType, 0.0));
971 Value complexExponent =
972 complex::CreateOp::create(builder, type, floatExponent, zero);
973
974 auto pow = complex::PowOp::create(builder, type, adaptor.getLhs(),
975 complexExponent, op.getFastmathAttr());
976 rewriter.replaceOp(op, pow.getResult());
977 return success();
978 }
979};
980
981struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
982 using OpConversionPattern<complex::PowOp>::OpConversionPattern;
983
984 LogicalResult
985 matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
986 ConversionPatternRewriter &rewriter) const override {
987 mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
988 auto type = cast<ComplexType>(adaptor.getLhs().getType());
989 auto elementType = cast<FloatType>(type.getElementType());
990
991 Value c = complex::ReOp::create(builder, elementType, adaptor.getRhs());
992 Value d = complex::ImOp::create(builder, elementType, adaptor.getRhs());
993
994 rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
995 c, d, op.getFastmath())});
996 return success();
997 }
998};
999
1000struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
1001 using OpConversionPattern<complex::RsqrtOp>::OpConversionPattern;
1002
1003 LogicalResult
1004 matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
1005 ConversionPatternRewriter &rewriter) const override {
1006 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1007 auto type = cast<ComplexType>(adaptor.getComplex().getType());
1008 auto elementType = cast<FloatType>(type.getElementType());
1009
1010 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
1011
1012 auto cst = [&](APFloat v) {
1013 return arith::ConstantOp::create(b, elementType,
1014 b.getFloatAttr(elementType, v));
1015 };
1016 const auto &floatSemantics = elementType.getFloatSemantics();
1017 Value zero = cst(APFloat::getZero(floatSemantics));
1018 Value inf = cst(APFloat::getInf(floatSemantics));
1019 Value negHalf = arith::ConstantOp::create(
1020 b, elementType, b.getFloatAttr(elementType, -0.5));
1021 Value nan = cst(APFloat::getNaN(floatSemantics));
1022
1023 Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
1024 Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
1025 Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
1026 Value argArg = math::Atan2Op::create(b, imag, real, fmf);
1027 Value rsqrtArg = arith::MulFOp::create(b, argArg, negHalf, fmf);
1028 Value cos = math::CosOp::create(b, rsqrtArg, fmf);
1029 Value sin = math::SinOp::create(b, rsqrtArg, fmf);
1030
1031 Value resultReal = arith::MulFOp::create(b, absRsqrt, cos, fmf);
1032 Value resultImag = arith::MulFOp::create(b, absRsqrt, sin, fmf);
1033
1034 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1035 arith::FastMathFlags::ninf)) {
1036 Value negOne = arith::ConstantOp::create(b, elementType,
1037 b.getFloatAttr(elementType, -1));
1038
1039 Value realSignedZero = math::CopySignOp::create(b, zero, real, fmf);
1040 Value imagSignedZero = math::CopySignOp::create(b, zero, imag, fmf);
1041 Value negImagSignedZero =
1042 arith::MulFOp::create(b, negOne, imagSignedZero, fmf);
1043
1044 Value absReal = math::AbsFOp::create(b, real, fmf);
1045 Value absImag = math::AbsFOp::create(b, imag, fmf);
1046
1047 Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
1048 absImag, inf, fmf);
1049 Value realIsNan =
1050 arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, real, real, fmf);
1051 Value realIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
1052 absReal, inf, fmf);
1053 Value inIsNanInf = arith::AndIOp::create(b, absImagIsInf, realIsNan);
1054
1055 Value resultIsZero = arith::OrIOp::create(b, inIsNanInf, realIsInf);
1056
1057 resultReal =
1058 arith::SelectOp::create(b, resultIsZero, realSignedZero, resultReal);
1059 resultImag = arith::SelectOp::create(b, resultIsZero, negImagSignedZero,
1060 resultImag);
1061 }
1062
1063 Value isRealZero =
1064 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero, fmf);
1065 Value isImagZero =
1066 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf);
1067 Value isZero = arith::AndIOp::create(b, isRealZero, isImagZero);
1068
1069 resultReal = arith::SelectOp::create(b, isZero, inf, resultReal);
1070 resultImag = arith::SelectOp::create(b, isZero, nan, resultImag);
1071
1072 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
1073 resultImag);
1074 return success();
1075 }
1076};
1077
1078struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
1079 using OpConversionPattern<complex::AngleOp>::OpConversionPattern;
1080
1081 LogicalResult
1082 matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
1083 ConversionPatternRewriter &rewriter) const override {
1084 auto loc = op.getLoc();
1085 auto type = op.getType();
1086 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
1087
1088 Value real =
1089 complex::ReOp::create(rewriter, loc, type, adaptor.getComplex());
1090 Value imag =
1091 complex::ImOp::create(rewriter, loc, type, adaptor.getComplex());
1092
1093 rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf);
1094
1095 return success();
1096 }
1097};
1098
1099} // namespace
1100
1102 RewritePatternSet &patterns, complex::ComplexRangeFlags complexRange) {
1103 // clang-format off
1104 patterns.add<
1105 AbsOpConversion,
1106 AngleOpConversion,
1107 Atan2OpConversion,
1108 BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1109 BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
1110 ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
1111 ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
1112 ConjOpConversion,
1113 CosOpConversion,
1114 ExpOpConversion,
1115 Expm1OpConversion,
1116 Log1pOpConversion,
1117 LogOpConversion,
1118 MulOpConversion,
1119 NegOpConversion,
1120 SignOpConversion,
1121 SinOpConversion,
1122 SqrtOpConversion,
1123 TanTanhOpConversion<complex::TanOp>,
1124 TanTanhOpConversion<complex::TanhOp>,
1125 PowiOpConversion,
1126 PowOpConversion,
1127 RsqrtOpConversion
1128 >(patterns.getContext());
1129
1130 patterns.add<DivOpConversion>(patterns.getContext(), complexRange);
1131
1132 // clang-format on
1133}
1134
1135namespace {
1136struct ConvertComplexToStandardPass
1138 ConvertComplexToStandardPass> {
1139 using Base::Base;
1140
1141 void runOnOperation() override;
1142};
1143
1144void ConvertComplexToStandardPass::runOnOperation() {
1145 // Convert to the Standard dialect using the converter defined above.
1146 RewritePatternSet patterns(&getContext());
1148
1149 ConversionTarget target(getContext());
1150 target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
1151 target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1152 if (failed(
1153 applyPartialConversion(getOperation(), target, std::move(patterns))))
1154 signalPassFailure();
1155}
1156} // namespace
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:254
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
Location getLoc()
The source location the operation was defined or derived from.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:258
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
void convertDivToStandardUsingAlgebraic(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the arith/math dialects using algebraic method
void convertDivToStandardUsingRangeReduction(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the arith/math dialects using Smith's method
Fraction abs(const Fraction &f)
Definition Fraction.h:107
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
Include the generated interface declarations.
void populateComplexToStandardConversionPatterns(RewritePatternSet &patterns, mlir::complex::ComplexRangeFlags complexRange=mlir::complex::ComplexRangeFlags::improved)
Populate the given list with patterns that convert from Complex to Standard.
const FrozenRewritePatternSet & patterns