MLIR  21.0.0git
ComplexToLLVM.cpp
Go to the documentation of this file.
1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/Pass/Pass.h"
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_CONVERTCOMPLEXTOLLVMPASS
23 #include "mlir/Conversion/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 using namespace mlir::LLVM;
28 using namespace mlir::arith;
29 
30 //===----------------------------------------------------------------------===//
31 // ComplexStructBuilder implementation.
32 //===----------------------------------------------------------------------===//
33 
34 static constexpr unsigned kRealPosInComplexNumberStruct = 0;
35 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
36 
38  Location loc, Type type) {
39  Value val = builder.create<LLVM::PoisonOp>(loc, type);
40  return ComplexStructBuilder(val);
41 }
42 
44  Value real) {
45  setPtr(builder, loc, kRealPosInComplexNumberStruct, real);
46 }
47 
49  return extractPtr(builder, loc, kRealPosInComplexNumberStruct);
50 }
51 
53  Value imaginary) {
54  setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary);
55 }
56 
58  return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct);
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // Conversion patterns.
63 //===----------------------------------------------------------------------===//
64 
65 namespace {
66 
67 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
69 
70  LogicalResult
71  matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
72  ConversionPatternRewriter &rewriter) const override {
73  auto loc = op.getLoc();
74 
75  ComplexStructBuilder complexStruct(adaptor.getComplex());
76  Value real = complexStruct.real(rewriter, op.getLoc());
77  Value imag = complexStruct.imaginary(rewriter, op.getLoc());
78 
79  arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
80  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
81  op.getContext(),
82  convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
83  Value sqNorm = rewriter.create<LLVM::FAddOp>(
84  loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
85  rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
86 
87  rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
88  return success();
89  }
90 };
91 
92 struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
94 
95  LogicalResult
96  matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor,
97  ConversionPatternRewriter &rewriter) const override {
99  op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
100  op->getAttrs(), *getTypeConverter(), rewriter);
101  }
102 };
103 
104 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
106 
107  LogicalResult
108  matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
109  ConversionPatternRewriter &rewriter) const override {
110  // Pack real and imaginary part in a complex number struct.
111  auto loc = complexOp.getLoc();
112  auto structType = typeConverter->convertType(complexOp.getType());
113  auto complexStruct =
114  ComplexStructBuilder::poison(rewriter, loc, structType);
115  complexStruct.setReal(rewriter, loc, adaptor.getReal());
116  complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary());
117 
118  rewriter.replaceOp(complexOp, {complexStruct});
119  return success();
120  }
121 };
122 
123 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
125 
126  LogicalResult
127  matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
128  ConversionPatternRewriter &rewriter) const override {
129  // Extract real part from the complex number struct.
130  ComplexStructBuilder complexStruct(adaptor.getComplex());
131  Value real = complexStruct.real(rewriter, op.getLoc());
132  rewriter.replaceOp(op, real);
133 
134  return success();
135  }
136 };
137 
138 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
140 
141  LogicalResult
142  matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
143  ConversionPatternRewriter &rewriter) const override {
144  // Extract imaginary part from the complex number struct.
145  ComplexStructBuilder complexStruct(adaptor.getComplex());
146  Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
147  rewriter.replaceOp(op, imaginary);
148 
149  return success();
150  }
151 };
152 
153 struct BinaryComplexOperands {
154  std::complex<Value> lhs;
155  std::complex<Value> rhs;
156 };
157 
158 template <typename OpTy>
159 BinaryComplexOperands
160 unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor,
161  ConversionPatternRewriter &rewriter) {
162  auto loc = op.getLoc();
163 
164  // Extract real and imaginary values from operands.
165  BinaryComplexOperands unpacked;
166  ComplexStructBuilder lhs(adaptor.getLhs());
167  unpacked.lhs.real(lhs.real(rewriter, loc));
168  unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
169  ComplexStructBuilder rhs(adaptor.getRhs());
170  unpacked.rhs.real(rhs.real(rewriter, loc));
171  unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
172 
173  return unpacked;
174 }
175 
176 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
178 
179  LogicalResult
180  matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
181  ConversionPatternRewriter &rewriter) const override {
182  auto loc = op.getLoc();
183  BinaryComplexOperands arg =
184  unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
185 
186  // Initialize complex number struct for result.
187  auto structType = typeConverter->convertType(op.getType());
188  auto result = ComplexStructBuilder::poison(rewriter, loc, structType);
189 
190  // Emit IR to add complex numbers.
191  arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
192  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
193  op.getContext(),
194  convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
195  Value real =
196  rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
197  Value imag =
198  rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
199  result.setReal(rewriter, loc, real);
200  result.setImaginary(rewriter, loc, imag);
201 
202  rewriter.replaceOp(op, {result});
203  return success();
204  }
205 };
206 
207 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
208  DivOpConversion(const LLVMTypeConverter &converter,
209  complex::ComplexRangeFlags target)
210  : ConvertOpToLLVMPattern<complex::DivOp>(converter),
211  complexRange(target) {}
212 
214 
215  LogicalResult
216  matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
217  ConversionPatternRewriter &rewriter) const override {
218  auto loc = op.getLoc();
219  BinaryComplexOperands arg =
220  unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
221 
222  // Initialize complex number struct for result.
223  auto structType = typeConverter->convertType(op.getType());
224  auto result = ComplexStructBuilder::poison(rewriter, loc, structType);
225 
226  // Emit IR to add complex numbers.
227  arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
228  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
229  op.getContext(),
230  convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
231  Value rhsRe = arg.rhs.real();
232  Value rhsIm = arg.rhs.imag();
233  Value lhsRe = arg.lhs.real();
234  Value lhsIm = arg.lhs.imag();
235 
236  Value resultRe, resultIm;
237 
238  if (complexRange == complex::ComplexRangeFlags::basic ||
239  complexRange == complex::ComplexRangeFlags::none) {
241  rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm);
242  } else if (complexRange == complex::ComplexRangeFlags::improved) {
244  rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm);
245  }
246 
247  result.setReal(rewriter, loc, resultRe);
248  result.setImaginary(rewriter, loc, resultIm);
249 
250  rewriter.replaceOp(op, {result});
251  return success();
252  }
253 
254 private:
255  complex::ComplexRangeFlags complexRange;
256 };
257 
258 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
260 
261  LogicalResult
262  matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
263  ConversionPatternRewriter &rewriter) const override {
264  auto loc = op.getLoc();
265  BinaryComplexOperands arg =
266  unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
267 
268  // Initialize complex number struct for result.
269  auto structType = typeConverter->convertType(op.getType());
270  auto result = ComplexStructBuilder::poison(rewriter, loc, structType);
271 
272  // Emit IR to add complex numbers.
273  arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
274  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
275  op.getContext(),
276  convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
277  Value rhsRe = arg.rhs.real();
278  Value rhsIm = arg.rhs.imag();
279  Value lhsRe = arg.lhs.real();
280  Value lhsIm = arg.lhs.imag();
281 
282  Value real = rewriter.create<LLVM::FSubOp>(
283  loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
284  rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
285 
286  Value imag = rewriter.create<LLVM::FAddOp>(
287  loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
288  rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
289 
290  result.setReal(rewriter, loc, real);
291  result.setImaginary(rewriter, loc, imag);
292 
293  rewriter.replaceOp(op, {result});
294  return success();
295  }
296 };
297 
298 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
300 
301  LogicalResult
302  matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
303  ConversionPatternRewriter &rewriter) const override {
304  auto loc = op.getLoc();
305  BinaryComplexOperands arg =
306  unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
307 
308  // Initialize complex number struct for result.
309  auto structType = typeConverter->convertType(op.getType());
310  auto result = ComplexStructBuilder::poison(rewriter, loc, structType);
311 
312  // Emit IR to substract complex numbers.
313  arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
314  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
315  op.getContext(),
316  convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
317  Value real =
318  rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
319  Value imag =
320  rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
321  result.setReal(rewriter, loc, real);
322  result.setImaginary(rewriter, loc, imag);
323 
324  rewriter.replaceOp(op, {result});
325  return success();
326  }
327 };
328 } // namespace
329 
331  const LLVMTypeConverter &converter, RewritePatternSet &patterns,
332  complex::ComplexRangeFlags complexRange) {
333  // clang-format off
334  patterns.add<
335  AbsOpConversion,
336  AddOpConversion,
337  ConstantOpLowering,
338  CreateOpConversion,
339  ImOpConversion,
340  MulOpConversion,
341  ReOpConversion,
342  SubOpConversion
343  >(converter);
344 
345  patterns.add<DivOpConversion>(converter, complexRange);
346  // clang-format on
347 }
348 
349 namespace {
350 struct ConvertComplexToLLVMPass
351  : public impl::ConvertComplexToLLVMPassBase<ConvertComplexToLLVMPass> {
352  using Base::Base;
353 
354  void runOnOperation() override;
355 };
356 } // namespace
357 
358 void ConvertComplexToLLVMPass::runOnOperation() {
359  // Convert to the LLVM IR dialect using the converter defined above.
361  LLVMTypeConverter converter(&getContext());
362  populateComplexToLLVMConversionPatterns(converter, patterns, complexRange);
363 
365  target.addIllegalDialect<complex::ComplexDialect>();
366  if (failed(
367  applyPartialConversion(getOperation(), target, std::move(patterns))))
368  signalPassFailure();
369 }
370 
371 //===----------------------------------------------------------------------===//
372 // ConvertToLLVMPatternInterface implementation
373 //===----------------------------------------------------------------------===//
374 
375 namespace {
376 /// Implement the interface to convert MemRef to LLVM.
377 struct ComplexToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
379  void loadDependentDialects(MLIRContext *context) const final {
380  context->loadDialect<LLVM::LLVMDialect>();
381  }
382 
383  /// Hook for derived dialect interface to provide conversion patterns
384  /// and mark dialect legal for the conversion target.
385  void populateConvertToLLVMConversionPatterns(
386  ConversionTarget &target, LLVMTypeConverter &typeConverter,
387  RewritePatternSet &patterns) const final {
389  }
390 };
391 } // namespace
392 
394  registry.addExtension(
395  +[](MLIRContext *ctx, complex::ComplexDialect *dialect) {
396  dialect->addInterfaces<ComplexToLLVMDialectInterface>();
397  });
398 }
static constexpr unsigned kRealPosInComplexNumberStruct
static constexpr unsigned kImaginaryPosInComplexNumberStruct
static MLIRContext * getContext(OpFoldResult val)
Value imaginary(OpBuilder &builder, Location loc)
void setImaginary(OpBuilder &builder, Location loc, Value imaginary)
void setReal(OpBuilder &builder, Location loc, Value real)
Value real(OpBuilder &builder, Location loc)
static ComplexStructBuilder poison(OpBuilder &builder, Location loc, Type type)
Build IR creating an undef value of the complex number type.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:155
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:161
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:345
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
void convertDivToLLVMUsingRangeReduction(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the LLVM dialect using Smith's method
void convertDivToLLVMUsingAlgebraic(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the LLVM dialect using algebraic method
Include the generated interface declarations.
void populateComplexToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, mlir::complex::ComplexRangeFlags complexRange=mlir::complex::ComplexRangeFlags::basic)
Populate the given list with patterns that convert from Complex to LLVM.
const FrozenRewritePatternSet & patterns
void registerConvertComplexToLLVMInterface(DialectRegistry &registry)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.