MLIR 23.0.0git
ComplexToStandard.cpp
Go to the documentation of this file.
1//===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
17#include <type_traits>
18
19namespace mlir {
20#define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARDPASS
21#include "mlir/Conversion/Passes.h.inc"
22} // namespace mlir
23
24using namespace mlir;
25
26namespace {
27
28enum class AbsFn { abs, sqrt, rsqrt };
29
30// Returns the absolute value, its square root or its reciprocal square root.
31Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
32 ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) {
33 Value one = arith::ConstantOp::create(b, real.getType(),
34 b.getFloatAttr(real.getType(), 1.0));
35
36 Value absReal = math::AbsFOp::create(b, real, fmf);
37 Value absImag = math::AbsFOp::create(b, imag, fmf);
38
39 Value max = arith::MaximumFOp::create(b, absReal, absImag, fmf);
40 Value min = arith::MinimumFOp::create(b, absReal, absImag, fmf);
41
42 // The lowering below requires NaNs and infinities to work correctly.
43 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
44 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
45 Value ratio = arith::DivFOp::create(b, min, max, fmfWithNaNInf);
46 Value ratioSq = arith::MulFOp::create(b, ratio, ratio, fmfWithNaNInf);
47 Value ratioSqPlusOne = arith::AddFOp::create(b, ratioSq, one, fmfWithNaNInf);
49
50 if (fn == AbsFn::rsqrt) {
51 ratioSqPlusOne = math::RsqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf);
52 min = math::RsqrtOp::create(b, min, fmfWithNaNInf);
53 max = math::RsqrtOp::create(b, max, fmfWithNaNInf);
54 }
55
56 if (fn == AbsFn::sqrt) {
57 Value quarter = arith::ConstantOp::create(
58 b, real.getType(), b.getFloatAttr(real.getType(), 0.25));
59 // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
60 Value sqrt = math::SqrtOp::create(b, max, fmfWithNaNInf);
61 Value p025 =
62 math::PowFOp::create(b, ratioSqPlusOne, quarter, fmfWithNaNInf);
63 result = arith::MulFOp::create(b, sqrt, p025, fmfWithNaNInf);
64 } else {
65 Value sqrt = math::SqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf);
66 result = arith::MulFOp::create(b, max, sqrt, fmfWithNaNInf);
67 }
68
69 Value isNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, result,
70 result, fmfWithNaNInf);
71 return arith::SelectOp::create(b, isNaN, min, result);
72}
73
74struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
75 using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
76
77 LogicalResult
78 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
79 ConversionPatternRewriter &rewriter) const override {
80 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
81
82 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
83
84 Value real = complex::ReOp::create(b, adaptor.getComplex());
85 Value imag = complex::ImOp::create(b, adaptor.getComplex());
86 rewriter.replaceOp(op, computeAbs(real, imag, fmf, b));
87
88 return success();
89 }
90};
91
92// atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
93struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
94 using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
95
96 LogicalResult
97 matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
98 ConversionPatternRewriter &rewriter) const override {
99 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
100
101 auto type = cast<ComplexType>(op.getType());
102 Type elementType = type.getElementType();
103 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
104
105 Value lhs = adaptor.getLhs();
106 Value rhs = adaptor.getRhs();
107
108 Value rhsSquared = complex::MulOp::create(b, type, rhs, rhs, fmf);
109 Value lhsSquared = complex::MulOp::create(b, type, lhs, lhs, fmf);
110 Value rhsSquaredPlusLhsSquared =
111 complex::AddOp::create(b, type, rhsSquared, lhsSquared, fmf);
112 Value sqrtOfRhsSquaredPlusLhsSquared =
113 complex::SqrtOp::create(b, type, rhsSquaredPlusLhsSquared, fmf);
114
115 Value zero =
116 arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType));
117 Value one = arith::ConstantOp::create(b, elementType,
118 b.getFloatAttr(elementType, 1));
119 Value i = complex::CreateOp::create(b, type, zero, one);
120 Value iTimesLhs = complex::MulOp::create(b, i, lhs, fmf);
121 Value rhsPlusILhs = complex::AddOp::create(b, rhs, iTimesLhs, fmf);
122
123 Value divResult = complex::DivOp::create(
124 b, rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
125 Value logResult = complex::LogOp::create(b, divResult, fmf);
126
127 Value negativeOne = arith::ConstantOp::create(
128 b, elementType, b.getFloatAttr(elementType, -1));
129 Value negativeI = complex::CreateOp::create(b, type, zero, negativeOne);
130
131 rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf);
132 return success();
133 }
134};
135
136template <typename ComparisonOp, arith::CmpFPredicate p>
137struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
138 using OpConversionPattern<ComparisonOp>::OpConversionPattern;
139 using ResultCombiner =
140 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
141 arith::AndIOp, arith::OrIOp>;
142
143 LogicalResult
144 matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
145 ConversionPatternRewriter &rewriter) const override {
146 auto loc = op.getLoc();
147 auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
148
149 Value realLhs =
150 complex::ReOp::create(rewriter, loc, type, adaptor.getLhs());
151 Value imagLhs =
152 complex::ImOp::create(rewriter, loc, type, adaptor.getLhs());
153 Value realRhs =
154 complex::ReOp::create(rewriter, loc, type, adaptor.getRhs());
155 Value imagRhs =
156 complex::ImOp::create(rewriter, loc, type, adaptor.getRhs());
157 Value realComparison =
158 arith::CmpFOp::create(rewriter, loc, p, realLhs, realRhs);
159 Value imagComparison =
160 arith::CmpFOp::create(rewriter, loc, p, imagLhs, imagRhs);
161
162 rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
163 imagComparison);
164 return success();
165 }
166};
167
168// Default conversion which applies the BinaryStandardOp separately on the real
169// and imaginary parts. Can for example be used for complex::AddOp and
170// complex::SubOp.
171template <typename BinaryComplexOp, typename BinaryStandardOp>
172struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
173 using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
174
175 LogicalResult
176 matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
177 ConversionPatternRewriter &rewriter) const override {
178 auto type = cast<ComplexType>(adaptor.getLhs().getType());
179 auto elementType = cast<FloatType>(type.getElementType());
180 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
181 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
182
183 Value realLhs = complex::ReOp::create(b, elementType, adaptor.getLhs());
184 Value realRhs = complex::ReOp::create(b, elementType, adaptor.getRhs());
185 Value resultReal = BinaryStandardOp::create(b, elementType, realLhs,
186 realRhs, fmf.getValue());
187 Value imagLhs = complex::ImOp::create(b, elementType, adaptor.getLhs());
188 Value imagRhs = complex::ImOp::create(b, elementType, adaptor.getRhs());
189 Value resultImag = BinaryStandardOp::create(b, elementType, imagLhs,
190 imagRhs, fmf.getValue());
191 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
192 resultImag);
193 return success();
194 }
195};
196
197template <typename TrigonometricOp>
198struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
199 using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
200
201 using OpConversionPattern<TrigonometricOp>::OpConversionPattern;
202
203 LogicalResult
204 matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
205 ConversionPatternRewriter &rewriter) const override {
206 auto loc = op.getLoc();
207 auto type = cast<ComplexType>(adaptor.getComplex().getType());
208 auto elementType = cast<FloatType>(type.getElementType());
209 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
210
211 Value real =
212 complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
213 Value imag =
214 complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
215
216 // Trigonometric ops use a set of common building blocks to convert to real
217 // ops. Here we create these building blocks and call into an op-specific
218 // implementation in the subclass to combine them.
219 Value half = arith::ConstantOp::create(
220 rewriter, loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
221 Value exp = math::ExpOp::create(rewriter, loc, imag, fmf);
222 Value scaledExp = arith::MulFOp::create(rewriter, loc, half, exp, fmf);
223 Value reciprocalExp = arith::DivFOp::create(rewriter, loc, half, exp, fmf);
224 Value sin = math::SinOp::create(rewriter, loc, real, fmf);
225 Value cos = math::CosOp::create(rewriter, loc, real, fmf);
226
227 auto resultPair =
228 combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
229
230 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
231 resultPair.second);
232 return success();
233 }
234
235 virtual std::pair<Value, Value>
236 combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
237 Value cos, ConversionPatternRewriter &rewriter,
238 arith::FastMathFlagsAttr fmf) const = 0;
239};
240
241struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
242 using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
243
244 std::pair<Value, Value> combine(Location loc, Value scaledExp,
245 Value reciprocalExp, Value sin, Value cos,
246 ConversionPatternRewriter &rewriter,
247 arith::FastMathFlagsAttr fmf) const override {
248 // Complex cosine is defined as;
249 // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
250 // Plugging in:
251 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
252 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
253 // and defining t := exp(y)
254 // We get:
255 // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
256 // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
257 Value sum =
258 arith::AddFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf);
259 Value resultReal = arith::MulFOp::create(rewriter, loc, sum, cos, fmf);
260 Value diff =
261 arith::SubFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf);
262 Value resultImag = arith::MulFOp::create(rewriter, loc, diff, sin, fmf);
263 return {resultReal, resultImag};
264 }
265};
266
267struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
268 DivOpConversion(MLIRContext *context, complex::ComplexRangeFlags target)
269 : OpConversionPattern<complex::DivOp>(context), complexRange(target) {}
270
271 using OpConversionPattern<complex::DivOp>::OpConversionPattern;
272
273 LogicalResult
274 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
275 ConversionPatternRewriter &rewriter) const override {
276 auto loc = op.getLoc();
277 auto type = cast<ComplexType>(adaptor.getLhs().getType());
278 auto elementType = cast<FloatType>(type.getElementType());
279 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
280
281 Value lhsReal =
282 complex::ReOp::create(rewriter, loc, elementType, adaptor.getLhs());
283 Value lhsImag =
284 complex::ImOp::create(rewriter, loc, elementType, adaptor.getLhs());
285 Value rhsReal =
286 complex::ReOp::create(rewriter, loc, elementType, adaptor.getRhs());
287 Value rhsImag =
288 complex::ImOp::create(rewriter, loc, elementType, adaptor.getRhs());
289
290 Value resultReal, resultImag;
291
292 if (complexRange == complex::ComplexRangeFlags::basic ||
293 complexRange == complex::ComplexRangeFlags::none) {
295 rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal,
296 &resultImag);
297 } else if (complexRange == complex::ComplexRangeFlags::improved) {
299 rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal,
300 &resultImag);
301 }
302
303 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
304 resultImag);
305
306 return success();
307 }
308
309private:
310 complex::ComplexRangeFlags complexRange;
311};
312
313struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
314 using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
315
316 // exp(x+I*y) = exp(x)*(cos(y)+I*sin(y))
317 // Handle special cases as StableHLO implementation does:
318 // 1. When b == 0, set imag(exp(z)) = 0
319 // 2. When exp(x) == inf, use exp(x/2)*(cos(y)+I*sin(y))*exp(x/2)
320 LogicalResult
321 matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
322 ConversionPatternRewriter &rewriter) const override {
323 auto loc = op.getLoc();
324 auto type = cast<ComplexType>(adaptor.getComplex().getType());
325 auto ET = cast<FloatType>(type.getElementType());
326 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
327 const auto &floatSemantics = ET.getFloatSemantics();
328 ImplicitLocOpBuilder b(loc, rewriter);
329
330 Value x = complex::ReOp::create(b, ET, adaptor.getComplex());
331 Value y = complex::ImOp::create(b, ET, adaptor.getComplex());
332 Value zero = arith::ConstantOp::create(b, ET, b.getZeroAttr(ET));
333 Value half = arith::ConstantOp::create(b, ET, b.getFloatAttr(ET, 0.5));
334 Value inf = arith::ConstantOp::create(
335 b, ET, b.getFloatAttr(ET, APFloat::getInf(floatSemantics)));
336
337 Value exp = math::ExpOp::create(b, x, fmf);
338 Value xHalf = arith::MulFOp::create(b, x, half, fmf);
339 Value expHalf = math::ExpOp::create(b, xHalf, fmf);
340 Value cos = math::CosOp::create(b, y, fmf);
341 Value sin = math::SinOp::create(b, y, fmf);
342
343 Value expIsInf =
344 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, exp, inf, fmf);
345 Value yIsZero =
346 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, y, zero);
347
348 // Real path: select between exp(x)*cos(y) and exp(x/2)*cos(y)*exp(x/2)
349 Value realNormal = arith::MulFOp::create(b, exp, cos, fmf);
350 Value expHalfCos = arith::MulFOp::create(b, expHalf, cos, fmf);
351 Value realOverflow = arith::MulFOp::create(b, expHalfCos, expHalf, fmf);
352 Value resultReal =
353 arith::SelectOp::create(b, expIsInf, realOverflow, realNormal);
354
355 // Imaginary part: if y == 0 return 0 else select between exp(x)*sin(y) and
356 // exp(x/2)*sin(y)*exp(x/2)
357 Value imagNormal = arith::MulFOp::create(b, exp, sin, fmf);
358 Value expHalfSin = arith::MulFOp::create(b, expHalf, sin, fmf);
359 Value imagOverflow = arith::MulFOp::create(b, expHalfSin, expHalf, fmf);
360 Value imagNonZero =
361 arith::SelectOp::create(b, expIsInf, imagOverflow, imagNormal);
362 Value resultImag = arith::SelectOp::create(b, yIsZero, zero, imagNonZero);
363
364 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
365 resultImag);
366 return success();
367 }
368};
369
370Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg,
371 ArrayRef<double> coefficients,
372 arith::FastMathFlagsAttr fmf) {
373 auto argType = mlir::cast<FloatType>(arg.getType());
374 Value poly =
375 arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[0]));
376 for (unsigned i = 1; i < coefficients.size(); ++i) {
377 poly = math::FmaOp::create(
378 b, poly, arg,
379 arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[i])),
380 fmf);
381 }
382 return poly;
383}
384
385struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
386 using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
387
388 // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
389 // [handle inaccuracies when a and/or b are small]
390 // = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i
391 // = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i
392 LogicalResult
393 matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
394 ConversionPatternRewriter &rewriter) const override {
395 auto type = op.getType();
396 auto elemType = mlir::cast<FloatType>(type.getElementType());
397
398 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
399 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
400 Value real = complex::ReOp::create(b, adaptor.getComplex());
401 Value imag = complex::ImOp::create(b, adaptor.getComplex());
402
403 Value zero = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 0.0));
404 Value one = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 1.0));
405
406 Value expm1Real = math::ExpM1Op::create(b, real, fmf);
407 Value expReal = arith::AddFOp::create(b, expm1Real, one, fmf);
408
409 Value sinImag = math::SinOp::create(b, imag, fmf);
410 Value cosm1Imag = emitCosm1(imag, fmf, b);
411 Value cosImag = arith::AddFOp::create(b, cosm1Imag, one, fmf);
412
413 Value realResult = arith::AddFOp::create(
414 b, arith::MulFOp::create(b, expm1Real, cosImag, fmf), cosm1Imag, fmf);
415
416 Value imagIsZero = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag,
417 zero, fmf.getValue());
418 Value imagResult = arith::SelectOp::create(
419 b, imagIsZero, zero, arith::MulFOp::create(b, expReal, sinImag, fmf));
420
421 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult,
422 imagResult);
423 return success();
424 }
425
426private:
427 Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf,
428 ImplicitLocOpBuilder &b) const {
429 auto argType = mlir::cast<FloatType>(arg.getType());
430 auto negHalf = arith::ConstantOp::create(b, b.getFloatAttr(argType, -0.5));
431 auto negOne = arith::ConstantOp::create(b, b.getFloatAttr(argType, -1.0));
432
433 // Algorithm copied from cephes cosm1.
434 SmallVector<double, 7> kCoeffs{
435 4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
436 2.0876754287081521758361E-9, -2.7557319214999787979814E-7,
437 2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
438 4.1666666666666666609054E-2,
439 };
440 Value cos = math::CosOp::create(b, arg, fmf);
441 Value forLargeArg = arith::AddFOp::create(b, cos, negOne, fmf);
442
443 Value argPow2 = arith::MulFOp::create(b, arg, arg, fmf);
444 Value argPow4 = arith::MulFOp::create(b, argPow2, argPow2, fmf);
445 Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf);
446
447 auto forSmallArg =
448 arith::AddFOp::create(b, arith::MulFOp::create(b, argPow4, poly, fmf),
449 arith::MulFOp::create(b, negHalf, argPow2, fmf));
450
451 // (pi/4)^2 is approximately 0.61685
452 Value piOver4Pow2 =
453 arith::ConstantOp::create(b, b.getFloatAttr(argType, 0.61685));
454 Value cond = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, argPow2,
455 piOver4Pow2, fmf.getValue());
456 return arith::SelectOp::create(b, cond, forLargeArg, forSmallArg);
457 }
458};
459
460struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
461 using OpConversionPattern<complex::LogOp>::OpConversionPattern;
462
463 LogicalResult
464 matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
465 ConversionPatternRewriter &rewriter) const override {
466 auto type = cast<ComplexType>(adaptor.getComplex().getType());
467 auto elementType = cast<FloatType>(type.getElementType());
468 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
469 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
470
471 Value abs = complex::AbsOp::create(b, elementType, adaptor.getComplex(),
472 fmf.getValue());
473 Value resultReal = math::LogOp::create(b, elementType, abs, fmf.getValue());
474 Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
475 Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
476 Value resultImag =
477 math::Atan2Op::create(b, elementType, imag, real, fmf.getValue());
478 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
479 resultImag);
480 return success();
481 }
482};
483
484struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
485 using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
486
487 LogicalResult
488 matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
489 ConversionPatternRewriter &rewriter) const override {
490 auto type = cast<ComplexType>(adaptor.getComplex().getType());
491 auto elementType = cast<FloatType>(type.getElementType());
492 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
493 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
494
495 Value real = complex::ReOp::create(b, adaptor.getComplex());
496 Value imag = complex::ImOp::create(b, adaptor.getComplex());
497
498 Value half = arith::ConstantOp::create(b, elementType,
499 b.getFloatAttr(elementType, 0.5));
500 Value one = arith::ConstantOp::create(b, elementType,
501 b.getFloatAttr(elementType, 1));
502 Value realPlusOne = arith::AddFOp::create(b, real, one, fmf);
503 Value absRealPlusOne = math::AbsFOp::create(b, realPlusOne, fmf);
504 Value absImag = math::AbsFOp::create(b, imag, fmf);
505
506 Value maxAbs = arith::MaximumFOp::create(b, absRealPlusOne, absImag, fmf);
507 Value minAbs = arith::MinimumFOp::create(b, absRealPlusOne, absImag, fmf);
508
509 Value useReal = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT,
510 realPlusOne, absImag, fmf);
511 Value maxMinusOne = arith::SubFOp::create(b, maxAbs, one, fmf);
512 Value maxAbsOfRealPlusOneAndImagMinusOne =
513 arith::SelectOp::create(b, useReal, real, maxMinusOne);
514 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
515 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
516 Value minMaxRatio = arith::DivFOp::create(b, minAbs, maxAbs, fmfWithNaNInf);
517 Value logOfMaxAbsOfRealPlusOneAndImag =
518 math::Log1pOp::create(b, maxAbsOfRealPlusOneAndImagMinusOne, fmf);
519 Value logOfSqrtPart = math::Log1pOp::create(
520 b, arith::MulFOp::create(b, minMaxRatio, minMaxRatio, fmfWithNaNInf),
521 fmfWithNaNInf);
522 Value r = arith::AddFOp::create(
523 b, arith::MulFOp::create(b, half, logOfSqrtPart, fmfWithNaNInf),
524 logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf);
525 Value resultReal = arith::SelectOp::create(
526 b,
527 arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, r, r,
528 fmfWithNaNInf),
529 minAbs, r);
530 Value resultImag = math::Atan2Op::create(b, imag, realPlusOne, fmf);
531 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
532 resultImag);
533 return success();
534 }
535};
536
537struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
538 using OpConversionPattern<complex::MulOp>::OpConversionPattern;
539
540 LogicalResult
541 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
542 ConversionPatternRewriter &rewriter) const override {
543 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
544 auto type = cast<ComplexType>(adaptor.getLhs().getType());
545 auto elementType = cast<FloatType>(type.getElementType());
546 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
547 auto fmfValue = fmf.getValue();
548 Value lhsReal = complex::ReOp::create(b, elementType, adaptor.getLhs());
549 Value lhsImag = complex::ImOp::create(b, elementType, adaptor.getLhs());
550 Value rhsReal = complex::ReOp::create(b, elementType, adaptor.getRhs());
551 Value rhsImag = complex::ImOp::create(b, elementType, adaptor.getRhs());
552 Value lhsRealTimesRhsReal =
553 arith::MulFOp::create(b, lhsReal, rhsReal, fmfValue);
554 Value lhsImagTimesRhsImag =
555 arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue);
556 Value real = arith::SubFOp::create(b, lhsRealTimesRhsReal,
557 lhsImagTimesRhsImag, fmfValue);
558 Value lhsImagTimesRhsReal =
559 arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue);
560 Value lhsRealTimesRhsImag =
561 arith::MulFOp::create(b, lhsReal, rhsImag, fmfValue);
562 Value imag = arith::AddFOp::create(b, lhsImagTimesRhsReal,
563 lhsRealTimesRhsImag, fmfValue);
564 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
565 return success();
566 }
567};
568
569struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
570 using OpConversionPattern<complex::NegOp>::OpConversionPattern;
571
572 LogicalResult
573 matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
574 ConversionPatternRewriter &rewriter) const override {
575 auto loc = op.getLoc();
576 auto type = cast<ComplexType>(adaptor.getComplex().getType());
577 auto elementType = cast<FloatType>(type.getElementType());
578
579 Value real =
580 complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
581 Value imag =
582 complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
583 Value negReal = arith::NegFOp::create(rewriter, loc, real);
584 Value negImag = arith::NegFOp::create(rewriter, loc, imag);
585 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
586 return success();
587 }
588};
589
590struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
591 using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
592
593 std::pair<Value, Value> combine(Location loc, Value scaledExp,
594 Value reciprocalExp, Value sin, Value cos,
595 ConversionPatternRewriter &rewriter,
596 arith::FastMathFlagsAttr fmf) const override {
597 // Complex sine is defined as;
598 // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
599 // Plugging in:
600 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
601 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
602 // and defining t := exp(y)
603 // We get:
604 // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
605 // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
606 Value sum =
607 arith::AddFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf);
608 Value resultReal = arith::MulFOp::create(rewriter, loc, sum, sin, fmf);
609 Value diff =
610 arith::SubFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf);
611 Value resultImag = arith::MulFOp::create(rewriter, loc, diff, cos, fmf);
612 return {resultReal, resultImag};
613 }
614};
615
616// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
617struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
618 using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
619
620 LogicalResult
621 matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
622 ConversionPatternRewriter &rewriter) const override {
623 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
624
625 auto type = cast<ComplexType>(op.getType());
626 auto elementType = cast<FloatType>(type.getElementType());
627 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
628
629 auto cst = [&](APFloat v) {
630 return arith::ConstantOp::create(b, elementType,
631 b.getFloatAttr(elementType, v));
632 };
633 const auto &floatSemantics = elementType.getFloatSemantics();
634 Value zero = cst(APFloat::getZero(floatSemantics));
635 Value half = arith::ConstantOp::create(b, elementType,
636 b.getFloatAttr(elementType, 0.5));
637
638 Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
639 Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
640 Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
641 Value argArg = math::Atan2Op::create(b, imag, real, fmf);
642 Value sqrtArg = arith::MulFOp::create(b, argArg, half, fmf);
643 Value cos = math::CosOp::create(b, sqrtArg, fmf);
644 Value sin = math::SinOp::create(b, sqrtArg, fmf);
645 // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply
646 // 0 * inf.
647 Value sinIsZero =
648 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, sin, zero, fmf);
649
650 Value resultReal = arith::MulFOp::create(b, absSqrt, cos, fmf);
651 Value resultImag = arith::SelectOp::create(
652 b, sinIsZero, zero, arith::MulFOp::create(b, absSqrt, sin, fmf));
653 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
654 arith::FastMathFlags::ninf)) {
655 Value inf = cst(APFloat::getInf(floatSemantics));
656 Value negInf = cst(APFloat::getInf(floatSemantics, true));
657 Value nan = cst(APFloat::getNaN(floatSemantics));
658 Value absImag = math::AbsFOp::create(b, elementType, imag, fmf);
659
660 Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
661 absImag, inf, fmf);
662 Value absImagIsNotInf = arith::CmpFOp::create(
663 b, arith::CmpFPredicate::ONE, absImag, inf, fmf);
664 Value realIsInf =
665 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, inf, fmf);
666 Value realIsNegInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
667 real, negInf, fmf);
668
669 resultReal = arith::SelectOp::create(
670 b, arith::AndIOp::create(b, realIsNegInf, absImagIsNotInf), zero,
671 resultReal);
672 resultReal = arith::SelectOp::create(
673 b, arith::OrIOp::create(b, absImagIsInf, realIsInf), inf, resultReal);
674
675 Value imagSignInf = math::CopySignOp::create(b, inf, imag, fmf);
676 resultImag = arith::SelectOp::create(
677 b,
678 arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, absSqrt, absSqrt),
679 nan, resultImag);
680 resultImag = arith::SelectOp::create(
681 b, arith::OrIOp::create(b, absImagIsInf, realIsNegInf), imagSignInf,
682 resultImag);
683 }
684
685 Value resultIsZero =
686 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
687 resultReal = arith::SelectOp::create(b, resultIsZero, zero, resultReal);
688 resultImag = arith::SelectOp::create(b, resultIsZero, zero, resultImag);
689
690 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
691 resultImag);
692 return success();
693 }
694};
695
696struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
697 using OpConversionPattern<complex::SignOp>::OpConversionPattern;
698
699 LogicalResult
700 matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
701 ConversionPatternRewriter &rewriter) const override {
702 auto type = cast<ComplexType>(adaptor.getComplex().getType());
703 auto elementType = cast<FloatType>(type.getElementType());
704 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
705 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
706
707 Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
708 Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
709 Value zero =
710 arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType));
711 Value realIsZero =
712 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero);
713 Value imagIsZero =
714 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero);
715 Value isZero = arith::AndIOp::create(b, realIsZero, imagIsZero);
716 auto abs =
717 complex::AbsOp::create(b, elementType, adaptor.getComplex(), fmf);
718 Value realSign = arith::DivFOp::create(b, real, abs, fmf);
719 Value imagSign = arith::DivFOp::create(b, imag, abs, fmf);
720 Value sign = complex::CreateOp::create(b, type, realSign, imagSign);
721 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
722 adaptor.getComplex(), sign);
723 return success();
724 }
725};
726
727template <typename Op>
728struct TanTanhOpConversion : public OpConversionPattern<Op> {
729 using OpConversionPattern<Op>::OpConversionPattern;
730
731 LogicalResult
732 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
733 ConversionPatternRewriter &rewriter) const override {
734 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
735 auto loc = op.getLoc();
736 auto type = cast<ComplexType>(adaptor.getComplex().getType());
737 auto elementType = cast<FloatType>(type.getElementType());
738 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
739 const auto &floatSemantics = elementType.getFloatSemantics();
740
741 Value real =
742 complex::ReOp::create(b, loc, elementType, adaptor.getComplex());
743 Value imag =
744 complex::ImOp::create(b, loc, elementType, adaptor.getComplex());
745 Value negOne = arith::ConstantOp::create(b, elementType,
746 b.getFloatAttr(elementType, -1.0));
747
748 if constexpr (std::is_same_v<Op, complex::TanOp>) {
749 // tan(x+yi) = -i*tanh(-y + xi)
750 std::swap(real, imag);
751 real = arith::MulFOp::create(b, real, negOne, fmf);
752 }
753
754 auto cst = [&](APFloat v) {
755 return arith::ConstantOp::create(b, elementType,
756 b.getFloatAttr(elementType, v));
757 };
758 Value inf = cst(APFloat::getInf(floatSemantics));
759 Value four = arith::ConstantOp::create(b, elementType,
760 b.getFloatAttr(elementType, 4.0));
761 Value twoReal = arith::AddFOp::create(b, real, real, fmf);
762 Value negTwoReal = arith::MulFOp::create(b, negOne, twoReal, fmf);
763
764 Value expTwoRealMinusOne = math::ExpM1Op::create(b, twoReal, fmf);
765 Value expNegTwoRealMinusOne = math::ExpM1Op::create(b, negTwoReal, fmf);
766 Value realNum = arith::SubFOp::create(b, expTwoRealMinusOne,
767 expNegTwoRealMinusOne, fmf);
768
769 Value cosImag = math::CosOp::create(b, imag, fmf);
770 Value cosImagSq = arith::MulFOp::create(b, cosImag, cosImag, fmf);
771 Value twoCosTwoImagPlusOne = arith::MulFOp::create(b, cosImagSq, four, fmf);
772 Value sinImag = math::SinOp::create(b, imag, fmf);
773
774 Value imagNum = arith::MulFOp::create(
775 b, four, arith::MulFOp::create(b, cosImag, sinImag, fmf), fmf);
776
777 Value expSumMinusTwo = arith::AddFOp::create(b, expTwoRealMinusOne,
778 expNegTwoRealMinusOne, fmf);
779 Value denom =
780 arith::AddFOp::create(b, expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
781
782 Value isInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
783 expSumMinusTwo, inf, fmf);
784 Value realLimit = math::CopySignOp::create(b, negOne, real, fmf);
785
786 Value resultReal = arith::SelectOp::create(
787 b, isInf, realLimit, arith::DivFOp::create(b, realNum, denom, fmf));
788 Value resultImag = arith::DivFOp::create(b, imagNum, denom, fmf);
789
790 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
791 arith::FastMathFlags::ninf)) {
792 Value absReal = math::AbsFOp::create(b, real, fmf);
793 Value zero = arith::ConstantOp::create(b, elementType,
794 b.getFloatAttr(elementType, 0.0));
795 Value nan = cst(APFloat::getNaN(floatSemantics));
796
797 Value absRealIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
798 absReal, inf, fmf);
799 Value imagIsZero =
800 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf);
801 Value absRealIsNotInf = arith::XOrIOp::create(
802 b, absRealIsInf, arith::ConstantIntOp::create(b, true, /*width=*/1));
803
804 Value imagNumIsNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO,
805 imagNum, imagNum, fmf);
806 Value resultRealIsNaN =
807 arith::AndIOp::create(b, imagNumIsNaN, absRealIsNotInf);
808 Value resultImagIsZero = arith::OrIOp::create(
809 b, imagIsZero, arith::AndIOp::create(b, absRealIsInf, imagNumIsNaN));
810
811 resultReal = arith::SelectOp::create(b, resultRealIsNaN, nan, resultReal);
812 resultImag =
813 arith::SelectOp::create(b, resultImagIsZero, zero, resultImag);
814 }
815
816 if constexpr (std::is_same_v<Op, complex::TanOp>) {
817 // tan(x+yi) = -i*tanh(-y + xi)
818 std::swap(resultReal, resultImag);
819 resultImag = arith::MulFOp::create(b, resultImag, negOne, fmf);
820 }
821
822 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
823 resultImag);
824 return success();
825 }
826};
827
828struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
829 using OpConversionPattern<complex::ConjOp>::OpConversionPattern;
830
831 LogicalResult
832 matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
833 ConversionPatternRewriter &rewriter) const override {
834 auto loc = op.getLoc();
835 auto type = cast<ComplexType>(adaptor.getComplex().getType());
836 auto elementType = cast<FloatType>(type.getElementType());
837 Value real =
838 complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
839 Value imag =
840 complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
841 Value negImag = arith::NegFOp::create(rewriter, loc, elementType, imag);
842
843 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);
844
845 return success();
846 }
847};
848
849/// Converts lhs^y = (a+bi)^(c+di) to
850/// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
851/// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
852static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
853 ComplexType type, Value lhs, Value c, Value d,
854 arith::FastMathFlags fmf) {
855 auto elementType = cast<FloatType>(type.getElementType());
856
857 Value a = complex::ReOp::create(builder, lhs);
858 Value b = complex::ImOp::create(builder, lhs);
859
860 Value abs = complex::AbsOp::create(builder, lhs, fmf);
861 Value absToC = math::PowFOp::create(builder, abs, c, fmf);
862
863 Value negD = arith::NegFOp::create(builder, d, fmf);
864 Value argLhs = math::Atan2Op::create(builder, b, a, fmf);
865 Value negDArgLhs = arith::MulFOp::create(builder, negD, argLhs, fmf);
866 Value expNegDArgLhs = math::ExpOp::create(builder, negDArgLhs, fmf);
867
868 Value coeff = arith::MulFOp::create(builder, absToC, expNegDArgLhs, fmf);
869 Value lnAbs = math::LogOp::create(builder, abs, fmf);
870 Value cArgLhs = arith::MulFOp::create(builder, c, argLhs, fmf);
871 Value dLnAbs = arith::MulFOp::create(builder, d, lnAbs, fmf);
872 Value q = arith::AddFOp::create(builder, cArgLhs, dLnAbs, fmf);
873 Value cosQ = math::CosOp::create(builder, q, fmf);
874 Value sinQ = math::SinOp::create(builder, q, fmf);
875
876 Value inf = arith::ConstantOp::create(
877 builder, elementType,
878 builder.getFloatAttr(elementType,
879 APFloat::getInf(elementType.getFloatSemantics())));
880 Value zero = arith::ConstantOp::create(
881 builder, elementType, builder.getFloatAttr(elementType, 0.0));
882 Value one = arith::ConstantOp::create(builder, elementType,
883 builder.getFloatAttr(elementType, 1.0));
884 Value complexOne = complex::CreateOp::create(builder, type, one, zero);
885 Value complexZero = complex::CreateOp::create(builder, type, zero, zero);
886 Value complexInf = complex::CreateOp::create(builder, type, inf, zero);
887
888 // Case 0:
889 // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
890 // Branch Cuts for Complex Elementary Functions or Much Ado About
891 // Nothing's Sign Bit, W. Kahan, Section 10.
892 Value absEqZero =
893 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, abs, zero, fmf);
894 Value dEqZero =
895 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, d, zero, fmf);
896 Value cEqZero =
897 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, c, zero, fmf);
898 Value bEqZero =
899 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, b, zero, fmf);
900
901 Value zeroLeC =
902 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLE, zero, c, fmf);
903 Value coeffCosQ = arith::MulFOp::create(builder, coeff, cosQ, fmf);
904 Value coeffSinQ = arith::MulFOp::create(builder, coeff, sinQ, fmf);
905 Value complexOneOrZero =
906 arith::SelectOp::create(builder, cEqZero, complexOne, complexZero);
907 Value coeffCosSin =
908 complex::CreateOp::create(builder, type, coeffCosQ, coeffSinQ);
909 Value cutoff0 = arith::SelectOp::create(
910 builder,
911 arith::AndIOp::create(
912 builder, arith::AndIOp::create(builder, absEqZero, dEqZero), zeroLeC),
913 complexOneOrZero, coeffCosSin);
914
915 // Case 1:
916 // x^0 is defined to be 1 for any x, see
917 // Branch Cuts for Complex Elementary Functions or Much Ado About
918 // Nothing's Sign Bit, W. Kahan, Section 10.
919 Value rhsEqZero = arith::AndIOp::create(builder, cEqZero, dEqZero);
920 Value cutoff1 =
921 arith::SelectOp::create(builder, rhsEqZero, complexOne, cutoff0);
922
923 // Case 2:
924 // 1^(c + d*i) = 1 + 0*i
925 Value lhsEqOne = arith::AndIOp::create(
926 builder,
927 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, one, fmf),
928 bEqZero);
929 Value cutoff2 =
930 arith::SelectOp::create(builder, lhsEqOne, complexOne, cutoff1);
931
932 // Case 3:
933 // inf^(c + 0*i) = inf + 0*i, c > 0
934 Value lhsEqInf = arith::AndIOp::create(
935 builder,
936 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, inf, fmf),
937 bEqZero);
938 Value rhsGt0 = arith::AndIOp::create(
939 builder, dEqZero,
940 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, c, zero, fmf));
941 Value cutoff3 = arith::SelectOp::create(
942 builder, arith::AndIOp::create(builder, lhsEqInf, rhsGt0), complexInf,
943 cutoff2);
944
945 // Case 4:
946 // inf^(c + 0*i) = 0 + 0*i, c < 0
947 Value rhsLt0 = arith::AndIOp::create(
948 builder, dEqZero,
949 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, c, zero, fmf));
950 Value cutoff4 = arith::SelectOp::create(
951 builder, arith::AndIOp::create(builder, lhsEqInf, rhsLt0), complexZero,
952 cutoff3);
953
954 return cutoff4;
955}
956
957struct PowiOpConversion : public OpConversionPattern<complex::PowiOp> {
958 using OpConversionPattern<complex::PowiOp>::OpConversionPattern;
959
960 LogicalResult
961 matchAndRewrite(complex::PowiOp op, OpAdaptor adaptor,
962 ConversionPatternRewriter &rewriter) const override {
963 ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
964 auto type = cast<ComplexType>(op.getType());
965 auto elementType = cast<FloatType>(type.getElementType());
966
967 Value floatExponent =
968 arith::SIToFPOp::create(builder, elementType, adaptor.getRhs());
969 Value zero = arith::ConstantOp::create(
970 builder, elementType, builder.getFloatAttr(elementType, 0.0));
971 Value complexExponent =
972 complex::CreateOp::create(builder, type, floatExponent, zero);
973
974 auto pow = complex::PowOp::create(builder, type, adaptor.getLhs(),
975 complexExponent, op.getFastmathAttr());
976 rewriter.replaceOp(op, pow.getResult());
977 return success();
978 }
979};
980
981struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
982 using OpConversionPattern<complex::PowOp>::OpConversionPattern;
983
984 LogicalResult
985 matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
986 ConversionPatternRewriter &rewriter) const override {
987 mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
988 auto type = cast<ComplexType>(adaptor.getLhs().getType());
989 auto elementType = cast<FloatType>(type.getElementType());
990
991 Value c = complex::ReOp::create(builder, elementType, adaptor.getRhs());
992 Value d = complex::ImOp::create(builder, elementType, adaptor.getRhs());
993
994 rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
995 c, d, op.getFastmath())});
996 return success();
997 }
998};
999
1000struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
1001 using OpConversionPattern<complex::RsqrtOp>::OpConversionPattern;
1002
1003 LogicalResult
1004 matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
1005 ConversionPatternRewriter &rewriter) const override {
1006 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1007 auto type = cast<ComplexType>(adaptor.getComplex().getType());
1008 auto elementType = cast<FloatType>(type.getElementType());
1009
1010 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
1011
1012 auto cst = [&](APFloat v) {
1013 return arith::ConstantOp::create(b, elementType,
1014 b.getFloatAttr(elementType, v));
1015 };
1016 const auto &floatSemantics = elementType.getFloatSemantics();
1017 Value zero = cst(APFloat::getZero(floatSemantics));
1018 Value inf = cst(APFloat::getInf(floatSemantics));
1019 Value negHalf = arith::ConstantOp::create(
1020 b, elementType, b.getFloatAttr(elementType, -0.5));
1021 Value nan = cst(APFloat::getNaN(floatSemantics));
1022
1023 Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
1024 Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
1025 Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
1026 Value argArg = math::Atan2Op::create(b, imag, real, fmf);
1027 Value rsqrtArg = arith::MulFOp::create(b, argArg, negHalf, fmf);
1028 Value cos = math::CosOp::create(b, rsqrtArg, fmf);
1029 Value sin = math::SinOp::create(b, rsqrtArg, fmf);
1030
1031 Value resultReal = arith::MulFOp::create(b, absRsqrt, cos, fmf);
1032 Value resultImag = arith::MulFOp::create(b, absRsqrt, sin, fmf);
1033
1034 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1035 arith::FastMathFlags::ninf)) {
1036 Value negOne = arith::ConstantOp::create(b, elementType,
1037 b.getFloatAttr(elementType, -1));
1038
1039 Value realSignedZero = math::CopySignOp::create(b, zero, real, fmf);
1040 Value imagSignedZero = math::CopySignOp::create(b, zero, imag, fmf);
1041 Value negImagSignedZero =
1042 arith::MulFOp::create(b, negOne, imagSignedZero, fmf);
1043
1044 Value absReal = math::AbsFOp::create(b, real, fmf);
1045 Value absImag = math::AbsFOp::create(b, imag, fmf);
1046
1047 Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
1048 absImag, inf, fmf);
1049 Value realIsNan =
1050 arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, real, real, fmf);
1051 Value realIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
1052 absReal, inf, fmf);
1053 Value inIsNanInf = arith::AndIOp::create(b, absImagIsInf, realIsNan);
1054
1055 Value resultIsZero = arith::OrIOp::create(b, inIsNanInf, realIsInf);
1056
1057 resultReal =
1058 arith::SelectOp::create(b, resultIsZero, realSignedZero, resultReal);
1059 resultImag = arith::SelectOp::create(b, resultIsZero, negImagSignedZero,
1060 resultImag);
1061 }
1062
1063 Value isRealZero =
1064 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero, fmf);
1065 Value isImagZero =
1066 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf);
1067 Value isZero = arith::AndIOp::create(b, isRealZero, isImagZero);
1068
1069 resultReal = arith::SelectOp::create(b, isZero, inf, resultReal);
1070 resultImag = arith::SelectOp::create(b, isZero, nan, resultImag);
1071
1072 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
1073 resultImag);
1074 return success();
1075 }
1076};
1077
1078struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
1079 using OpConversionPattern<complex::AngleOp>::OpConversionPattern;
1080
1081 LogicalResult
1082 matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
1083 ConversionPatternRewriter &rewriter) const override {
1084 auto loc = op.getLoc();
1085 auto type = op.getType();
1086 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
1087
1088 Value real =
1089 complex::ReOp::create(rewriter, loc, type, adaptor.getComplex());
1090 Value imag =
1091 complex::ImOp::create(rewriter, loc, type, adaptor.getComplex());
1092
1093 rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf);
1094
1095 return success();
1096 }
1097};
1098
1099} // namespace
1100
1102 RewritePatternSet &patterns, complex::ComplexRangeFlags complexRange) {
1103 // clang-format off
1104 patterns.add<
1105 AbsOpConversion,
1106 AngleOpConversion,
1107 Atan2OpConversion,
1108 BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1109 BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
1110 ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
1111 ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
1112 ConjOpConversion,
1113 CosOpConversion,
1114 ExpOpConversion,
1115 Expm1OpConversion,
1116 Log1pOpConversion,
1117 LogOpConversion,
1118 MulOpConversion,
1119 NegOpConversion,
1120 SignOpConversion,
1121 SinOpConversion,
1122 SqrtOpConversion,
1123 TanTanhOpConversion<complex::TanOp>,
1124 TanTanhOpConversion<complex::TanhOp>,
1125 PowiOpConversion,
1126 PowOpConversion,
1127 RsqrtOpConversion
1128 >(patterns.getContext());
1129
1130 patterns.add<DivOpConversion>(patterns.getContext(), complexRange);
1131
1132 // clang-format on
1133}
1134
1135namespace {
1136struct ConvertComplexToStandardPass
1137 : public impl::ConvertComplexToStandardPassBase<
1138 ConvertComplexToStandardPass> {
1139 using Base::Base;
1140
1141 void runOnOperation() override;
1142};
1143
1144void ConvertComplexToStandardPass::runOnOperation() {
1145 // Convert to the Standard dialect using the converter defined above.
1146 RewritePatternSet patterns(&getContext());
1147 populateComplexToStandardConversionPatterns(patterns, complexRange);
1148
1149 ConversionTarget target(getContext());
1150 target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
1151 target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1152 if (failed(
1153 applyPartialConversion(getOperation(), target, std::move(patterns))))
1154 signalPassFailure();
1155}
1156} // namespace
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:262
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
void convertDivToStandardUsingAlgebraic(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the arith/math dialects using algebraic method
void convertDivToStandardUsingRangeReduction(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the arith/math dialects using Smith's method
Fraction abs(const Fraction &f)
Definition Fraction.h:107
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
Include the generated interface declarations.
void populateComplexToStandardConversionPatterns(RewritePatternSet &patterns, mlir::complex::ComplexRangeFlags complexRange=mlir::complex::ComplexRangeFlags::improved)
Populate the given list with patterns that convert from Complex to Standard.