MLIR 22.0.0git
ComplexToLLVM.cpp
Go to the documentation of this file.
1//===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
19
20namespace mlir {
21#define GEN_PASS_DEF_CONVERTCOMPLEXTOLLVMPASS
22#include "mlir/Conversion/Passes.h.inc"
23} // namespace mlir
24
25using namespace mlir;
26using namespace mlir::LLVM;
27using namespace mlir::arith;
28
29//===----------------------------------------------------------------------===//
30// ComplexStructBuilder implementation.
31//===----------------------------------------------------------------------===//
32
33static constexpr unsigned kRealPosInComplexNumberStruct = 0;
34static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
35
37 Location loc, Type type) {
38 Value val = LLVM::PoisonOp::create(builder, loc, type);
39 return ComplexStructBuilder(val);
40}
41
46
50
55
59
60//===----------------------------------------------------------------------===//
61// Conversion patterns.
62//===----------------------------------------------------------------------===//
63
64namespace {
65
66struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
68
69 LogicalResult
70 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
71 ConversionPatternRewriter &rewriter) const override {
72 auto loc = op.getLoc();
73
74 ComplexStructBuilder complexStruct(adaptor.getComplex());
75 Value real = complexStruct.real(rewriter, op.getLoc());
76 Value imag = complexStruct.imaginary(rewriter, op.getLoc());
77
78 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
79 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
80 op.getContext(),
81 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
82 Value sqNorm = LLVM::FAddOp::create(
83 rewriter, loc, LLVM::FMulOp::create(rewriter, loc, real, real, fmf),
84 LLVM::FMulOp::create(rewriter, loc, imag, imag, fmf), fmf);
85
86 rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
87 return success();
88 }
89};
90
91struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
93
94 LogicalResult
95 matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor,
96 ConversionPatternRewriter &rewriter) const override {
98 op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
99 op->getAttrs(), /*propAttr=*/Attribute{}, *getTypeConverter(),
100 rewriter);
101 }
102};
103
104struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
105 using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
106
107 LogicalResult
108 matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
109 ConversionPatternRewriter &rewriter) const override {
110 // Pack real and imaginary part in a complex number struct.
111 auto loc = complexOp.getLoc();
112 auto structType = typeConverter->convertType(complexOp.getType());
113 auto complexStruct =
114 ComplexStructBuilder::poison(rewriter, loc, structType);
115 complexStruct.setReal(rewriter, loc, adaptor.getReal());
116 complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary());
117
118 rewriter.replaceOp(complexOp, {complexStruct});
119 return success();
120 }
121};
122
123struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
124 using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;
125
126 LogicalResult
127 matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
128 ConversionPatternRewriter &rewriter) const override {
129 // Extract real part from the complex number struct.
130 ComplexStructBuilder complexStruct(adaptor.getComplex());
131 Value real = complexStruct.real(rewriter, op.getLoc());
132 rewriter.replaceOp(op, real);
133
134 return success();
135 }
136};
137
138struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
139 using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;
140
141 LogicalResult
142 matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
143 ConversionPatternRewriter &rewriter) const override {
144 // Extract imaginary part from the complex number struct.
145 ComplexStructBuilder complexStruct(adaptor.getComplex());
146 Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
147 rewriter.replaceOp(op, imaginary);
148
149 return success();
150 }
151};
152
153struct BinaryComplexOperands {
154 std::complex<Value> lhs;
155 std::complex<Value> rhs;
156};
157
158template <typename OpTy>
159BinaryComplexOperands
160unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor,
161 ConversionPatternRewriter &rewriter) {
162 auto loc = op.getLoc();
163
164 // Extract real and imaginary values from operands.
165 BinaryComplexOperands unpacked;
166 ComplexStructBuilder lhs(adaptor.getLhs());
167 unpacked.lhs.real(lhs.real(rewriter, loc));
168 unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
169 ComplexStructBuilder rhs(adaptor.getRhs());
170 unpacked.rhs.real(rhs.real(rewriter, loc));
171 unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
172
173 return unpacked;
174}
175
176struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
177 using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;
178
179 LogicalResult
180 matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
181 ConversionPatternRewriter &rewriter) const override {
182 auto loc = op.getLoc();
183 BinaryComplexOperands arg =
184 unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
185
186 // Initialize complex number struct for result.
187 auto structType = typeConverter->convertType(op.getType());
188 auto result = ComplexStructBuilder::poison(rewriter, loc, structType);
189
190 // Emit IR to add complex numbers.
191 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
192 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
193 op.getContext(),
194 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
195 Value real = LLVM::FAddOp::create(rewriter, loc, arg.lhs.real(),
196 arg.rhs.real(), fmf);
197 Value imag = LLVM::FAddOp::create(rewriter, loc, arg.lhs.imag(),
198 arg.rhs.imag(), fmf);
199 result.setReal(rewriter, loc, real);
200 result.setImaginary(rewriter, loc, imag);
201
202 rewriter.replaceOp(op, {result});
203 return success();
204 }
205};
206
207struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
208 DivOpConversion(const LLVMTypeConverter &converter,
209 complex::ComplexRangeFlags target)
210 : ConvertOpToLLVMPattern<complex::DivOp>(converter),
211 complexRange(target) {}
212
213 using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern;
214
215 LogicalResult
216 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
217 ConversionPatternRewriter &rewriter) const override {
218 auto loc = op.getLoc();
219 BinaryComplexOperands arg =
220 unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
221
222 // Initialize complex number struct for result.
223 auto structType = typeConverter->convertType(op.getType());
224 auto result = ComplexStructBuilder::poison(rewriter, loc, structType);
225
226 // Emit IR to add complex numbers.
227 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
228 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
229 op.getContext(),
230 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
231 Value rhsRe = arg.rhs.real();
232 Value rhsIm = arg.rhs.imag();
233 Value lhsRe = arg.lhs.real();
234 Value lhsIm = arg.lhs.imag();
235
236 Value resultRe, resultIm;
237
238 if (complexRange == complex::ComplexRangeFlags::basic ||
239 complexRange == complex::ComplexRangeFlags::none) {
241 rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm);
242 } else if (complexRange == complex::ComplexRangeFlags::improved) {
244 rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm);
245 }
246
247 result.setReal(rewriter, loc, resultRe);
248 result.setImaginary(rewriter, loc, resultIm);
249
250 rewriter.replaceOp(op, {result});
251 return success();
252 }
253
254private:
255 complex::ComplexRangeFlags complexRange;
256};
257
258struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
259 using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern;
260
261 LogicalResult
262 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
263 ConversionPatternRewriter &rewriter) const override {
264 auto loc = op.getLoc();
265 BinaryComplexOperands arg =
266 unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
267
268 // Initialize complex number struct for result.
269 auto structType = typeConverter->convertType(op.getType());
270 auto result = ComplexStructBuilder::poison(rewriter, loc, structType);
271
272 // Emit IR to add complex numbers.
273 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
274 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
275 op.getContext(),
276 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
277 Value rhsRe = arg.rhs.real();
278 Value rhsIm = arg.rhs.imag();
279 Value lhsRe = arg.lhs.real();
280 Value lhsIm = arg.lhs.imag();
281
282 Value real = LLVM::FSubOp::create(
283 rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, lhsRe, fmf),
284 LLVM::FMulOp::create(rewriter, loc, rhsIm, lhsIm, fmf), fmf);
285
286 Value imag = LLVM::FAddOp::create(
287 rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
288 LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
289
290 result.setReal(rewriter, loc, real);
291 result.setImaginary(rewriter, loc, imag);
292
293 rewriter.replaceOp(op, {result});
294 return success();
295 }
296};
297
298struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
299 using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
300
301 LogicalResult
302 matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
303 ConversionPatternRewriter &rewriter) const override {
304 auto loc = op.getLoc();
305 BinaryComplexOperands arg =
306 unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
307
308 // Initialize complex number struct for result.
309 auto structType = typeConverter->convertType(op.getType());
310 auto result = ComplexStructBuilder::poison(rewriter, loc, structType);
311
312 // Emit IR to substract complex numbers.
313 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
314 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
315 op.getContext(),
316 convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
317 Value real = LLVM::FSubOp::create(rewriter, loc, arg.lhs.real(),
318 arg.rhs.real(), fmf);
319 Value imag = LLVM::FSubOp::create(rewriter, loc, arg.lhs.imag(),
320 arg.rhs.imag(), fmf);
321 result.setReal(rewriter, loc, real);
322 result.setImaginary(rewriter, loc, imag);
323
324 rewriter.replaceOp(op, {result});
325 return success();
326 }
327};
328} // namespace
329
332 complex::ComplexRangeFlags complexRange) {
333 // clang-format off
334 patterns.add<
335 AbsOpConversion,
336 AddOpConversion,
337 ConstantOpLowering,
338 CreateOpConversion,
339 ImOpConversion,
340 MulOpConversion,
341 ReOpConversion,
342 SubOpConversion
343 >(converter);
344
345 patterns.add<DivOpConversion>(converter, complexRange);
346 // clang-format on
347}
348
349namespace {
350struct ConvertComplexToLLVMPass
351 : public impl::ConvertComplexToLLVMPassBase<ConvertComplexToLLVMPass> {
352 using Base::Base;
353
354 void runOnOperation() override;
355};
356} // namespace
357
358void ConvertComplexToLLVMPass::runOnOperation() {
359 // Convert to the LLVM IR dialect using the converter defined above.
361 LLVMTypeConverter converter(&getContext());
362 populateComplexToLLVMConversionPatterns(converter, patterns, complexRange);
363
365 target.addIllegalDialect<complex::ComplexDialect>();
366 if (failed(
367 applyPartialConversion(getOperation(), target, std::move(patterns))))
368 signalPassFailure();
369}
370
371//===----------------------------------------------------------------------===//
372// ConvertToLLVMPatternInterface implementation
373//===----------------------------------------------------------------------===//
374
375namespace {
376/// Implement the interface to convert MemRef to LLVM.
377struct ComplexToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
379 void loadDependentDialects(MLIRContext *context) const final {
380 context->loadDialect<LLVM::LLVMDialect>();
381 }
382
383 /// Hook for derived dialect interface to provide conversion patterns
384 /// and mark dialect legal for the conversion target.
385 void populateConvertToLLVMConversionPatterns(
386 ConversionTarget &target, LLVMTypeConverter &typeConverter,
387 RewritePatternSet &patterns) const final {
389 }
390};
391} // namespace
392
394 registry.addExtension(
395 +[](MLIRContext *ctx, complex::ComplexDialect *dialect) {
396 dialect->addInterfaces<ComplexToLLVMDialectInterface>();
397 });
398}
return success()
static constexpr unsigned kRealPosInComplexNumberStruct
static constexpr unsigned kImaginaryPosInComplexNumberStruct
lhs
b getContext())
Value imaginary(OpBuilder &builder, Location loc)
void setImaginary(OpBuilder &builder, Location loc, Value imaginary)
void setReal(OpBuilder &builder, Location loc, Value real)
Value real(OpBuilder &builder, Location loc)
static ComplexStructBuilder poison(OpBuilder &builder, Location loc, Type type)
Build IR creating an undef value of the complex number type.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition Pattern.h:207
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:213
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr)
Builds IR to set a value in the struct at position pos.
Value extractPtr(OpBuilder &builder, Location loc, unsigned pos) const
Builds IR to extract a value from the struct at position pos.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, Attribute propertiesAttr, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition Pattern.cpp:301
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
void convertDivToLLVMUsingRangeReduction(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the LLVM dialect using Smith's method
void convertDivToLLVMUsingAlgebraic(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the LLVM dialect using algebraic method
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
void populateComplexToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, mlir::complex::ComplexRangeFlags complexRange=mlir::complex::ComplexRangeFlags::basic)
Populate the given list with patterns that convert from Complex to LLVM.
const FrozenRewritePatternSet & patterns
void registerConvertComplexToLLVMInterface(DialectRegistry &registry)