MLIR  20.0.0git
ComplexToLLVM.cpp
Go to the documentation of this file.
1 //===- ComplexToLLVM.cpp - conversion from Complex to LLVM dialect --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/Pass/Pass.h"
19 
20 namespace mlir {
21 #define GEN_PASS_DEF_CONVERTCOMPLEXTOLLVMPASS
22 #include "mlir/Conversion/Passes.h.inc"
23 } // namespace mlir
24 
25 using namespace mlir;
26 using namespace mlir::LLVM;
27 using namespace mlir::arith;
28 
29 //===----------------------------------------------------------------------===//
30 // ComplexStructBuilder implementation.
31 //===----------------------------------------------------------------------===//
32 
33 static constexpr unsigned kRealPosInComplexNumberStruct = 0;
34 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
35 
37  Location loc, Type type) {
38  Value val = builder.create<LLVM::UndefOp>(loc, type);
39  return ComplexStructBuilder(val);
40 }
41 
43  Value real) {
44  setPtr(builder, loc, kRealPosInComplexNumberStruct, real);
45 }
46 
48  return extractPtr(builder, loc, kRealPosInComplexNumberStruct);
49 }
50 
52  Value imaginary) {
53  setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary);
54 }
55 
57  return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct);
58 }
59 
60 //===----------------------------------------------------------------------===//
61 // Conversion patterns.
62 //===----------------------------------------------------------------------===//
63 
64 namespace {
65 
66 struct AbsOpConversion : public ConvertOpToLLVMPattern<complex::AbsOp> {
68 
69  LogicalResult
70  matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
71  ConversionPatternRewriter &rewriter) const override {
72  auto loc = op.getLoc();
73 
74  ComplexStructBuilder complexStruct(adaptor.getComplex());
75  Value real = complexStruct.real(rewriter, op.getLoc());
76  Value imag = complexStruct.imaginary(rewriter, op.getLoc());
77 
78  arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
79  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
80  op.getContext(),
81  convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
82  Value sqNorm = rewriter.create<LLVM::FAddOp>(
83  loc, rewriter.create<LLVM::FMulOp>(loc, real, real, fmf),
84  rewriter.create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
85 
86  rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
87  return success();
88  }
89 };
90 
91 struct ConstantOpLowering : public ConvertOpToLLVMPattern<complex::ConstantOp> {
93 
94  LogicalResult
95  matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor,
96  ConversionPatternRewriter &rewriter) const override {
98  op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
99  op->getAttrs(), *getTypeConverter(), rewriter);
100  }
101 };
102 
103 struct CreateOpConversion : public ConvertOpToLLVMPattern<complex::CreateOp> {
105 
106  LogicalResult
107  matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
108  ConversionPatternRewriter &rewriter) const override {
109  // Pack real and imaginary part in a complex number struct.
110  auto loc = complexOp.getLoc();
111  auto structType = typeConverter->convertType(complexOp.getType());
112  auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType);
113  complexStruct.setReal(rewriter, loc, adaptor.getReal());
114  complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary());
115 
116  rewriter.replaceOp(complexOp, {complexStruct});
117  return success();
118  }
119 };
120 
121 struct ReOpConversion : public ConvertOpToLLVMPattern<complex::ReOp> {
123 
124  LogicalResult
125  matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
126  ConversionPatternRewriter &rewriter) const override {
127  // Extract real part from the complex number struct.
128  ComplexStructBuilder complexStruct(adaptor.getComplex());
129  Value real = complexStruct.real(rewriter, op.getLoc());
130  rewriter.replaceOp(op, real);
131 
132  return success();
133  }
134 };
135 
136 struct ImOpConversion : public ConvertOpToLLVMPattern<complex::ImOp> {
138 
139  LogicalResult
140  matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
141  ConversionPatternRewriter &rewriter) const override {
142  // Extract imaginary part from the complex number struct.
143  ComplexStructBuilder complexStruct(adaptor.getComplex());
144  Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
145  rewriter.replaceOp(op, imaginary);
146 
147  return success();
148  }
149 };
150 
151 struct BinaryComplexOperands {
152  std::complex<Value> lhs;
153  std::complex<Value> rhs;
154 };
155 
156 template <typename OpTy>
157 BinaryComplexOperands
158 unpackBinaryComplexOperands(OpTy op, typename OpTy::Adaptor adaptor,
159  ConversionPatternRewriter &rewriter) {
160  auto loc = op.getLoc();
161 
162  // Extract real and imaginary values from operands.
163  BinaryComplexOperands unpacked;
164  ComplexStructBuilder lhs(adaptor.getLhs());
165  unpacked.lhs.real(lhs.real(rewriter, loc));
166  unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
167  ComplexStructBuilder rhs(adaptor.getRhs());
168  unpacked.rhs.real(rhs.real(rewriter, loc));
169  unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
170 
171  return unpacked;
172 }
173 
174 struct AddOpConversion : public ConvertOpToLLVMPattern<complex::AddOp> {
176 
177  LogicalResult
178  matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
179  ConversionPatternRewriter &rewriter) const override {
180  auto loc = op.getLoc();
181  BinaryComplexOperands arg =
182  unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
183 
184  // Initialize complex number struct for result.
185  auto structType = typeConverter->convertType(op.getType());
186  auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
187 
188  // Emit IR to add complex numbers.
189  arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
190  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
191  op.getContext(),
192  convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
193  Value real =
194  rewriter.create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
195  Value imag =
196  rewriter.create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
197  result.setReal(rewriter, loc, real);
198  result.setImaginary(rewriter, loc, imag);
199 
200  rewriter.replaceOp(op, {result});
201  return success();
202  }
203 };
204 
205 struct DivOpConversion : public ConvertOpToLLVMPattern<complex::DivOp> {
207 
208  LogicalResult
209  matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
210  ConversionPatternRewriter &rewriter) const override {
211  auto loc = op.getLoc();
212  BinaryComplexOperands arg =
213  unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
214 
215  // Initialize complex number struct for result.
216  auto structType = typeConverter->convertType(op.getType());
217  auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
218 
219  // Emit IR to add complex numbers.
220  arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
221  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
222  op.getContext(),
223  convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
224  Value rhsRe = arg.rhs.real();
225  Value rhsIm = arg.rhs.imag();
226  Value lhsRe = arg.lhs.real();
227  Value lhsIm = arg.lhs.imag();
228 
229  Value rhsSqNorm = rewriter.create<LLVM::FAddOp>(
230  loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf),
231  rewriter.create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf);
232 
233  Value resultReal = rewriter.create<LLVM::FAddOp>(
234  loc, rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf),
235  rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf);
236 
237  Value resultImag = rewriter.create<LLVM::FSubOp>(
238  loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
239  rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
240 
241  result.setReal(
242  rewriter, loc,
243  rewriter.create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf));
244  result.setImaginary(
245  rewriter, loc,
246  rewriter.create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf));
247 
248  rewriter.replaceOp(op, {result});
249  return success();
250  }
251 };
252 
253 struct MulOpConversion : public ConvertOpToLLVMPattern<complex::MulOp> {
255 
256  LogicalResult
257  matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
258  ConversionPatternRewriter &rewriter) const override {
259  auto loc = op.getLoc();
260  BinaryComplexOperands arg =
261  unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
262 
263  // Initialize complex number struct for result.
264  auto structType = typeConverter->convertType(op.getType());
265  auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
266 
267  // Emit IR to add complex numbers.
268  arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
269  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
270  op.getContext(),
271  convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
272  Value rhsRe = arg.rhs.real();
273  Value rhsIm = arg.rhs.imag();
274  Value lhsRe = arg.lhs.real();
275  Value lhsIm = arg.lhs.imag();
276 
277  Value real = rewriter.create<LLVM::FSubOp>(
278  loc, rewriter.create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
279  rewriter.create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
280 
281  Value imag = rewriter.create<LLVM::FAddOp>(
282  loc, rewriter.create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
283  rewriter.create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
284 
285  result.setReal(rewriter, loc, real);
286  result.setImaginary(rewriter, loc, imag);
287 
288  rewriter.replaceOp(op, {result});
289  return success();
290  }
291 };
292 
293 struct SubOpConversion : public ConvertOpToLLVMPattern<complex::SubOp> {
295 
296  LogicalResult
297  matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
298  ConversionPatternRewriter &rewriter) const override {
299  auto loc = op.getLoc();
300  BinaryComplexOperands arg =
301  unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
302 
303  // Initialize complex number struct for result.
304  auto structType = typeConverter->convertType(op.getType());
305  auto result = ComplexStructBuilder::undef(rewriter, loc, structType);
306 
307  // Emit IR to substract complex numbers.
308  arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
309  LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
310  op.getContext(),
311  convertArithFastMathFlagsToLLVM(complexFMFAttr.getValue()));
312  Value real =
313  rewriter.create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
314  Value imag =
315  rewriter.create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
316  result.setReal(rewriter, loc, real);
317  result.setImaginary(rewriter, loc, imag);
318 
319  rewriter.replaceOp(op, {result});
320  return success();
321  }
322 };
323 } // namespace
324 
326  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
327  // clang-format off
328  patterns.add<
329  AbsOpConversion,
330  AddOpConversion,
331  ConstantOpLowering,
332  CreateOpConversion,
333  DivOpConversion,
334  ImOpConversion,
335  MulOpConversion,
336  ReOpConversion,
337  SubOpConversion
338  >(converter);
339  // clang-format on
340 }
341 
342 namespace {
343 struct ConvertComplexToLLVMPass
344  : public impl::ConvertComplexToLLVMPassBase<ConvertComplexToLLVMPass> {
345  using Base::Base;
346 
347  void runOnOperation() override;
348 };
349 } // namespace
350 
351 void ConvertComplexToLLVMPass::runOnOperation() {
352  // Convert to the LLVM IR dialect using the converter defined above.
353  RewritePatternSet patterns(&getContext());
354  LLVMTypeConverter converter(&getContext());
355  populateComplexToLLVMConversionPatterns(converter, patterns);
356 
358  target.addIllegalDialect<complex::ComplexDialect>();
359  if (failed(
360  applyPartialConversion(getOperation(), target, std::move(patterns))))
361  signalPassFailure();
362 }
363 
364 //===----------------------------------------------------------------------===//
365 // ConvertToLLVMPatternInterface implementation
366 //===----------------------------------------------------------------------===//
367 
368 namespace {
369 /// Implement the interface to convert MemRef to LLVM.
370 struct ComplexToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
372  void loadDependentDialects(MLIRContext *context) const final {
373  context->loadDialect<LLVM::LLVMDialect>();
374  }
375 
376  /// Hook for derived dialect interface to provide conversion patterns
377  /// and mark dialect legal for the conversion target.
378  void populateConvertToLLVMConversionPatterns(
379  ConversionTarget &target, LLVMTypeConverter &typeConverter,
380  RewritePatternSet &patterns) const final {
381  populateComplexToLLVMConversionPatterns(typeConverter, patterns);
382  }
383 };
384 } // namespace
385 
387  registry.addExtension(
388  +[](MLIRContext *ctx, complex::ComplexDialect *dialect) {
389  dialect->addInterfaces<ComplexToLLVMDialectInterface>();
390  });
391 }
static constexpr unsigned kRealPosInComplexNumberStruct
static constexpr unsigned kImaginaryPosInComplexNumberStruct
static MLIRContext * getContext(OpFoldResult val)
static ComplexStructBuilder undef(OpBuilder &builder, Location loc, Type type)
Build IR creating an undef value of the complex number type.
Value imaginary(OpBuilder &builder, Location loc)
void setImaginary(OpBuilder &builder, Location loc, Value imaginary)
void setReal(OpBuilder &builder, Location loc, Value real)
Value real(OpBuilder &builder, Location loc)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition: Pattern.h:147
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:212
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:476
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:340
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
Include the generated interface declarations.
void populateComplexToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from Complex to LLVM.
void registerConvertComplexToLLVMInterface(DialectRegistry &registry)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.