MLIR  17.0.0git
ComplexToLLVM.cpp
Go to the documentation of this file.
1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/Pass/Pass.h"
17 
18 namespace mlir {
19 #define GEN_PASS_DEF_CONVERTCOMPLEXTOLLVM
20 #include "mlir/Conversion/Passes.h.inc"
21 } // namespace mlir
22 
23 using namespace mlir;
24 using namespace mlir::LLVM;
25 
26 //===----------------------------------------------------------------------===//
27 // ComplexStructBuilder implementation.
28 //===----------------------------------------------------------------------===//
29 
30 static constexpr unsigned kRealPosInComplexNumberStruct = 0;
31 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
32 
34  Location loc, Type type) {
35  Value val = builder.create<LLVM::UndefOp>(loc, type);
36  return ComplexStructBuilder(val);
37 }
38 
40  Value real) {
41  setPtr(builder, loc, kRealPosInComplexNumberStruct, real);
42 }
43 
45  return extractPtr(builder, loc, kRealPosInComplexNumberStruct);
46 }
47 
49  Value imaginary) {
50  setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary);
51 }
52 
54  return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct);
55 }
56 
57 //===----------------------------------------------------------------------===//
58 // Conversion patterns.
59 //===----------------------------------------------------------------------===//
60 
61 namespace {
62 
63 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
65 
67  matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
68  ConversionPatternRewriter &rewriter) const override {
69  auto loc = op.getLoc();
70 
71  ComplexStructBuilder complexStruct(adaptor.getComplex());
72  Value real = complexStruct.real(rewriter, op.getLoc());
73  Value imag = complexStruct.imaginary(rewriter, op.getLoc());
74 
75  auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
76  Value sqNorm = rewriter.create<LLVM::FAddOp>(
77  loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
78  rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
79 
80  rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
81  return success();
82  }
83 };
84 
85 struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
87 
89  matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor,
90  ConversionPatternRewriter &rewriter) const override {
92  op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
93  op->getAttrs(), *getTypeConverter(), rewriter);
94  }
95 };
96 
97 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
99 
101  matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
102  ConversionPatternRewriter &rewriter) const override {
103  // Pack real and imaginary part in a complex number struct.
104  auto loc = complexOp.getLoc();
105  auto structType = typeConverter->convertType(complexOp.getType());
106  auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
107  complexStruct.setReal(rewriter, loc, adaptor.getReal());
108  complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary());
109 
110  rewriter.replaceOp(complexOp, {complexStruct});
111  return success();
112  }
113 };
114 
115 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
117 
119  matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
120  ConversionPatternRewriter &rewriter) const override {
121  // Extract real part from the complex number struct.
122  ComplexStructBuilder complexStruct(adaptor.getComplex());
123  Value real = complexStruct.real(rewriter, op.getLoc());
124  rewriter.replaceOp(op, real);
125 
126  return success();
127  }
128 };
129 
130 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
132 
134  matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
135  ConversionPatternRewriter &rewriter) const override {
136  // Extract imaginary part from the complex number struct.
137  ComplexStructBuilder complexStruct(adaptor.getComplex());
138  Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
139  rewriter.replaceOp(op, imaginary);
140 
141  return success();
142  }
143 };
144 
145 struct BinaryComplexOperands {
146  std::complex<Value> lhs;
147  std::complex<Value> rhs;
148 };
149 
150 template <typename OpTy>
151 BinaryComplexOperands
152 unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor,
153  ConversionPatternRewriter &rewriter) {
154  auto loc = op.getLoc();
155 
156  // Extract real and imaginary values from operands.
157  BinaryComplexOperands unpacked;
158  ComplexStructBuilder lhs(adaptor.getLhs());
159  unpacked.lhs.real(lhs.real(rewriter, loc));
160  unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
161  ComplexStructBuilder rhs(adaptor.getRhs());
162  unpacked.rhs.real(rhs.real(rewriter, loc));
163  unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
164 
165  return unpacked;
166 }
167 
168 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
170 
172  matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
173  ConversionPatternRewriter &rewriter) const override {
174  auto loc = op.getLoc();
175  BinaryComplexOperands arg =
176  unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
177 
178  // Initialize complex number struct for result.
179  auto structType = typeConverter->convertType(op.getType());
180  auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
181 
182  // Emit IR to add complex numbers.
183  auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
184  Value real =
185  rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
186  Value imag =
187  rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
188  result.setReal(rewriter, loc, real);
189  result.setImaginary(rewriter, loc, imag);
190 
191  rewriter.replaceOp(op, {result});
192  return success();
193  }
194 };
195 
196 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
198 
200  matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
201  ConversionPatternRewriter &rewriter) const override {
202  auto loc = op.getLoc();
203  BinaryComplexOperands arg =
204  unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
205 
206  // Initialize complex number struct for result.
207  auto structType = typeConverter->convertType(op.getType());
208  auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
209 
210  // Emit IR to add complex numbers.
211  auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
212  Value rhsRe = arg.rhs.real();
213  Value rhsIm = arg.rhs.imag();
214  Value lhsRe = arg.lhs.real();
215  Value lhsIm = arg.lhs.imag();
216 
217  Value rhsSqNorm = rewriter.create<LLVM::FAddOp>(
218  loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf),
219  rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf);
220 
221  Value resultReal = rewriter.create<LLVM::FAddOp>(
222  loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf),
223  rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf);
224 
225  Value resultImag = rewriter.create<LLVM::FSubOp>(
226  loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
227  rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
228 
229  result.setReal(
230  rewriter, loc,
231  rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf));
232  result.setImaginary(
233  rewriter, loc,
234  rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf));
235 
236  rewriter.replaceOp(op, {result});
237  return success();
238  }
239 };
240 
241 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
243 
245  matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
246  ConversionPatternRewriter &rewriter) const override {
247  auto loc = op.getLoc();
248  BinaryComplexOperands arg =
249  unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
250 
251  // Initialize complex number struct for result.
252  auto structType = typeConverter->convertType(op.getType());
253  auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
254 
255  // Emit IR to add complex numbers.
256  auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
257  Value rhsRe = arg.rhs.real();
258  Value rhsIm = arg.rhs.imag();
259  Value lhsRe = arg.lhs.real();
260  Value lhsIm = arg.lhs.imag();
261 
262  Value real = rewriter.create<LLVM::FSubOp>(
263  loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
264  rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
265 
266  Value imag = rewriter.create<LLVM::FAddOp>(
267  loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
268  rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
269 
270  result.setReal(rewriter, loc, real);
271  result.setImaginary(rewriter, loc, imag);
272 
273  rewriter.replaceOp(op, {result});
274  return success();
275  }
276 };
277 
278 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
280 
282  matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
283  ConversionPatternRewriter &rewriter) const override {
284  auto loc = op.getLoc();
285  BinaryComplexOperands arg =
286  unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
287 
288  // Initialize complex number struct for result.
289  auto structType = typeConverter->convertType(op.getType());
290  auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
291 
292  // Emit IR to substract complex numbers.
293  auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {});
294  Value real =
295  rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
296  Value imag =
297  rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
298  result.setReal(rewriter, loc, real);
299  result.setImaginary(rewriter, loc, imag);
300 
301  rewriter.replaceOp(op, {result});
302  return success();
303  }
304 };
305 } // namespace
306 
308  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
309  // clang-format off
310  patterns.add<
311  AbsOpConversion,
312  AddOpConversion,
313  ConstantOpLowering,
314  CreateOpConversion,
315  DivOpConversion,
316  ImOpConversion,
317  MulOpConversion,
318  ReOpConversion,
319  SubOpConversion
320  >(converter);
321  // clang-format on
322 }
323 
324 namespace {
325 struct ConvertComplexToLLVMPass
326  : public impl::ConvertComplexToLLVMBase<ConvertComplexToLLVMPass> {
327  void runOnOperation() override;
328 };
329 } // namespace
330 
331 void ConvertComplexToLLVMPass::runOnOperation() {
332  // Convert to the LLVM IR dialect using the converter defined above.
333  RewritePatternSet patterns(&getContext());
334  LLVMTypeConverter converter(&getContext());
335  populateComplexToLLVMConversionPatterns(converter, patterns);
336 
337  LLVMConversionTarget target(getContext());
338  target.addIllegalDialect<complex::ComplexDialect>();
339  if (failed(
340  applyPartialConversion(getOperation(), target, std::move(patterns))))
341  signalPassFailure();
342 }
343 
344 std::unique_ptr<Pass> mlir::createConvertComplexToLLVMPass() {
345  return std::make_unique<ConvertComplexToLLVMPass>();
346 }
static constexpr unsigned kRealPosInComplexNumberStruct
static constexpr unsigned kImaginaryPosInComplexNumberStruct
static ComplexStructBuilder undef(OpBuilder &builder, Location loc, Type type)
Build IR creating an undef value of the complex number type.
Value imaginary(OpBuilder &builder, Location loc)
void setImaginary(OpBuilder &builder, Location loc, Value imaginary)
void setReal(OpBuilder &builder, Location loc, Value real)
Value real(OpBuilder &builder, Location loc)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:135
ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:139
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:199
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:305
Include the generated interface declarations.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
std::unique_ptr< Pass > createConvertComplexToLLVMPass()
Create a pass to convert Complex operations to the LLVMIR dialect.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateComplexToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from Complex to LLVM.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26