MLIR  22.0.0git
ComplexToLLVM.cpp
Go to the documentation of this file.
1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 
20 namespace mlir {
21 #define GEN_PASS_DEF_CONVERTCOMPLEXTOLLVMPASS
22 #include "mlir/Conversion/Passes.h.inc"
23 } // namespace mlir
24 
25 using namespace mlir;
26 using namespace mlir::LLVM;
27 using namespace mlir::arith;
28 
29 //===----------------------------------------------------------------------===//
30 // ComplexStructBuilder implementation.
31 //===----------------------------------------------------------------------===//
32 
33 static constexpr unsigned kRealPosInComplexNumberStruct = 0;
34 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
35 
37  Location loc, Type type) {
38  Value val = LLVM::PoisonOp::create(builder, loc, type);
39  return ComplexStructBuilder(val);
40 }
41 
43  Value real) {
44  setPtr(builder, loc, kRealPosInComplexNumberStruct, real);
45 }
46 
48  return extractPtr(builder, loc, kRealPosInComplexNumberStruct);
49 }
50 
52  Value imaginary) {
53  setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary);
54 }
55 
57  return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct);
58 }
59 
60 //===----------------------------------------------------------------------===//
61 // Conversion patterns.
62 //===----------------------------------------------------------------------===//
63 
64 namespace {
65 
66 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
68 
69  LogicalResult
70  matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
71  ConversionPatternRewriter &rewriter) const override {
72  auto loc = op.getLoc();
73 
74  ComplexStructBuilder complexStruct(adaptor.getComplex());
75  Value real = complexStruct.real(rewriter, op.getLoc());
76  Value imag = complexStruct.imaginary(rewriter, op.getLoc());
77 
78  arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
79  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
80  op.getContext(),
81  convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
82  Value sqNorm = LLVM::FAddOp::create(
83  rewriter, loc, LLVM::FMulOp::create(rewriter, loc, real, real, fmf),
84  LLVM::FMulOp::create(rewriter, loc, imag, imag, fmf), fmf);
85 
86  rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
87  return success();
88  }
89 };
90 
91 struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
93 
94  LogicalResult
95  matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor,
96  ConversionPatternRewriter &rewriter) const override {
98  op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
99  op->getAttrs(), *getTypeConverter(), rewriter);
100  }
101 };
102 
103 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
105 
106  LogicalResult
107  matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
108  ConversionPatternRewriter &rewriter) const override {
109  // Pack real and imaginary part in a complex number struct.
110  auto loc = complexOp.getLoc();
111  auto structType = typeConverter->convertType(complexOp.getType());
112  auto complexStruct =
113  ComplexStructBuilder::poison(rewriter, loc, structType);
114  complexStruct.setReal(rewriter, loc, adaptor.getReal());
115  complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary());
116 
117  rewriter.replaceOp(complexOp, {complexStruct});
118  return success();
119  }
120 };
121 
122 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
124 
125  LogicalResult
126  matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
127  ConversionPatternRewriter &rewriter) const override {
128  // Extract real part from the complex number struct.
129  ComplexStructBuilder complexStruct(adaptor.getComplex());
130  Value real = complexStruct.real(rewriter, op.getLoc());
131  rewriter.replaceOp(op, real);
132 
133  return success();
134  }
135 };
136 
137 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
139 
140  LogicalResult
141  matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
142  ConversionPatternRewriter &rewriter) const override {
143  // Extract imaginary part from the complex number struct.
144  ComplexStructBuilder complexStruct(adaptor.getComplex());
145  Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
146  rewriter.replaceOp(op, imaginary);
147 
148  return success();
149  }
150 };
151 
152 struct BinaryComplexOperands {
153  std::complex<Value> lhs;
154  std::complex<Value> rhs;
155 };
156 
157 template <typename OpTy>
158 BinaryComplexOperands
159 unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor,
160  ConversionPatternRewriter &rewriter) {
161  auto loc = op.getLoc();
162 
163  // Extract real and imaginary values from operands.
164  BinaryComplexOperands unpacked;
165  ComplexStructBuilder lhs(adaptor.getLhs());
166  unpacked.lhs.real(lhs.real(rewriter, loc));
167  unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
168  ComplexStructBuilder rhs(adaptor.getRhs());
169  unpacked.rhs.real(rhs.real(rewriter, loc));
170  unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
171 
172  return unpacked;
173 }
174 
175 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
177 
178  LogicalResult
179  matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
180  ConversionPatternRewriter &rewriter) const override {
181  auto loc = op.getLoc();
182  BinaryComplexOperands arg =
183  unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
184 
185  // Initialize complex number struct for result.
186  auto structType = typeConverter->convertType(op.getType());
187  auto result = ComplexStructBuilder::poison(rewriter, loc, structType);
188 
189  // Emit IR to add complex numbers.
190  arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
191  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
192  op.getContext(),
193  convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
194  Value real = LLVM::FAddOp::create(rewriter, loc, arg.lhs.real(),
195  arg.rhs.real(), fmf);
196  Value imag = LLVM::FAddOp::create(rewriter, loc, arg.lhs.imag(),
197  arg.rhs.imag(), fmf);
198  result.setReal(rewriter, loc, real);
199  result.setImaginary(rewriter, loc, imag);
200 
201  rewriter.replaceOp(op, {result});
202  return success();
203  }
204 };
205 
206 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
207  DivOpConversion(const LLVMTypeConverter &converter,
208  complex::ComplexRangeFlags target)
209  : ConvertOpToLLVMPattern<complex::DivOp>(converter),
210  complexRange(target) {}
211 
213 
214  LogicalResult
215  matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
216  ConversionPatternRewriter &rewriter) const override {
217  auto loc = op.getLoc();
218  BinaryComplexOperands arg =
219  unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
220 
221  // Initialize complex number struct for result.
222  auto structType = typeConverter->convertType(op.getType());
223  auto result = ComplexStructBuilder::poison(rewriter, loc, structType);
224 
225  // Emit IR to add complex numbers.
226  arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
227  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
228  op.getContext(),
229  convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
230  Value rhsRe = arg.rhs.real();
231  Value rhsIm = arg.rhs.imag();
232  Value lhsRe = arg.lhs.real();
233  Value lhsIm = arg.lhs.imag();
234 
235  Value resultRe, resultIm;
236 
237  if (complexRange == complex::ComplexRangeFlags::basic ||
238  complexRange == complex::ComplexRangeFlags::none) {
240  rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm);
241  } else if (complexRange == complex::ComplexRangeFlags::improved) {
243  rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm);
244  }
245 
246  result.setReal(rewriter, loc, resultRe);
247  result.setImaginary(rewriter, loc, resultIm);
248 
249  rewriter.replaceOp(op, {result});
250  return success();
251  }
252 
253 private:
254  complex::ComplexRangeFlags complexRange;
255 };
256 
257 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
259 
260  LogicalResult
261  matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
262  ConversionPatternRewriter &rewriter) const override {
263  auto loc = op.getLoc();
264  BinaryComplexOperands arg =
265  unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
266 
267  // Initialize complex number struct for result.
268  auto structType = typeConverter->convertType(op.getType());
269  auto result = ComplexStructBuilder::poison(rewriter, loc, structType);
270 
271  // Emit IR to add complex numbers.
272  arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
273  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
274  op.getContext(),
275  convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
276  Value rhsRe = arg.rhs.real();
277  Value rhsIm = arg.rhs.imag();
278  Value lhsRe = arg.lhs.real();
279  Value lhsIm = arg.lhs.imag();
280 
281  Value real = LLVM::FSubOp::create(
282  rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, lhsRe, fmf),
283  LLVM::FMulOp::create(rewriter, loc, rhsIm, lhsIm, fmf), fmf);
284 
285  Value imag = LLVM::FAddOp::create(
286  rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
287  LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
288 
289  result.setReal(rewriter, loc, real);
290  result.setImaginary(rewriter, loc, imag);
291 
292  rewriter.replaceOp(op, {result});
293  return success();
294  }
295 };
296 
297 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
299 
300  LogicalResult
301  matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
302  ConversionPatternRewriter &rewriter) const override {
303  auto loc = op.getLoc();
304  BinaryComplexOperands arg =
305  unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
306 
307  // Initialize complex number struct for result.
308  auto structType = typeConverter->convertType(op.getType());
309  auto result = ComplexStructBuilder::poison(rewriter, loc, structType);
310 
311  // Emit IR to substract complex numbers.
312  arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
313  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
314  op.getContext(),
315  convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
316  Value real = LLVM::FSubOp::create(rewriter, loc, arg.lhs.real(),
317  arg.rhs.real(), fmf);
318  Value imag = LLVM::FSubOp::create(rewriter, loc, arg.lhs.imag(),
319  arg.rhs.imag(), fmf);
320  result.setReal(rewriter, loc, real);
321  result.setImaginary(rewriter, loc, imag);
322 
323  rewriter.replaceOp(op, {result});
324  return success();
325  }
326 };
327 } // namespace
328 
330  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
331  complex::ComplexRangeFlags complexRange) {
332  // clang-format off
333  patterns.add<
334  AbsOpConversion,
335  AddOpConversion,
336  ConstantOpLowering,
337  CreateOpConversion,
338  ImOpConversion,
339  MulOpConversion,
340  ReOpConversion,
341  SubOpConversion
342  >(converter);
343 
344  patterns.add<DivOpConversion>(converter, complexRange);
345  // clang-format on
346 }
347 
348 namespace {
349 struct ConvertComplexToLLVMPass
350  : public impl::ConvertComplexToLLVMPassBase<ConvertComplexToLLVMPass> {
351  using Base::Base;
352 
353  void runOnOperation() override;
354 };
355 } // namespace
356 
357 void ConvertComplexToLLVMPass::runOnOperation() {
358  // Convert to the LLVM IR dialect using the converter defined above.
360  LLVMTypeConverter converter(&getContext());
361  populateComplexToLLVMConversionPatterns(converter, patterns, complexRange);
362 
364  target.addIllegalDialect<complex::ComplexDialect>();
365  if (failed(
366  applyPartialConversion(getOperation(), target, std::move(patterns))))
367  signalPassFailure();
368 }
369 
370 //===----------------------------------------------------------------------===//
371 // ConvertToLLVMPatternInterface implementation
372 //===----------------------------------------------------------------------===//
373 
374 namespace {
375 /// Implement the interface to convert MemRef to LLVM.
376 struct ComplexToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
378  void loadDependentDialects(MLIRContext *context) const final {
379  context->loadDialect<LLVM::LLVMDialect>();
380  }
381 
382  /// Hook for derived dialect interface to provide conversion patterns
383  /// and mark dialect legal for the conversion target.
384  void populateConvertToLLVMConversionPatterns(
385  ConversionTarget &target, LLVMTypeConverter &typeConverter,
386  RewritePatternSet &patterns) const final {
388  }
389 };
390 } // namespace
391 
393  registry.addExtension(
394  +[](MLIRContext *ctx, complex::ComplexDialect *dialect) {
395  dialect->addInterfaces<ComplexToLLVMDialectInterface>();
396  });
397 }
static constexpr unsigned kRealPosInComplexNumberStruct
static constexpr unsigned kImaginaryPosInComplexNumberStruct
static MLIRContext * getContext(OpFoldResult val)
Value imaginary(OpBuilder &builder, Location loc)
void setImaginary(OpBuilder &builder, Location loc, Value imaginary)
void setReal(OpBuilder &builder, Location loc, Value real)
Value real(OpBuilder &builder, Location loc)
static ComplexStructBuilder poison(OpBuilder &builder, Location loc, Type type)
Build IR creating an undef value of the complex number type.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:209
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:215
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:307
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
void convertDivToLLVMUsingRangeReduction(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the LLVM dialect using Smith's method
void convertDivToLLVMUsingAlgebraic(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the LLVM dialect using algebraic method
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
void populateComplexToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, mlir::complex::ComplexRangeFlags complexRange=mlir::complex::ComplexRangeFlags::basic)
Populate the given list with patterns that convert from Complex to LLVM.
const FrozenRewritePatternSet & patterns
void registerConvertComplexToLLVMInterface(DialectRegistry &registry)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.