MLIR 22.0.0git
ComplexToLLVM.cpp
Go to the documentation of this file.
1//===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
19
20namespace mlir {
21#define GEN_PASS_DEF_CONVERTCOMPLEXTOLLVMPASS
22#include "mlir/Conversion/Passes.h.inc"
23} // namespace mlir
24
25using namespace mlir;
26using namespace mlir::LLVM;
27using namespace mlir::arith;
28
29//===----------------------------------------------------------------------===//
30// ComplexStructBuilder implementation.
31//===----------------------------------------------------------------------===//
32
33static constexpr unsigned kRealPosInComplexNumberStruct = 0;
34static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
35
37 Location loc, Type type) {
38 Value val = LLVM::PoisonOp::create(builder, loc, type);
39 return ComplexStructBuilder(val);
40}
41
46
50
55
59
60//===----------------------------------------------------------------------===//
61// Conversion patterns.
62//===----------------------------------------------------------------------===//
63
64namespace {
65
66struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
68
69 LogicalResult
70 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
71 ConversionPatternRewriter &rewriter) const override {
72 auto loc = op.getLoc();
73
74 ComplexStructBuilder complexStruct(adaptor.getComplex());
75 Value real = complexStruct.real(rewriter, op.getLoc());
76 Value imag = complexStruct.imaginary(rewriter, op.getLoc());
77
78 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
79 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
80 op.getContext(),
81 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
82 Value sqNorm = LLVM::FAddOp::create(
83 rewriter, loc, LLVM::FMulOp::create(rewriter, loc, real, real, fmf),
84 LLVM::FMulOp::create(rewriter, loc, imag, imag, fmf), fmf);
85
86 rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
87 return success();
88 }
89};
90
91struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
93
94 LogicalResult
95 matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor,
96 ConversionPatternRewriter &rewriter) const override {
98 op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
99 op->getAttrs(), *getTypeConverter(), rewriter);
100 }
101};
102
103struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
104 using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
105
106 LogicalResult
107 matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
108 ConversionPatternRewriter &rewriter) const override {
109 // Pack real and imaginary part in a complex number struct.
110 auto loc = complexOp.getLoc();
111 auto structType = typeConverter->convertType(complexOp.getType());
112 auto complexStruct =
113 ComplexStructBuilder::poison(rewriter, loc, structType);
114 complexStruct.setReal(rewriter, loc, adaptor.getReal());
115 complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary());
116
117 rewriter.replaceOp(complexOp, {complexStruct});
118 return success();
119 }
120};
121
122struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
123 using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;
124
125 LogicalResult
126 matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
127 ConversionPatternRewriter &rewriter) const override {
128 // Extract real part from the complex number struct.
129 ComplexStructBuilder complexStruct(adaptor.getComplex());
130 Value real = complexStruct.real(rewriter, op.getLoc());
131 rewriter.replaceOp(op, real);
132
133 return success();
134 }
135};
136
137struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
138 using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;
139
140 LogicalResult
141 matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
142 ConversionPatternRewriter &rewriter) const override {
143 // Extract imaginary part from the complex number struct.
144 ComplexStructBuilder complexStruct(adaptor.getComplex());
145 Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
146 rewriter.replaceOp(op, imaginary);
147
148 return success();
149 }
150};
151
152struct BinaryComplexOperands {
153 std::complex<Value> lhs;
154 std::complex<Value> rhs;
155};
156
157template <typename OpTy>
158BinaryComplexOperands
159unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor,
160 ConversionPatternRewriter &rewriter) {
161 auto loc = op.getLoc();
162
163 // Extract real and imaginary values from operands.
164 BinaryComplexOperands unpacked;
165 ComplexStructBuilder lhs(adaptor.getLhs());
166 unpacked.lhs.real(lhs.real(rewriter, loc));
167 unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
168 ComplexStructBuilder rhs(adaptor.getRhs());
169 unpacked.rhs.real(rhs.real(rewriter, loc));
170 unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
171
172 return unpacked;
173}
174
175struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
176 using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;
177
178 LogicalResult
179 matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
180 ConversionPatternRewriter &rewriter) const override {
181 auto loc = op.getLoc();
182 BinaryComplexOperands arg =
183 unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
184
185 // Initialize complex number struct for result.
186 auto structType = typeConverter->convertType(op.getType());
187 auto result = ComplexStructBuilder::poison(rewriter, loc, structType);
188
189 // Emit IR to add complex numbers.
190 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
191 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
192 op.getContext(),
193 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
194 Value real = LLVM::FAddOp::create(rewriter, loc, arg.lhs.real(),
195 arg.rhs.real(), fmf);
196 Value imag = LLVM::FAddOp::create(rewriter, loc, arg.lhs.imag(),
197 arg.rhs.imag(), fmf);
198 result.setReal(rewriter, loc, real);
199 result.setImaginary(rewriter, loc, imag);
200
201 rewriter.replaceOp(op, {result});
202 return success();
203 }
204};
205
206struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
207 DivOpConversion(const LLVMTypeConverter &converter,
208 complex::ComplexRangeFlags target)
209 : ConvertOpToLLVMPattern<complex::DivOp>(converter),
210 complexRange(target) {}
211
212 using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern;
213
214 LogicalResult
215 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
216 ConversionPatternRewriter &rewriter) const override {
217 auto loc = op.getLoc();
218 BinaryComplexOperands arg =
219 unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
220
221 // Initialize complex number struct for result.
222 auto structType = typeConverter->convertType(op.getType());
223 auto result = ComplexStructBuilder::poison(rewriter, loc, structType);
224
225 // Emit IR to add complex numbers.
226 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
227 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
228 op.getContext(),
229 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
230 Value rhsRe = arg.rhs.real();
231 Value rhsIm = arg.rhs.imag();
232 Value lhsRe = arg.lhs.real();
233 Value lhsIm = arg.lhs.imag();
234
235 Value resultRe, resultIm;
236
237 if (complexRange == complex::ComplexRangeFlags::basic ||
238 complexRange == complex::ComplexRangeFlags::none) {
240 rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm);
241 } else if (complexRange == complex::ComplexRangeFlags::improved) {
243 rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm);
244 }
245
246 result.setReal(rewriter, loc, resultRe);
247 result.setImaginary(rewriter, loc, resultIm);
248
249 rewriter.replaceOp(op, {result});
250 return success();
251 }
252
253private:
254 complex::ComplexRangeFlags complexRange;
255};
256
257struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
258 using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern;
259
260 LogicalResult
261 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
262 ConversionPatternRewriter &rewriter) const override {
263 auto loc = op.getLoc();
264 BinaryComplexOperands arg =
265 unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
266
267 // Initialize complex number struct for result.
268 auto structType = typeConverter->convertType(op.getType());
269 auto result = ComplexStructBuilder::poison(rewriter, loc, structType);
270
271 // Emit IR to add complex numbers.
272 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
273 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
274 op.getContext(),
275 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
276 Value rhsRe = arg.rhs.real();
277 Value rhsIm = arg.rhs.imag();
278 Value lhsRe = arg.lhs.real();
279 Value lhsIm = arg.lhs.imag();
280
281 Value real = LLVM::FSubOp::create(
282 rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, lhsRe, fmf),
283 LLVM::FMulOp::create(rewriter, loc, rhsIm, lhsIm, fmf), fmf);
284
285 Value imag = LLVM::FAddOp::create(
286 rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
287 LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
288
289 result.setReal(rewriter, loc, real);
290 result.setImaginary(rewriter, loc, imag);
291
292 rewriter.replaceOp(op, {result});
293 return success();
294 }
295};
296
297struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
298 using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
299
300 LogicalResult
301 matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
302 ConversionPatternRewriter &rewriter) const override {
303 auto loc = op.getLoc();
304 BinaryComplexOperands arg =
305 unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
306
307 // Initialize complex number struct for result.
308 auto structType = typeConverter->convertType(op.getType());
309 auto result = ComplexStructBuilder::poison(rewriter, loc, structType);
310
311 // Emit IR to substract complex numbers.
312 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
313 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
314 op.getContext(),
315 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
316 Value real = LLVM::FSubOp::create(rewriter, loc, arg.lhs.real(),
317 arg.rhs.real(), fmf);
318 Value imag = LLVM::FSubOp::create(rewriter, loc, arg.lhs.imag(),
319 arg.rhs.imag(), fmf);
320 result.setReal(rewriter, loc, real);
321 result.setImaginary(rewriter, loc, imag);
322
323 rewriter.replaceOp(op, {result});
324 return success();
325 }
326};
327} // namespace
328
331 complex::ComplexRangeFlags complexRange) {
332 // clang-format off
333 patterns.add<
334 AbsOpConversion,
335 AddOpConversion,
336 ConstantOpLowering,
337 CreateOpConversion,
338 ImOpConversion,
339 MulOpConversion,
340 ReOpConversion,
341 SubOpConversion
342 >(converter);
343
344 patterns.add<DivOpConversion>(converter, complexRange);
345 // clang-format on
346}
347
348namespace {
349struct ConvertComplexToLLVMPass
350 : public impl::ConvertComplexToLLVMPassBase<ConvertComplexToLLVMPass> {
351 using Base::Base;
352
353 void runOnOperation() override;
354};
355} // namespace
356
357void ConvertComplexToLLVMPass::runOnOperation() {
358 // Convert to the LLVM IR dialect using the converter defined above.
360 LLVMTypeConverter converter(&getContext());
361 populateComplexToLLVMConversionPatterns(converter, patterns, complexRange);
362
364 target.addIllegalDialect<complex::ComplexDialect>();
365 if (failed(
366 applyPartialConversion(getOperation(), target, std::move(patterns))))
367 signalPassFailure();
368}
369
370//===----------------------------------------------------------------------===//
371// ConvertToLLVMPatternInterface implementation
372//===----------------------------------------------------------------------===//
373
374namespace {
375/// Implement the interface to convert MemRef to LLVM.
376struct ComplexToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
378 void loadDependentDialects(MLIRContext *context) const final {
379 context->loadDialect<LLVM::LLVMDialect>();
380 }
381
382 /// Hook for derived dialect interface to provide conversion patterns
383 /// and mark dialect legal for the conversion target.
384 void populateConvertToLLVMConversionPatterns(
385 ConversionTarget &target, LLVMTypeConverter &typeConverter,
386 RewritePatternSet &patterns) const final {
388 }
389};
390} // namespace
391
393 registry.addExtension(
394 +[](MLIRContext *ctx, complex::ComplexDialect *dialect) {
395 dialect->addInterfaces<ComplexToLLVMDialectInterface>();
396 });
397}
return success()
static constexpr unsigned kRealPosInComplexNumberStruct
static constexpr unsigned kImaginaryPosInComplexNumberStruct
lhs
b getContext())
Value imaginary(OpBuilder &builder, Location loc)
void setImaginary(OpBuilder &builder, Location loc, Value imaginary)
void setReal(OpBuilder &builder, Location loc, Value real)
Value real(OpBuilder &builder, Location loc)
static ComplexStructBuilder poison(OpBuilder &builder, Location loc, Type type)
Build IR creating an undef value of the complex number type.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:209
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:215
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr)
Builds IR to set a value in the struct at position pos.
Value extractPtr(OpBuilder &builder, Location loc, unsigned pos) const
Builds IR to extract a value from the struct at position pos.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition Pattern.cpp:307
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
void convertDivToLLVMUsingRangeReduction(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the LLVM dialect using Smith's method
void convertDivToLLVMUsingAlgebraic(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the LLVM dialect using algebraic method
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
void populateComplexToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, mlir::complex::ComplexRangeFlags complexRange=mlir::complex::ComplexRangeFlags::basic)
Populate the given list with patterns that convert from Complex to LLVM.
const FrozenRewritePatternSet & patterns
void registerConvertComplexToLLVMInterface(DialectRegistry &registry)