MLIR 23.0.0git
ComplexToLLVM.cpp
Go to the documentation of this file.
1//===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
19
20namespace mlir {
21#define GEN_PASS_DEF_CONVERTCOMPLEXTOLLVMPASS
22#include "mlir/Conversion/Passes.h.inc"
23} // namespace mlir
24
25using namespace mlir;
26using namespace mlir::LLVM;
27using namespace mlir::arith;
28
29//===----------------------------------------------------------------------===//
30// ComplexStructBuilder implementation.
31//===----------------------------------------------------------------------===//
32
33static constexpr unsigned kRealPosInComplexNumberStruct = 0;
34static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
35
37 Location loc, Type type) {
38 Value val = LLVM::PoisonOp::create(builder, loc, type);
39 return ComplexStructBuilder(val);
40}
41
46
50
55
59
60//===----------------------------------------------------------------------===//
61// Conversion patterns.
62//===----------------------------------------------------------------------===//
63
64namespace {
65
66struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
68
69 LogicalResult
70 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
71 ConversionPatternRewriter &rewriter) const override {
72 auto loc = op.getLoc();
73
74 ComplexStructBuilder complexStruct(adaptor.getComplex());
75 Value real = complexStruct.real(rewriter, op.getLoc());
76 Value imag = complexStruct.imaginary(rewriter, op.getLoc());
77
78 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
79 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
80 op.getContext(),
81 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
82 Value sqNorm = LLVM::FAddOp::create(
83 rewriter, loc, LLVM::FMulOp::create(rewriter, loc, real, real, fmf),
84 LLVM::FMulOp::create(rewriter, loc, imag, imag, fmf), fmf);
85
86 rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
87 return success();
88 }
89};
90
91struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
93
94 LogicalResult
95 matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor,
96 ConversionPatternRewriter &rewriter) const override {
98 op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
99 op->getAttrs(), /*propAttr=*/Attribute{}, *getTypeConverter(),
100 rewriter);
101 }
102};
103
104struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
105 using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
106
107 LogicalResult
108 matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
109 ConversionPatternRewriter &rewriter) const override {
110 // Pack real and imaginary part in a complex number struct.
111 auto loc = complexOp.getLoc();
112 auto structType = typeConverter->convertType(complexOp.getType());
113 auto complexStruct =
114 ComplexStructBuilder::poison(rewriter, loc, structType);
115 complexStruct.setReal(rewriter, loc, adaptor.getReal());
116 complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary());
117
118 rewriter.replaceOp(complexOp, {complexStruct});
119 return success();
120 }
121};
122
123struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
124 using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;
125
126 LogicalResult
127 matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
128 ConversionPatternRewriter &rewriter) const override {
129 // Extract real part from the complex number struct.
130 ComplexStructBuilder complexStruct(adaptor.getComplex());
131 Value real = complexStruct.real(rewriter, op.getLoc());
132 rewriter.replaceOp(op, real);
133
134 return success();
135 }
136};
137
138struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
139 using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;
140
141 LogicalResult
142 matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
143 ConversionPatternRewriter &rewriter) const override {
144 // Extract imaginary part from the complex number struct.
145 ComplexStructBuilder complexStruct(adaptor.getComplex());
146 Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
147 rewriter.replaceOp(op, imaginary);
148
149 return success();
150 }
151};
152
153struct BinaryComplexOperands {
156};
157
158template <typename OpTy>
159BinaryComplexOperands
160unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor,
161 ConversionPatternRewriter &rewriter) {
162 auto loc = op.getLoc();
163
164 // Extract real and imaginary values from operands.
165 BinaryComplexOperands unpacked;
166 ComplexStructBuilder lhs(adaptor.getLhs());
167 unpacked.lhs.real(lhs.real(rewriter, loc));
168 unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
169 ComplexStructBuilder rhs(adaptor.getRhs());
170 unpacked.rhs.real(rhs.real(rewriter, loc));
171 unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
172
173 return unpacked;
174}
175
176struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
177 using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;
178
179 LogicalResult
180 matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
181 ConversionPatternRewriter &rewriter) const override {
182 auto loc = op.getLoc();
183 BinaryComplexOperands arg =
184 unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
185
186 // Initialize complex number struct for result.
187 auto structType = typeConverter->convertType(op.getType());
188 auto result = ComplexStructBuilder::poison(rewriter, loc, structType);
189
190 // Emit IR to add complex numbers.
191 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
192 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
193 op.getContext(),
194 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
195 Value real = LLVM::FAddOp::create(rewriter, loc, arg.lhs.real(),
196 arg.rhs.real(), fmf);
197 Value imag = LLVM::FAddOp::create(rewriter, loc, arg.lhs.imag(),
198 arg.rhs.imag(), fmf);
199 result.setReal(rewriter, loc, real);
200 result.setImaginary(rewriter, loc, imag);
201
202 rewriter.replaceOp(op, {result});
203 return success();
204 }
205};
206
207struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
208 DivOpConversion(const LLVMTypeConverter &converter,
209 complex::ComplexRangeFlags target)
210 : ConvertOpToLLVMPattern<complex::DivOp>(converter),
211 complexRange(target) {}
212
213 using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern;
214
215 LogicalResult
216 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
217 ConversionPatternRewriter &rewriter) const override {
218 auto loc = op.getLoc();
219 BinaryComplexOperands arg =
220 unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
221
222 // Initialize complex number struct for result.
223 auto structType = typeConverter->convertType(op.getType());
224 auto result = ComplexStructBuilder::poison(rewriter, loc, structType);
225
226 // Emit IR to add complex numbers.
227 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
228 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
229 op.getContext(),
230 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
231 Value rhsRe = arg.rhs.real();
232 Value rhsIm = arg.rhs.imag();
233 Value lhsRe = arg.lhs.real();
234 Value lhsIm = arg.lhs.imag();
235
236 Value resultRe, resultIm;
237
238 if (complexRange == complex::ComplexRangeFlags::basic ||
239 complexRange == complex::ComplexRangeFlags::none) {
241 rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm);
242 } else if (complexRange == complex::ComplexRangeFlags::improved) {
244 rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm);
245 }
246
247 result.setReal(rewriter, loc, resultRe);
248 result.setImaginary(rewriter, loc, resultIm);
249
250 rewriter.replaceOp(op, {result});
251 return success();
252 }
253
254private:
255 complex::ComplexRangeFlags complexRange;
256};
257
258struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
259 using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern;
260
261 LogicalResult
262 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
263 ConversionPatternRewriter &rewriter) const override {
264 auto loc = op.getLoc();
265 BinaryComplexOperands arg =
266 unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
267
268 // Initialize complex number struct for result.
269 auto structType = typeConverter->convertType(op.getType());
270 auto result = ComplexStructBuilder::poison(rewriter, loc, structType);
271
272 // Emit IR to add complex numbers.
273 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
274 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
275 op.getContext(),
276 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
277 Value rhsRe = arg.rhs.real();
278 Value rhsIm = arg.rhs.imag();
279 Value lhsRe = arg.lhs.real();
280 Value lhsIm = arg.lhs.imag();
281
282 Value real;
283 Value imag;
284 if (arith::bitEnumContainsAll(complexFMFAttr.getValue(),
285 arith::FastMathFlags::contract)) {
286 Value lhsImagTimesRhsImag =
287 LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsIm, fmf);
288 Value negLhsImagTimesRhsImag =
289 LLVM::FNegOp::create(rewriter, loc, lhsImagTimesRhsImag, fmf);
290 real = LLVM::FMAOp::create(rewriter, loc, lhsRe, rhsRe,
291 negLhsImagTimesRhsImag, fmf);
292
293 Value lhsImagTimesRhsReal =
294 LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf);
295 imag = LLVM::FMAOp::create(rewriter, loc, lhsRe, rhsIm,
296 lhsImagTimesRhsReal, fmf);
297 } else {
298 Value lhsRealTimesRhsReal =
299 LLVM::FMulOp::create(rewriter, loc, rhsRe, lhsRe, fmf);
300 Value lhsImagTimesRhsImag =
301 LLVM::FMulOp::create(rewriter, loc, rhsIm, lhsIm, fmf);
302 Value lhsImagTimesRhsReal =
303 LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf);
304 Value lhsRealTimesRhsImag =
305 LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf);
306
307 real = LLVM::FSubOp::create(rewriter, loc, lhsRealTimesRhsReal,
308 lhsImagTimesRhsImag, fmf);
309
310 imag = LLVM::FAddOp::create(rewriter, loc, lhsImagTimesRhsReal,
311 lhsRealTimesRhsImag, fmf);
312 }
313
314 result.setReal(rewriter, loc, real);
315 result.setImaginary(rewriter, loc, imag);
316
317 rewriter.replaceOp(op, {result});
318 return success();
319 }
320};
321
322struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
323 using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
324
325 LogicalResult
326 matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
327 ConversionPatternRewriter &rewriter) const override {
328 auto loc = op.getLoc();
329 BinaryComplexOperands arg =
330 unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
331
332 // Initialize complex number struct for result.
333 auto structType = typeConverter->convertType(op.getType());
334 auto result = ComplexStructBuilder::poison(rewriter, loc, structType);
335
336 // Emit IR to substract complex numbers.
337 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
338 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
339 op.getContext(),
340 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
341 Value real = LLVM::FSubOp::create(rewriter, loc, arg.lhs.real(),
342 arg.rhs.real(), fmf);
343 Value imag = LLVM::FSubOp::create(rewriter, loc, arg.lhs.imag(),
344 arg.rhs.imag(), fmf);
345 result.setReal(rewriter, loc, real);
346 result.setImaginary(rewriter, loc, imag);
347
348 rewriter.replaceOp(op, {result});
349 return success();
350 }
351};
352} // namespace
353
355 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
356 complex::ComplexRangeFlags complexRange) {
357 // clang-format off
358 patterns.add<
359 AbsOpConversion,
360 AddOpConversion,
361 ConstantOpLowering,
362 CreateOpConversion,
363 ImOpConversion,
364 MulOpConversion,
365 ReOpConversion,
366 SubOpConversion
367 >(converter);
368
369 patterns.add<DivOpConversion>(converter, complexRange);
370 // clang-format on
371}
372
373namespace {
374struct ConvertComplexToLLVMPass
375 : public impl::ConvertComplexToLLVMPassBase<ConvertComplexToLLVMPass> {
376 using Base::Base;
377
378 void runOnOperation() override;
379};
380} // namespace
381
382void ConvertComplexToLLVMPass::runOnOperation() {
383 // Convert to the LLVM IR dialect using the converter defined above.
384 RewritePatternSet patterns(&getContext());
385 LLVMTypeConverter converter(&getContext());
386 populateComplexToLLVMConversionPatterns(converter, patterns, complexRange);
387
389 target.addIllegalDialect<complex::ComplexDialect>();
390 if (failed(
391 applyPartialConversion(getOperation(), target, std::move(patterns))))
392 signalPassFailure();
393}
394
395//===----------------------------------------------------------------------===//
396// ConvertToLLVMPatternInterface implementation
397//===----------------------------------------------------------------------===//
398
399namespace {
400/// Implement the interface to convert MemRef to LLVM.
401struct ComplexToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
402 ComplexToLLVMDialectInterface(Dialect *dialect)
403 : ConvertToLLVMPatternInterface(dialect) {}
404
405 void loadDependentDialects(MLIRContext *context) const final {
406 context->loadDialect<LLVM::LLVMDialect>();
407 }
408
409 /// Hook for derived dialect interface to provide conversion patterns
410 /// and mark dialect legal for the conversion target.
411 void populateConvertToLLVMConversionPatterns(
412 ConversionTarget &target, LLVMTypeConverter &typeConverter,
413 RewritePatternSet &patterns) const final {
414 populateComplexToLLVMConversionPatterns(typeConverter, patterns);
415 }
416};
417} // namespace
418
420 registry.addExtension(
421 +[](MLIRContext *ctx, complex::ComplexDialect *dialect) {
422 dialect->addInterfaces<ComplexToLLVMDialectInterface>();
423 });
424}
return success()
static constexpr unsigned kRealPosInComplexNumberStruct
static constexpr unsigned kImaginaryPosInComplexNumberStruct
lhs
b getContext())
Value imaginary(OpBuilder &builder, Location loc)
void setImaginary(OpBuilder &builder, Location loc, Value imaginary)
void setReal(OpBuilder &builder, Location loc, Value real)
Value real(OpBuilder &builder, Location loc)
static ComplexStructBuilder poison(OpBuilder &builder, Location loc, Type type)
Build IR creating an undef value of the complex number type.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:227
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:233
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr)
Builds IR to set a value in the struct at position pos.
Value extractPtr(OpBuilder &builder, Location loc, unsigned pos) const
Builds IR to extract a value from the struct at position pos.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, Attribute propertiesAttr, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition Pattern.cpp:313
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
void convertDivToLLVMUsingRangeReduction(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the LLVM dialect using Smith's method
void convertDivToLLVMUsingAlgebraic(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the LLVM dialect using algebraic method
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
void populateComplexToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, mlir::complex::ComplexRangeFlags complexRange=mlir::complex::ComplexRangeFlags::basic)
Populate the given list with patterns that convert from Complex to LLVM.
constexpr T real(const NonFloatComplex< T > &x)
Definition Complex.h:255
std::conditional_t< std::is_floating_point_v< T >, std::complex< T >, NonFloatComplex< T > > Complex
Definition Complex.h:265
void registerConvertComplexToLLVMInterface(DialectRegistry &registry)
constexpr T imag(const NonFloatComplex< T > &x)
Definition Complex.h:260