MLIR 23.0.0git
ComplexToStandard.cpp
Go to the documentation of this file.
1//===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
17#include <type_traits>
18
19namespace mlir {
20#define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARDPASS
21#include "mlir/Conversion/Passes.h.inc"
22} // namespace mlir
23
24using namespace mlir;
25
26namespace {
27
28enum class AbsFn { abs, sqrt, rsqrt };
29
30// Returns the absolute value, its square root or its reciprocal square root.
31Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
32 ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) {
33 Value one = arith::ConstantOp::create(b, real.getType(),
34 b.getFloatAttr(real.getType(), 1.0));
35
36 Value absReal = math::AbsFOp::create(b, real, fmf);
37 Value absImag = math::AbsFOp::create(b, imag, fmf);
38
39 Value max = arith::MaximumFOp::create(b, absReal, absImag, fmf);
40 Value min = arith::MinimumFOp::create(b, absReal, absImag, fmf);
41
42 // The lowering below requires NaNs and infinities to work correctly.
43 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
44 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
45 Value ratio = arith::DivFOp::create(b, min, max, fmfWithNaNInf);
46 Value ratioSq = arith::MulFOp::create(b, ratio, ratio, fmfWithNaNInf);
47 Value ratioSqPlusOne = arith::AddFOp::create(b, ratioSq, one, fmfWithNaNInf);
49
50 if (fn == AbsFn::rsqrt) {
51 ratioSqPlusOne = math::RsqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf);
52 min = math::RsqrtOp::create(b, min, fmfWithNaNInf);
53 max = math::RsqrtOp::create(b, max, fmfWithNaNInf);
54 }
55
56 if (fn == AbsFn::sqrt) {
57 Value quarter = arith::ConstantOp::create(
58 b, real.getType(), b.getFloatAttr(real.getType(), 0.25));
59 // sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
60 Value sqrt = math::SqrtOp::create(b, max, fmfWithNaNInf);
61 Value p025 =
62 math::PowFOp::create(b, ratioSqPlusOne, quarter, fmfWithNaNInf);
63 result = arith::MulFOp::create(b, sqrt, p025, fmfWithNaNInf);
64 } else {
65 Value sqrt = math::SqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf);
66 result = arith::MulFOp::create(b, max, sqrt, fmfWithNaNInf);
67 }
68
69 Value isNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, result,
70 result, fmfWithNaNInf);
71 return arith::SelectOp::create(b, isNaN, min, result);
72}
73
74struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
75 using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
76
77 LogicalResult
78 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
79 ConversionPatternRewriter &rewriter) const override {
80 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
81
82 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
83
84 Value real = complex::ReOp::create(b, adaptor.getComplex());
85 Value imag = complex::ImOp::create(b, adaptor.getComplex());
86 rewriter.replaceOp(op, computeAbs(real, imag, fmf, b));
87
88 return success();
89 }
90};
91
92// atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
93struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
94 using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
95
96 LogicalResult
97 matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
98 ConversionPatternRewriter &rewriter) const override {
99 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
100
101 auto type = cast<ComplexType>(op.getType());
102 Type elementType = type.getElementType();
103 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
104
105 Value lhs = adaptor.getLhs();
106 Value rhs = adaptor.getRhs();
107
108 Value rhsSquared = complex::MulOp::create(b, type, rhs, rhs, fmf);
109 Value lhsSquared = complex::MulOp::create(b, type, lhs, lhs, fmf);
110 Value rhsSquaredPlusLhsSquared =
111 complex::AddOp::create(b, type, rhsSquared, lhsSquared, fmf);
112 Value sqrtOfRhsSquaredPlusLhsSquared =
113 complex::SqrtOp::create(b, type, rhsSquaredPlusLhsSquared, fmf);
114
115 Value zero =
116 arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType));
117 Value one = arith::ConstantOp::create(b, elementType,
118 b.getFloatAttr(elementType, 1));
119 Value i = complex::CreateOp::create(b, type, zero, one);
120 Value iTimesLhs = complex::MulOp::create(b, i, lhs, fmf);
121 Value rhsPlusILhs = complex::AddOp::create(b, rhs, iTimesLhs, fmf);
122
123 Value divResult = complex::DivOp::create(
124 b, rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
125 Value logResult = complex::LogOp::create(b, divResult, fmf);
126
127 Value negativeOne = arith::ConstantOp::create(
128 b, elementType, b.getFloatAttr(elementType, -1));
129 Value negativeI = complex::CreateOp::create(b, type, zero, negativeOne);
130
131 rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf);
132 return success();
133 }
134};
135
136template <typename ComparisonOp, arith::CmpFPredicate p>
137struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
138 using OpConversionPattern<ComparisonOp>::OpConversionPattern;
139 using ResultCombiner =
140 std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
141 arith::AndIOp, arith::OrIOp>;
142
143 LogicalResult
144 matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
145 ConversionPatternRewriter &rewriter) const override {
146 auto loc = op.getLoc();
147 auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
148
149 Value realLhs =
150 complex::ReOp::create(rewriter, loc, type, adaptor.getLhs());
151 Value imagLhs =
152 complex::ImOp::create(rewriter, loc, type, adaptor.getLhs());
153 Value realRhs =
154 complex::ReOp::create(rewriter, loc, type, adaptor.getRhs());
155 Value imagRhs =
156 complex::ImOp::create(rewriter, loc, type, adaptor.getRhs());
157 Value realComparison =
158 arith::CmpFOp::create(rewriter, loc, p, realLhs, realRhs);
159 Value imagComparison =
160 arith::CmpFOp::create(rewriter, loc, p, imagLhs, imagRhs);
161
162 rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
163 imagComparison);
164 return success();
165 }
166};
167
168// Default conversion which applies the BinaryStandardOp separately on the real
169// and imaginary parts. Can for example be used for complex::AddOp and
170// complex::SubOp.
171template <typename BinaryComplexOp, typename BinaryStandardOp>
172struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
173 using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
174
175 LogicalResult
176 matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
177 ConversionPatternRewriter &rewriter) const override {
178 auto type = cast<ComplexType>(adaptor.getLhs().getType());
179 auto elementType = cast<FloatType>(type.getElementType());
180 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
181 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
182
183 Value realLhs = complex::ReOp::create(b, elementType, adaptor.getLhs());
184 Value realRhs = complex::ReOp::create(b, elementType, adaptor.getRhs());
185 Value resultReal = BinaryStandardOp::create(b, elementType, realLhs,
186 realRhs, fmf.getValue());
187 Value imagLhs = complex::ImOp::create(b, elementType, adaptor.getLhs());
188 Value imagRhs = complex::ImOp::create(b, elementType, adaptor.getRhs());
189 Value resultImag = BinaryStandardOp::create(b, elementType, imagLhs,
190 imagRhs, fmf.getValue());
191 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
192 resultImag);
193 return success();
194 }
195};
196
197template <typename TrigonometricOp>
198struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
199 using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
200
201 using OpConversionPattern<TrigonometricOp>::OpConversionPattern;
202
203 LogicalResult
204 matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
205 ConversionPatternRewriter &rewriter) const override {
206 auto loc = op.getLoc();
207 auto type = cast<ComplexType>(adaptor.getComplex().getType());
208 auto elementType = cast<FloatType>(type.getElementType());
209 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
210
211 Value real =
212 complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
213 Value imag =
214 complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
215
216 // Trigonometric ops use a set of common building blocks to convert to real
217 // ops. Here we create these building blocks and call into an op-specific
218 // implementation in the subclass to combine them.
219 Value half = arith::ConstantOp::create(
220 rewriter, loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
221 Value exp = math::ExpOp::create(rewriter, loc, imag, fmf);
222 Value scaledExp = arith::MulFOp::create(rewriter, loc, half, exp, fmf);
223 Value reciprocalExp = arith::DivFOp::create(rewriter, loc, half, exp, fmf);
224 Value sin = math::SinOp::create(rewriter, loc, real, fmf);
225 Value cos = math::CosOp::create(rewriter, loc, real, fmf);
226
227 auto resultPair =
228 combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
229
230 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
231 resultPair.second);
232 return success();
233 }
234
235 virtual std::pair<Value, Value>
236 combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
237 Value cos, ConversionPatternRewriter &rewriter,
238 arith::FastMathFlagsAttr fmf) const = 0;
239};
240
241struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
242 using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
243
244 std::pair<Value, Value> combine(Location loc, Value scaledExp,
245 Value reciprocalExp, Value sin, Value cos,
246 ConversionPatternRewriter &rewriter,
247 arith::FastMathFlagsAttr fmf) const override {
248 // Complex cosine is defined as;
249 // cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
250 // Plugging in:
251 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
252 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
253 // and defining t := exp(y)
254 // We get:
255 // Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
256 // Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
257 Value sum =
258 arith::AddFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf);
259 Value resultReal = arith::MulFOp::create(rewriter, loc, sum, cos, fmf);
260 Value diff =
261 arith::SubFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf);
262 Value resultImag = arith::MulFOp::create(rewriter, loc, diff, sin, fmf);
263 return {resultReal, resultImag};
264 }
265};
266
267struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
268 DivOpConversion(MLIRContext *context, complex::ComplexRangeFlags target)
269 : OpConversionPattern<complex::DivOp>(context), complexRange(target) {}
270
271 using OpConversionPattern<complex::DivOp>::OpConversionPattern;
272
273 LogicalResult
274 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
275 ConversionPatternRewriter &rewriter) const override {
276 auto loc = op.getLoc();
277 auto type = cast<ComplexType>(adaptor.getLhs().getType());
278 auto elementType = cast<FloatType>(type.getElementType());
279 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
280
281 Value lhsReal =
282 complex::ReOp::create(rewriter, loc, elementType, adaptor.getLhs());
283 Value lhsImag =
284 complex::ImOp::create(rewriter, loc, elementType, adaptor.getLhs());
285 Value rhsReal =
286 complex::ReOp::create(rewriter, loc, elementType, adaptor.getRhs());
287 Value rhsImag =
288 complex::ImOp::create(rewriter, loc, elementType, adaptor.getRhs());
289
290 Value resultReal, resultImag;
291
292 if (complexRange == complex::ComplexRangeFlags::basic ||
293 complexRange == complex::ComplexRangeFlags::none) {
295 rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal,
296 &resultImag);
297 } else if (complexRange == complex::ComplexRangeFlags::improved) {
299 rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal,
300 &resultImag);
301 }
302
303 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
304 resultImag);
305
306 return success();
307 }
308
309private:
310 complex::ComplexRangeFlags complexRange;
311};
312
313struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
314 using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
315
316 // exp(x+I*y) = exp(x)*(cos(y)+I*sin(y))
317 // Handle special cases as StableHLO implementation does:
318 // 1. When b == 0, set imag(exp(z)) = 0
319 // 2. When exp(x) == inf, use exp(x/2)*(cos(y)+I*sin(y))*exp(x/2)
320 LogicalResult
321 matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
322 ConversionPatternRewriter &rewriter) const override {
323 auto loc = op.getLoc();
324 auto type = cast<ComplexType>(adaptor.getComplex().getType());
325 auto ET = cast<FloatType>(type.getElementType());
326 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
327 const auto &floatSemantics = ET.getFloatSemantics();
328 ImplicitLocOpBuilder b(loc, rewriter);
329
330 Value x = complex::ReOp::create(b, ET, adaptor.getComplex());
331 Value y = complex::ImOp::create(b, ET, adaptor.getComplex());
332 Value zero = arith::ConstantOp::create(b, ET, b.getZeroAttr(ET));
333 Value half = arith::ConstantOp::create(b, ET, b.getFloatAttr(ET, 0.5));
334 Value inf = arith::ConstantOp::create(
335 b, ET, b.getFloatAttr(ET, APFloat::getInf(floatSemantics)));
336
337 Value exp = math::ExpOp::create(b, x, fmf);
338 Value xHalf = arith::MulFOp::create(b, x, half, fmf);
339 Value expHalf = math::ExpOp::create(b, xHalf, fmf);
340 Value cos = math::CosOp::create(b, y, fmf);
341 Value sin = math::SinOp::create(b, y, fmf);
342
343 Value expIsInf =
344 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, exp, inf, fmf);
345 Value yIsZero =
346 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, y, zero);
347
348 // Real path: select between exp(x)*cos(y) and exp(x/2)*cos(y)*exp(x/2)
349 Value realNormal = arith::MulFOp::create(b, exp, cos, fmf);
350 Value expHalfCos = arith::MulFOp::create(b, expHalf, cos, fmf);
351 Value realOverflow = arith::MulFOp::create(b, expHalfCos, expHalf, fmf);
352 Value resultReal =
353 arith::SelectOp::create(b, expIsInf, realOverflow, realNormal);
354
355 // Imaginary part: if y == 0 return 0 else select between exp(x)*sin(y) and
356 // exp(x/2)*sin(y)*exp(x/2)
357 Value imagNormal = arith::MulFOp::create(b, exp, sin, fmf);
358 Value expHalfSin = arith::MulFOp::create(b, expHalf, sin, fmf);
359 Value imagOverflow = arith::MulFOp::create(b, expHalfSin, expHalf, fmf);
360 Value imagNonZero =
361 arith::SelectOp::create(b, expIsInf, imagOverflow, imagNormal);
362 Value resultImag = arith::SelectOp::create(b, yIsZero, zero, imagNonZero);
363
364 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
365 resultImag);
366 return success();
367 }
368};
369
370Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg,
371 ArrayRef<double> coefficients,
372 arith::FastMathFlagsAttr fmf) {
373 auto argType = mlir::cast<FloatType>(arg.getType());
374 Value poly =
375 arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[0]));
376 for (unsigned i = 1; i < coefficients.size(); ++i) {
377 poly = math::FmaOp::create(
378 b, poly, arg,
379 arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[i])),
380 fmf);
381 }
382 return poly;
383}
384
385struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
386 using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
387
388 // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
389 // [handle inaccuracies when a and/or b are small]
390 // = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i
391 // = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i
392 LogicalResult
393 matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
394 ConversionPatternRewriter &rewriter) const override {
395 auto type = op.getType();
396 auto elemType = mlir::cast<FloatType>(type.getElementType());
397
398 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
399 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
400 Value real = complex::ReOp::create(b, adaptor.getComplex());
401 Value imag = complex::ImOp::create(b, adaptor.getComplex());
402
403 Value zero = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 0.0));
404 Value one = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 1.0));
405
406 Value expm1Real = math::ExpM1Op::create(b, real, fmf);
407 Value expReal = arith::AddFOp::create(b, expm1Real, one, fmf);
408
409 Value sinImag = math::SinOp::create(b, imag, fmf);
410 Value cosm1Imag = emitCosm1(imag, fmf, b);
411 Value cosImag = arith::AddFOp::create(b, cosm1Imag, one, fmf);
412
413 Value realResult = arith::AddFOp::create(
414 b, arith::MulFOp::create(b, expm1Real, cosImag, fmf), cosm1Imag, fmf);
415
416 Value imagIsZero = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag,
417 zero, fmf.getValue());
418 Value imagResult = arith::SelectOp::create(
419 b, imagIsZero, zero, arith::MulFOp::create(b, expReal, sinImag, fmf));
420
421 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult,
422 imagResult);
423 return success();
424 }
425
426private:
427 Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf,
428 ImplicitLocOpBuilder &b) const {
429 auto argType = mlir::cast<FloatType>(arg.getType());
430 auto negHalf = arith::ConstantOp::create(b, b.getFloatAttr(argType, -0.5));
431 auto negOne = arith::ConstantOp::create(b, b.getFloatAttr(argType, -1.0));
432
433 // Algorithm copied from cephes cosm1.
434 SmallVector<double, 7> kCoeffs{
435 4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
436 2.0876754287081521758361E-9, -2.7557319214999787979814E-7,
437 2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
438 4.1666666666666666609054E-2,
439 };
440 Value cos = math::CosOp::create(b, arg, fmf);
441 Value forLargeArg = arith::AddFOp::create(b, cos, negOne, fmf);
442
443 Value argPow2 = arith::MulFOp::create(b, arg, arg, fmf);
444 Value argPow4 = arith::MulFOp::create(b, argPow2, argPow2, fmf);
445 Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf);
446
447 auto forSmallArg =
448 arith::AddFOp::create(b, arith::MulFOp::create(b, argPow4, poly, fmf),
449 arith::MulFOp::create(b, negHalf, argPow2, fmf));
450
451 // (pi/4)^2 is approximately 0.61685
452 Value piOver4Pow2 =
453 arith::ConstantOp::create(b, b.getFloatAttr(argType, 0.61685));
454 Value cond = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, argPow2,
455 piOver4Pow2, fmf.getValue());
456 return arith::SelectOp::create(b, cond, forLargeArg, forSmallArg);
457 }
458};
459
460struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
461 using OpConversionPattern<complex::LogOp>::OpConversionPattern;
462
463 LogicalResult
464 matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
465 ConversionPatternRewriter &rewriter) const override {
466 auto type = cast<ComplexType>(adaptor.getComplex().getType());
467 auto elementType = cast<FloatType>(type.getElementType());
468 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
469 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
470
471 Value abs = complex::AbsOp::create(b, elementType, adaptor.getComplex(),
472 fmf.getValue());
473 Value resultReal = math::LogOp::create(b, elementType, abs, fmf.getValue());
474 Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
475 Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
476 Value resultImag =
477 math::Atan2Op::create(b, elementType, imag, real, fmf.getValue());
478 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
479 resultImag);
480 return success();
481 }
482};
483
484struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
485 using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
486
487 LogicalResult
488 matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
489 ConversionPatternRewriter &rewriter) const override {
490 auto type = cast<ComplexType>(adaptor.getComplex().getType());
491 auto elementType = cast<FloatType>(type.getElementType());
492 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
493 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
494
495 Value real = complex::ReOp::create(b, adaptor.getComplex());
496 Value imag = complex::ImOp::create(b, adaptor.getComplex());
497
498 Value half = arith::ConstantOp::create(b, elementType,
499 b.getFloatAttr(elementType, 0.5));
500 Value one = arith::ConstantOp::create(b, elementType,
501 b.getFloatAttr(elementType, 1));
502 Value realPlusOne = arith::AddFOp::create(b, real, one, fmf);
503 Value absRealPlusOne = math::AbsFOp::create(b, realPlusOne, fmf);
504 Value absImag = math::AbsFOp::create(b, imag, fmf);
505
506 Value maxAbs = arith::MaximumFOp::create(b, absRealPlusOne, absImag, fmf);
507 Value minAbs = arith::MinimumFOp::create(b, absRealPlusOne, absImag, fmf);
508
509 Value useReal = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT,
510 realPlusOne, absImag, fmf);
511 Value maxMinusOne = arith::SubFOp::create(b, maxAbs, one, fmf);
512 Value maxAbsOfRealPlusOneAndImagMinusOne =
513 arith::SelectOp::create(b, useReal, real, maxMinusOne);
514 arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
515 fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
516 Value minMaxRatio = arith::DivFOp::create(b, minAbs, maxAbs, fmfWithNaNInf);
517 Value logOfMaxAbsOfRealPlusOneAndImag =
518 math::Log1pOp::create(b, maxAbsOfRealPlusOneAndImagMinusOne, fmf);
519 Value logOfSqrtPart = math::Log1pOp::create(
520 b, arith::MulFOp::create(b, minMaxRatio, minMaxRatio, fmfWithNaNInf),
521 fmfWithNaNInf);
522 Value r = arith::AddFOp::create(
523 b, arith::MulFOp::create(b, half, logOfSqrtPart, fmfWithNaNInf),
524 logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf);
525 Value resultReal = arith::SelectOp::create(
526 b,
527 arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, r, r,
528 fmfWithNaNInf),
529 minAbs, r);
530 Value resultImag = math::Atan2Op::create(b, imag, realPlusOne, fmf);
531 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
532 resultImag);
533 return success();
534 }
535};
536
537struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
538 using OpConversionPattern<complex::MulOp>::OpConversionPattern;
539
540 LogicalResult
541 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
542 ConversionPatternRewriter &rewriter) const override {
543 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
544 auto type = cast<ComplexType>(adaptor.getLhs().getType());
545 auto elementType = cast<FloatType>(type.getElementType());
546 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
547 auto fmfValue = fmf.getValue();
548 Value lhsReal = complex::ReOp::create(b, elementType, adaptor.getLhs());
549 Value lhsImag = complex::ImOp::create(b, elementType, adaptor.getLhs());
550 Value rhsReal = complex::ReOp::create(b, elementType, adaptor.getRhs());
551 Value rhsImag = complex::ImOp::create(b, elementType, adaptor.getRhs());
552 Value real;
553 Value imag;
554 if (arith::bitEnumContainsAll(fmfValue, arith::FastMathFlags::contract)) {
555 Value lhsImagTimesRhsImag =
556 arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue);
557 Value negLhsImagTimesRhsImag =
558 arith::NegFOp::create(b, lhsImagTimesRhsImag, fmfValue);
559 real = math::FmaOp::create(b, lhsReal, rhsReal, negLhsImagTimesRhsImag,
560 fmfValue);
561
562 Value lhsImagTimesRhsReal =
563 arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue);
564 imag = math::FmaOp::create(b, lhsReal, rhsImag, lhsImagTimesRhsReal,
565 fmfValue);
566 } else {
567 Value lhsRealTimesRhsReal =
568 arith::MulFOp::create(b, lhsReal, rhsReal, fmfValue);
569 Value lhsImagTimesRhsImag =
570 arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue);
571 Value lhsImagTimesRhsReal =
572 arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue);
573 Value lhsRealTimesRhsImag =
574 arith::MulFOp::create(b, lhsReal, rhsImag, fmfValue);
575
576 real = arith::SubFOp::create(b, lhsRealTimesRhsReal, lhsImagTimesRhsImag,
577 fmfValue);
578 imag = arith::AddFOp::create(b, lhsImagTimesRhsReal, lhsRealTimesRhsImag,
579 fmfValue);
580 }
581 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
582 return success();
583 }
584};
585
586struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
587 using OpConversionPattern<complex::NegOp>::OpConversionPattern;
588
589 LogicalResult
590 matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
591 ConversionPatternRewriter &rewriter) const override {
592 auto loc = op.getLoc();
593 auto type = cast<ComplexType>(adaptor.getComplex().getType());
594 auto elementType = cast<FloatType>(type.getElementType());
595
596 Value real =
597 complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
598 Value imag =
599 complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
600 Value negReal = arith::NegFOp::create(rewriter, loc, real);
601 Value negImag = arith::NegFOp::create(rewriter, loc, imag);
602 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
603 return success();
604 }
605};
606
607struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
608 using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
609
610 std::pair<Value, Value> combine(Location loc, Value scaledExp,
611 Value reciprocalExp, Value sin, Value cos,
612 ConversionPatternRewriter &rewriter,
613 arith::FastMathFlagsAttr fmf) const override {
614 // Complex sine is defined as;
615 // sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
616 // Plugging in:
617 // exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
618 // exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
619 // and defining t := exp(y)
620 // We get:
621 // Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
622 // Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
623 Value sum =
624 arith::AddFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf);
625 Value resultReal = arith::MulFOp::create(rewriter, loc, sum, sin, fmf);
626 Value diff =
627 arith::SubFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf);
628 Value resultImag = arith::MulFOp::create(rewriter, loc, diff, cos, fmf);
629 return {resultReal, resultImag};
630 }
631};
632
633// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
634struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
635 using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
636
637 LogicalResult
638 matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
639 ConversionPatternRewriter &rewriter) const override {
640 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
641
642 auto type = cast<ComplexType>(op.getType());
643 auto elementType = cast<FloatType>(type.getElementType());
644 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
645
646 auto cst = [&](APFloat v) {
647 return arith::ConstantOp::create(b, elementType,
648 b.getFloatAttr(elementType, v));
649 };
650 const auto &floatSemantics = elementType.getFloatSemantics();
651 Value zero = cst(APFloat::getZero(floatSemantics));
652 Value half = arith::ConstantOp::create(b, elementType,
653 b.getFloatAttr(elementType, 0.5));
654
655 Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
656 Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
657 Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
658 Value argArg = math::Atan2Op::create(b, imag, real, fmf);
659 Value sqrtArg = arith::MulFOp::create(b, argArg, half, fmf);
660 Value cos = math::CosOp::create(b, sqrtArg, fmf);
661 Value sin = math::SinOp::create(b, sqrtArg, fmf);
662 // sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply
663 // 0 * inf.
664 Value sinIsZero =
665 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, sin, zero, fmf);
666
667 Value resultReal = arith::MulFOp::create(b, absSqrt, cos, fmf);
668 Value resultImag = arith::SelectOp::create(
669 b, sinIsZero, zero, arith::MulFOp::create(b, absSqrt, sin, fmf));
670 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
671 arith::FastMathFlags::ninf)) {
672 Value inf = cst(APFloat::getInf(floatSemantics));
673 Value negInf = cst(APFloat::getInf(floatSemantics, true));
674 Value nan = cst(APFloat::getNaN(floatSemantics));
675 Value absImag = math::AbsFOp::create(b, elementType, imag, fmf);
676
677 Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
678 absImag, inf, fmf);
679 Value absImagIsNotInf = arith::CmpFOp::create(
680 b, arith::CmpFPredicate::ONE, absImag, inf, fmf);
681 Value realIsInf =
682 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, inf, fmf);
683 Value realIsNegInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
684 real, negInf, fmf);
685
686 resultReal = arith::SelectOp::create(
687 b, arith::AndIOp::create(b, realIsNegInf, absImagIsNotInf), zero,
688 resultReal);
689 resultReal = arith::SelectOp::create(
690 b, arith::OrIOp::create(b, absImagIsInf, realIsInf), inf, resultReal);
691
692 Value imagSignInf = math::CopySignOp::create(b, inf, imag, fmf);
693 resultImag = arith::SelectOp::create(
694 b,
695 arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, absSqrt, absSqrt),
696 nan, resultImag);
697 resultImag = arith::SelectOp::create(
698 b, arith::OrIOp::create(b, absImagIsInf, realIsNegInf), imagSignInf,
699 resultImag);
700 }
701
702 Value resultIsZero =
703 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
704 resultReal = arith::SelectOp::create(b, resultIsZero, zero, resultReal);
705 resultImag = arith::SelectOp::create(b, resultIsZero, zero, resultImag);
706
707 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
708 resultImag);
709 return success();
710 }
711};
712
713struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
714 using OpConversionPattern<complex::SignOp>::OpConversionPattern;
715
716 LogicalResult
717 matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
718 ConversionPatternRewriter &rewriter) const override {
719 auto type = cast<ComplexType>(adaptor.getComplex().getType());
720 auto elementType = cast<FloatType>(type.getElementType());
721 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
722 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
723
724 Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
725 Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
726 Value zero =
727 arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType));
728 Value realIsZero =
729 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero);
730 Value imagIsZero =
731 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero);
732 Value isZero = arith::AndIOp::create(b, realIsZero, imagIsZero);
733 auto abs =
734 complex::AbsOp::create(b, elementType, adaptor.getComplex(), fmf);
735 Value realSign = arith::DivFOp::create(b, real, abs, fmf);
736 Value imagSign = arith::DivFOp::create(b, imag, abs, fmf);
737 Value sign = complex::CreateOp::create(b, type, realSign, imagSign);
738 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
739 adaptor.getComplex(), sign);
740 return success();
741 }
742};
743
744template <typename Op>
745struct TanTanhOpConversion : public OpConversionPattern<Op> {
746 using OpConversionPattern<Op>::OpConversionPattern;
747
748 LogicalResult
749 matchAndRewrite(Op op, typename Op::Adaptor adaptor,
750 ConversionPatternRewriter &rewriter) const override {
751 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
752 auto loc = op.getLoc();
753 auto type = cast<ComplexType>(adaptor.getComplex().getType());
754 auto elementType = cast<FloatType>(type.getElementType());
755 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
756 const auto &floatSemantics = elementType.getFloatSemantics();
757
758 Value real =
759 complex::ReOp::create(b, loc, elementType, adaptor.getComplex());
760 Value imag =
761 complex::ImOp::create(b, loc, elementType, adaptor.getComplex());
762 Value negOne = arith::ConstantOp::create(b, elementType,
763 b.getFloatAttr(elementType, -1.0));
764
765 if constexpr (std::is_same_v<Op, complex::TanOp>) {
766 // tan(x+yi) = -i*tanh(-y + xi)
767 std::swap(real, imag);
768 real = arith::MulFOp::create(b, real, negOne, fmf);
769 }
770
771 auto cst = [&](APFloat v) {
772 return arith::ConstantOp::create(b, elementType,
773 b.getFloatAttr(elementType, v));
774 };
775 Value inf = cst(APFloat::getInf(floatSemantics));
776 Value four = arith::ConstantOp::create(b, elementType,
777 b.getFloatAttr(elementType, 4.0));
778 Value twoReal = arith::AddFOp::create(b, real, real, fmf);
779 Value negTwoReal = arith::MulFOp::create(b, negOne, twoReal, fmf);
780
781 Value expTwoRealMinusOne = math::ExpM1Op::create(b, twoReal, fmf);
782 Value expNegTwoRealMinusOne = math::ExpM1Op::create(b, negTwoReal, fmf);
783 Value realNum = arith::SubFOp::create(b, expTwoRealMinusOne,
784 expNegTwoRealMinusOne, fmf);
785
786 Value cosImag = math::CosOp::create(b, imag, fmf);
787 Value cosImagSq = arith::MulFOp::create(b, cosImag, cosImag, fmf);
788 Value twoCosTwoImagPlusOne = arith::MulFOp::create(b, cosImagSq, four, fmf);
789 Value sinImag = math::SinOp::create(b, imag, fmf);
790
791 Value imagNum = arith::MulFOp::create(
792 b, four, arith::MulFOp::create(b, cosImag, sinImag, fmf), fmf);
793
794 Value expSumMinusTwo = arith::AddFOp::create(b, expTwoRealMinusOne,
795 expNegTwoRealMinusOne, fmf);
796 Value denom =
797 arith::AddFOp::create(b, expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
798
799 Value isInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
800 expSumMinusTwo, inf, fmf);
801 Value realLimit = math::CopySignOp::create(b, negOne, real, fmf);
802
803 Value resultReal = arith::SelectOp::create(
804 b, isInf, realLimit, arith::DivFOp::create(b, realNum, denom, fmf));
805 Value resultImag = arith::DivFOp::create(b, imagNum, denom, fmf);
806
807 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
808 arith::FastMathFlags::ninf)) {
809 Value absReal = math::AbsFOp::create(b, real, fmf);
810 Value zero = arith::ConstantOp::create(b, elementType,
811 b.getFloatAttr(elementType, 0.0));
812 Value nan = cst(APFloat::getNaN(floatSemantics));
813
814 Value absRealIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
815 absReal, inf, fmf);
816 Value imagIsZero =
817 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf);
818 Value absRealIsNotInf = arith::XOrIOp::create(
819 b, absRealIsInf, arith::ConstantIntOp::create(b, true, /*width=*/1));
820
821 Value imagNumIsNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO,
822 imagNum, imagNum, fmf);
823 Value resultRealIsNaN =
824 arith::AndIOp::create(b, imagNumIsNaN, absRealIsNotInf);
825 Value resultImagIsZero = arith::OrIOp::create(
826 b, imagIsZero, arith::AndIOp::create(b, absRealIsInf, imagNumIsNaN));
827
828 resultReal = arith::SelectOp::create(b, resultRealIsNaN, nan, resultReal);
829 resultImag =
830 arith::SelectOp::create(b, resultImagIsZero, zero, resultImag);
831 }
832
833 if constexpr (std::is_same_v<Op, complex::TanOp>) {
834 // tan(x+yi) = -i*tanh(-y + xi)
835 std::swap(resultReal, resultImag);
836 resultImag = arith::MulFOp::create(b, resultImag, negOne, fmf);
837 }
838
839 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
840 resultImag);
841 return success();
842 }
843};
844
845struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
846 using OpConversionPattern<complex::ConjOp>::OpConversionPattern;
847
848 LogicalResult
849 matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
850 ConversionPatternRewriter &rewriter) const override {
851 auto loc = op.getLoc();
852 auto type = cast<ComplexType>(adaptor.getComplex().getType());
853 auto elementType = cast<FloatType>(type.getElementType());
854 Value real =
855 complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
856 Value imag =
857 complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
858 Value negImag = arith::NegFOp::create(rewriter, loc, elementType, imag);
859
860 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);
861
862 return success();
863 }
864};
865
866/// Converts lhs^y = (a+bi)^(c+di) to
867/// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
868/// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
869static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
870 ComplexType type, Value lhs, Value c, Value d,
871 arith::FastMathFlags fmf) {
872 auto elementType = cast<FloatType>(type.getElementType());
873
874 Value a = complex::ReOp::create(builder, lhs);
875 Value b = complex::ImOp::create(builder, lhs);
876
877 Value abs = complex::AbsOp::create(builder, lhs, fmf);
878 Value absToC = math::PowFOp::create(builder, abs, c, fmf);
879
880 Value negD = arith::NegFOp::create(builder, d, fmf);
881 Value argLhs = math::Atan2Op::create(builder, b, a, fmf);
882 Value negDArgLhs = arith::MulFOp::create(builder, negD, argLhs, fmf);
883 Value expNegDArgLhs = math::ExpOp::create(builder, negDArgLhs, fmf);
884
885 Value coeff = arith::MulFOp::create(builder, absToC, expNegDArgLhs, fmf);
886 Value lnAbs = math::LogOp::create(builder, abs, fmf);
887 Value cArgLhs = arith::MulFOp::create(builder, c, argLhs, fmf);
888 Value dLnAbs = arith::MulFOp::create(builder, d, lnAbs, fmf);
889 Value q = arith::AddFOp::create(builder, cArgLhs, dLnAbs, fmf);
890 Value cosQ = math::CosOp::create(builder, q, fmf);
891 Value sinQ = math::SinOp::create(builder, q, fmf);
892
893 Value inf = arith::ConstantOp::create(
894 builder, elementType,
895 builder.getFloatAttr(elementType,
896 APFloat::getInf(elementType.getFloatSemantics())));
897 Value zero = arith::ConstantOp::create(
898 builder, elementType, builder.getFloatAttr(elementType, 0.0));
899 Value one = arith::ConstantOp::create(builder, elementType,
900 builder.getFloatAttr(elementType, 1.0));
901 Value complexOne = complex::CreateOp::create(builder, type, one, zero);
902 Value complexZero = complex::CreateOp::create(builder, type, zero, zero);
903 Value complexInf = complex::CreateOp::create(builder, type, inf, zero);
904
905 // Case 0:
906 // d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
907 // Branch Cuts for Complex Elementary Functions or Much Ado About
908 // Nothing's Sign Bit, W. Kahan, Section 10.
909 Value absEqZero =
910 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, abs, zero, fmf);
911 Value dEqZero =
912 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, d, zero, fmf);
913 Value cEqZero =
914 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, c, zero, fmf);
915 Value bEqZero =
916 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, b, zero, fmf);
917
918 Value zeroLeC =
919 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLE, zero, c, fmf);
920 Value coeffCosQ = arith::MulFOp::create(builder, coeff, cosQ, fmf);
921 Value coeffSinQ = arith::MulFOp::create(builder, coeff, sinQ, fmf);
922 Value complexOneOrZero =
923 arith::SelectOp::create(builder, cEqZero, complexOne, complexZero);
924 Value coeffCosSin =
925 complex::CreateOp::create(builder, type, coeffCosQ, coeffSinQ);
926 Value cutoff0 = arith::SelectOp::create(
927 builder,
928 arith::AndIOp::create(
929 builder, arith::AndIOp::create(builder, absEqZero, dEqZero), zeroLeC),
930 complexOneOrZero, coeffCosSin);
931
932 // Case 1:
933 // x^0 is defined to be 1 for any x, see
934 // Branch Cuts for Complex Elementary Functions or Much Ado About
935 // Nothing's Sign Bit, W. Kahan, Section 10.
936 Value rhsEqZero = arith::AndIOp::create(builder, cEqZero, dEqZero);
937 Value cutoff1 =
938 arith::SelectOp::create(builder, rhsEqZero, complexOne, cutoff0);
939
940 // Case 2:
941 // 1^(c + d*i) = 1 + 0*i
942 Value lhsEqOne = arith::AndIOp::create(
943 builder,
944 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, one, fmf),
945 bEqZero);
946 Value cutoff2 =
947 arith::SelectOp::create(builder, lhsEqOne, complexOne, cutoff1);
948
949 // Case 3:
950 // inf^(c + 0*i) = inf + 0*i, c > 0
951 Value lhsEqInf = arith::AndIOp::create(
952 builder,
953 arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, inf, fmf),
954 bEqZero);
955 Value rhsGt0 = arith::AndIOp::create(
956 builder, dEqZero,
957 arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, c, zero, fmf));
958 Value cutoff3 = arith::SelectOp::create(
959 builder, arith::AndIOp::create(builder, lhsEqInf, rhsGt0), complexInf,
960 cutoff2);
961
962 // Case 4:
963 // inf^(c + 0*i) = 0 + 0*i, c < 0
964 Value rhsLt0 = arith::AndIOp::create(
965 builder, dEqZero,
966 arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, c, zero, fmf));
967 Value cutoff4 = arith::SelectOp::create(
968 builder, arith::AndIOp::create(builder, lhsEqInf, rhsLt0), complexZero,
969 cutoff3);
970
971 return cutoff4;
972}
973
974struct PowiOpConversion : public OpConversionPattern<complex::PowiOp> {
975 using OpConversionPattern<complex::PowiOp>::OpConversionPattern;
976
977 LogicalResult
978 matchAndRewrite(complex::PowiOp op, OpAdaptor adaptor,
979 ConversionPatternRewriter &rewriter) const override {
980 ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
981 auto type = cast<ComplexType>(op.getType());
982 auto elementType = cast<FloatType>(type.getElementType());
983
984 Value floatExponent =
985 arith::SIToFPOp::create(builder, elementType, adaptor.getRhs());
986 Value zero = arith::ConstantOp::create(
987 builder, elementType, builder.getFloatAttr(elementType, 0.0));
988 Value complexExponent =
989 complex::CreateOp::create(builder, type, floatExponent, zero);
990
991 auto pow = complex::PowOp::create(builder, type, adaptor.getLhs(),
992 complexExponent, op.getFastmathAttr());
993 rewriter.replaceOp(op, pow.getResult());
994 return success();
995 }
996};
997
998struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
999 using OpConversionPattern<complex::PowOp>::OpConversionPattern;
1000
1001 LogicalResult
1002 matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
1003 ConversionPatternRewriter &rewriter) const override {
1004 mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
1005 auto type = cast<ComplexType>(adaptor.getLhs().getType());
1006 auto elementType = cast<FloatType>(type.getElementType());
1007
1008 Value c = complex::ReOp::create(builder, elementType, adaptor.getRhs());
1009 Value d = complex::ImOp::create(builder, elementType, adaptor.getRhs());
1010
1011 rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
1012 c, d, op.getFastmath())});
1013 return success();
1014 }
1015};
1016
1017struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
1018 using OpConversionPattern<complex::RsqrtOp>::OpConversionPattern;
1019
1020 LogicalResult
1021 matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
1022 ConversionPatternRewriter &rewriter) const override {
1023 mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1024 auto type = cast<ComplexType>(adaptor.getComplex().getType());
1025 auto elementType = cast<FloatType>(type.getElementType());
1026
1027 arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
1028
1029 auto cst = [&](APFloat v) {
1030 return arith::ConstantOp::create(b, elementType,
1031 b.getFloatAttr(elementType, v));
1032 };
1033 const auto &floatSemantics = elementType.getFloatSemantics();
1034 Value zero = cst(APFloat::getZero(floatSemantics));
1035 Value inf = cst(APFloat::getInf(floatSemantics));
1036 Value negHalf = arith::ConstantOp::create(
1037 b, elementType, b.getFloatAttr(elementType, -0.5));
1038 Value nan = cst(APFloat::getNaN(floatSemantics));
1039
1040 Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
1041 Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
1042 Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
1043 Value argArg = math::Atan2Op::create(b, imag, real, fmf);
1044 Value rsqrtArg = arith::MulFOp::create(b, argArg, negHalf, fmf);
1045 Value cos = math::CosOp::create(b, rsqrtArg, fmf);
1046 Value sin = math::SinOp::create(b, rsqrtArg, fmf);
1047
1048 Value resultReal = arith::MulFOp::create(b, absRsqrt, cos, fmf);
1049 Value resultImag = arith::MulFOp::create(b, absRsqrt, sin, fmf);
1050
1051 if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
1052 arith::FastMathFlags::ninf)) {
1053 Value negOne = arith::ConstantOp::create(b, elementType,
1054 b.getFloatAttr(elementType, -1));
1055
1056 Value realSignedZero = math::CopySignOp::create(b, zero, real, fmf);
1057 Value imagSignedZero = math::CopySignOp::create(b, zero, imag, fmf);
1058 Value negImagSignedZero =
1059 arith::MulFOp::create(b, negOne, imagSignedZero, fmf);
1060
1061 Value absReal = math::AbsFOp::create(b, real, fmf);
1062 Value absImag = math::AbsFOp::create(b, imag, fmf);
1063
1064 Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
1065 absImag, inf, fmf);
1066 Value realIsNan =
1067 arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, real, real, fmf);
1068 Value realIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
1069 absReal, inf, fmf);
1070 Value inIsNanInf = arith::AndIOp::create(b, absImagIsInf, realIsNan);
1071
1072 Value resultIsZero = arith::OrIOp::create(b, inIsNanInf, realIsInf);
1073
1074 resultReal =
1075 arith::SelectOp::create(b, resultIsZero, realSignedZero, resultReal);
1076 resultImag = arith::SelectOp::create(b, resultIsZero, negImagSignedZero,
1077 resultImag);
1078 }
1079
1080 Value isRealZero =
1081 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero, fmf);
1082 Value isImagZero =
1083 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf);
1084 Value isZero = arith::AndIOp::create(b, isRealZero, isImagZero);
1085
1086 resultReal = arith::SelectOp::create(b, isZero, inf, resultReal);
1087 resultImag = arith::SelectOp::create(b, isZero, nan, resultImag);
1088
1089 rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
1090 resultImag);
1091 return success();
1092 }
1093};
1094
1095struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
1096 using OpConversionPattern<complex::AngleOp>::OpConversionPattern;
1097
1098 LogicalResult
1099 matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
1100 ConversionPatternRewriter &rewriter) const override {
1101 auto loc = op.getLoc();
1102 auto type = op.getType();
1103 arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
1104
1105 Value real =
1106 complex::ReOp::create(rewriter, loc, type, adaptor.getComplex());
1107 Value imag =
1108 complex::ImOp::create(rewriter, loc, type, adaptor.getComplex());
1109
1110 rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf);
1111
1112 return success();
1113 }
1114};
1115
1116} // namespace
1117
1119 RewritePatternSet &patterns, complex::ComplexRangeFlags complexRange) {
1120 // clang-format off
1121 patterns.add<
1122 AbsOpConversion,
1123 AngleOpConversion,
1124 Atan2OpConversion,
1125 BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
1126 BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
1127 ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
1128 ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
1129 ConjOpConversion,
1130 CosOpConversion,
1131 ExpOpConversion,
1132 Expm1OpConversion,
1133 Log1pOpConversion,
1134 LogOpConversion,
1135 MulOpConversion,
1136 NegOpConversion,
1137 SignOpConversion,
1138 SinOpConversion,
1139 SqrtOpConversion,
1140 TanTanhOpConversion<complex::TanOp>,
1141 TanTanhOpConversion<complex::TanhOp>,
1142 PowiOpConversion,
1143 PowOpConversion,
1144 RsqrtOpConversion
1145 >(patterns.getContext());
1146
1147 patterns.add<DivOpConversion>(patterns.getContext(), complexRange);
1148
1149 // clang-format on
1150}
1151
1152namespace {
1153struct ConvertComplexToStandardPass
1154 : public impl::ConvertComplexToStandardPassBase<
1155 ConvertComplexToStandardPass> {
1156 using Base::Base;
1157
1158 void runOnOperation() override;
1159};
1160
1161void ConvertComplexToStandardPass::runOnOperation() {
1162 // Convert to the Standard dialect using the converter defined above.
1163 RewritePatternSet patterns(&getContext());
1164 populateComplexToStandardConversionPatterns(patterns, complexRange);
1165
1166 ConversionTarget target(getContext());
1167 target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
1168 target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1169 if (failed(
1170 applyPartialConversion(getOperation(), target, std::move(patterns))))
1171 signalPassFailure();
1172}
1173} // namespace
return success()
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:268
NestedPattern Op(FilterFunctionType filter=defaultFilterFunction)
void convertDivToStandardUsingAlgebraic(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the arith/math dialects using algebraic method
void convertDivToStandardUsingRangeReduction(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, arith::FastMathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the arith/math dialects using Smith's method
Fraction abs(const Fraction &f)
Definition Fraction.h:107
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
OwningOpRef< spirv::ModuleOp > combine(ArrayRef< spirv::ModuleOp > inputModules, OpBuilder &combinedModuleBuilder, SymbolRenameListener symRenameListener)
Combines a list of SPIR-V inputModules into one.
Include the generated interface declarations.
constexpr T real(const NonFloatComplex< T > &x)
Definition Complex.h:255
void populateComplexToStandardConversionPatterns(RewritePatternSet &patterns, mlir::complex::ComplexRangeFlags complexRange=mlir::complex::ComplexRangeFlags::improved)
Populate the given list with patterns that convert from Complex to Standard.
constexpr T imag(const NonFloatComplex< T > &x)
Definition Complex.h:260