MLIR  16.0.0git
ComplexToLLVM.cpp
Go to the documentation of this file.
1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "../PassDetail.h"
17 
18 using namespace mlir;
19 using namespace mlir::LLVM;
20 
21 //===----------------------------------------------------------------------===//
22 // ComplexStructBuilder implementation.
23 //===----------------------------------------------------------------------===//
24 
25 static constexpr unsigned kRealPosInComplexNumberStruct = 0;
26 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
27 
29  Location loc, Type type) {
30  Value val = builder.create<LLVM::UndefOp>(loc, type);
31  return ComplexStructBuilder(val);
32 }
33 
35  Value real) {
36  setPtr(builder, loc, kRealPosInComplexNumberStruct, real);
37 }
38 
40  return extractPtr(builder, loc, kRealPosInComplexNumberStruct);
41 }
42 
44  Value imaginary) {
45  setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary);
46 }
47 
49  return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct);
50 }
51 
52 //===----------------------------------------------------------------------===//
53 // Conversion patterns.
54 //===----------------------------------------------------------------------===//
55 
56 namespace {
57 
58 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
60 
62  matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
63  ConversionPatternRewriter &rewriter) const override {
64  auto loc = op.getLoc();
65 
66  ComplexStructBuilder complexStruct(adaptor.getComplex());
67  Value real = complexStruct.real(rewriter, op.getLoc());
68  Value imag = complexStruct.imaginary(rewriter, op.getLoc());
69 
70  auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
71  Value sqNorm = rewriter.create<LLVM::FAddOp>(
72  loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
73  rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
74 
75  rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
76  return success();
77  }
78 };
79 
80 struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
82 
84  matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor,
85  ConversionPatternRewriter &rewriter) const override {
87  op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
88  *getTypeConverter(), rewriter);
89  }
90 };
91 
92 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
94 
96  matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
97  ConversionPatternRewriter &rewriter) const override {
98  // Pack real and imaginary part in a complex number struct.
99  auto loc = complexOp.getLoc();
100  auto structType = typeConverter->convertType(complexOp.getType());
101  auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
102  complexStruct.setReal(rewriter, loc, adaptor.getReal());
103  complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary());
104 
105  rewriter.replaceOp(complexOp, {complexStruct});
106  return success();
107  }
108 };
109 
110 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
112 
114  matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
115  ConversionPatternRewriter &rewriter) const override {
116  // Extract real part from the complex number struct.
117  ComplexStructBuilder complexStruct(adaptor.getComplex());
118  Value real = complexStruct.real(rewriter, op.getLoc());
119  rewriter.replaceOp(op, real);
120 
121  return success();
122  }
123 };
124 
125 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
127 
129  matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
130  ConversionPatternRewriter &rewriter) const override {
131  // Extract imaginary part from the complex number struct.
132  ComplexStructBuilder complexStruct(adaptor.getComplex());
133  Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
134  rewriter.replaceOp(op, imaginary);
135 
136  return success();
137  }
138 };
139 
140 struct BinaryComplexOperands {
141  std::complex<Value> lhs;
142  std::complex<Value> rhs;
143 };
144 
145 template <typename OpTy>
146 BinaryComplexOperands
147 unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor,
148  ConversionPatternRewriter &rewriter) {
149  auto loc = op.getLoc();
150 
151  // Extract real and imaginary values from operands.
152  BinaryComplexOperands unpacked;
153  ComplexStructBuilder lhs(adaptor.getLhs());
154  unpacked.lhs.real(lhs.real(rewriter, loc));
155  unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
156  ComplexStructBuilder rhs(adaptor.getRhs());
157  unpacked.rhs.real(rhs.real(rewriter, loc));
158  unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
159 
160  return unpacked;
161 }
162 
163 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
165 
167  matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
168  ConversionPatternRewriter &rewriter) const override {
169  auto loc = op.getLoc();
170  BinaryComplexOperands arg =
171  unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
172 
173  // Initialize complex number struct for result.
174  auto structType = typeConverter->convertType(op.getType());
175  auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
176 
177  // Emit IR to add complex numbers.
178  auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
179  Value real =
180  rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
181  Value imag =
182  rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
183  result.setReal(rewriter, loc, real);
184  result.setImaginary(rewriter, loc, imag);
185 
186  rewriter.replaceOp(op, {result});
187  return success();
188  }
189 };
190 
191 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
193 
195  matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
196  ConversionPatternRewriter &rewriter) const override {
197  auto loc = op.getLoc();
198  BinaryComplexOperands arg =
199  unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
200 
201  // Initialize complex number struct for result.
202  auto structType = typeConverter->convertType(op.getType());
203  auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
204 
205  // Emit IR to add complex numbers.
206  auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
207  Value rhsRe = arg.rhs.real();
208  Value rhsIm = arg.rhs.imag();
209  Value lhsRe = arg.lhs.real();
210  Value lhsIm = arg.lhs.imag();
211 
212  Value rhsSqNorm = rewriter.create<LLVM::FAddOp>(
213  loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf),
214  rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf);
215 
216  Value resultReal = rewriter.create<LLVM::FAddOp>(
217  loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf),
218  rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf);
219 
220  Value resultImag = rewriter.create<LLVM::FSubOp>(
221  loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
222  rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
223 
224  result.setReal(
225  rewriter, loc,
226  rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf));
227  result.setImaginary(
228  rewriter, loc,
229  rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf));
230 
231  rewriter.replaceOp(op, {result});
232  return success();
233  }
234 };
235 
236 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
238 
240  matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
241  ConversionPatternRewriter &rewriter) const override {
242  auto loc = op.getLoc();
243  BinaryComplexOperands arg =
244  unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
245 
246  // Initialize complex number struct for result.
247  auto structType = typeConverter->convertType(op.getType());
248  auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
249 
250  // Emit IR to add complex numbers.
251  auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
252  Value rhsRe = arg.rhs.real();
253  Value rhsIm = arg.rhs.imag();
254  Value lhsRe = arg.lhs.real();
255  Value lhsIm = arg.lhs.imag();
256 
257  Value real = rewriter.create<LLVM::FSubOp>(
258  loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
259  rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
260 
261  Value imag = rewriter.create<LLVM::FAddOp>(
262  loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
263  rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
264 
265  result.setReal(rewriter, loc, real);
266  result.setImaginary(rewriter, loc, imag);
267 
268  rewriter.replaceOp(op, {result});
269  return success();
270  }
271 };
272 
273 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
275 
277  matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
278  ConversionPatternRewriter &rewriter) const override {
279  auto loc = op.getLoc();
280  BinaryComplexOperands arg =
281  unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
282 
283  // Initialize complex number struct for result.
284  auto structType = typeConverter->convertType(op.getType());
285  auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
286 
287  // Emit IR to substract complex numbers.
288  auto fmf = LLVM::FMFAttr::get(op.getContext(), {});
289  Value real =
290  rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
291  Value imag =
292  rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
293  result.setReal(rewriter, loc, real);
294  result.setImaginary(rewriter, loc, imag);
295 
296  rewriter.replaceOp(op, {result});
297  return success();
298  }
299 };
300 } // namespace
301 
303  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
304  // clang-format off
305  patterns.add<
306  AbsOpConversion,
307  AddOpConversion,
308  ConstantOpLowering,
309  CreateOpConversion,
310  DivOpConversion,
311  ImOpConversion,
312  MulOpConversion,
313  ReOpConversion,
314  SubOpConversion
315  >(converter);
316  // clang-format on
317 }
318 
319 namespace {
320 struct ConvertComplexToLLVMPass
321  : public ConvertComplexToLLVMBase<ConvertComplexToLLVMPass> {
322  void runOnOperation() override;
323 };
324 } // namespace
325 
326 void ConvertComplexToLLVMPass::runOnOperation() {
327  // Convert to the LLVM IR dialect using the converter defined above.
328  RewritePatternSet patterns(&getContext());
329  LLVMTypeConverter converter(&getContext());
330  populateComplexToLLVMConversionPatterns(converter, patterns);
331 
332  LLVMConversionTarget target(getContext());
333  target.addIllegalDialect<complex::ComplexDialect>();
334  if (failed(
335  applyPartialConversion(getOperation(), target, std::move(patterns))))
336  signalPassFailure();
337 }
338 
339 std::unique_ptr<Pass> mlir::createConvertComplexToLLVMPass() {
340  return std::make_unique<ConvertComplexToLLVMPass>();
341 }
Include the generated interface declarations.
void populateComplexToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from Complex to LLVM.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:132
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
void setReal(OpBuilder &builder, Location loc, Value real)
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:136
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static constexpr unsigned kImaginaryPosInComplexNumberStruct
static ComplexStructBuilder undef(OpBuilder &builder, Location loc, Type type)
Build IR creating an undef value of the complex number type.
Derived class that automatically populates legalization information for different LLVM ops...
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
Value imaginary(OpBuilder &builder, Location loc)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Value real(OpBuilder &builder, Location loc)
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
std::unique_ptr< Pass > createConvertComplexToLLVMPass()
Create a pass to convert Complex operations to the LLVMIR dialect.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
static constexpr unsigned kRealPosInComplexNumberStruct
This class implements a pattern rewriter for use with ConversionPatterns.
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands...
Definition: Pattern.cpp:309
This class helps build Operations.
Definition: Builders.h:192
void setImaginary(OpBuilder &builder, Location loc, Value imaginary)