21#define GEN_PASS_DEF_CONVERTCOMPLEXTOLLVMPASS
22#include "mlir/Conversion/Passes.h.inc"
38 Value val = LLVM::PoisonOp::create(builder, loc, type);
70 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
71 ConversionPatternRewriter &rewriter)
const override {
72 auto loc = op.getLoc();
75 Value real = complexStruct.real(rewriter, op.getLoc());
76 Value imag = complexStruct.imaginary(rewriter, op.getLoc());
78 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
79 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
82 Value sqNorm = LLVM::FAddOp::create(
83 rewriter, loc, LLVM::FMulOp::create(rewriter, loc, real, real, fmf),
84 LLVM::FMulOp::create(rewriter, loc, imag, imag, fmf), fmf);
86 rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
95 matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor,
96 ConversionPatternRewriter &rewriter)
const override {
98 op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
99 op->getAttrs(), *getTypeConverter(), rewriter);
104 using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
107 matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
108 ConversionPatternRewriter &rewriter)
const override {
110 auto loc = complexOp.getLoc();
111 auto structType = typeConverter->convertType(complexOp.getType());
114 complexStruct.setReal(rewriter, loc, adaptor.getReal());
115 complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary());
117 rewriter.replaceOp(complexOp, {complexStruct});
123 using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;
126 matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
127 ConversionPatternRewriter &rewriter)
const override {
129 ComplexStructBuilder complexStruct(adaptor.getComplex());
130 Value real = complexStruct.real(rewriter, op.getLoc());
131 rewriter.replaceOp(op, real);
138 using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;
141 matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
142 ConversionPatternRewriter &rewriter)
const override {
144 ComplexStructBuilder complexStruct(adaptor.getComplex());
145 Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
146 rewriter.replaceOp(op, imaginary);
152struct BinaryComplexOperands {
153 std::complex<Value>
lhs;
154 std::complex<Value>
rhs;
157template <
typename OpTy>
159unpackBinaryComplexOperands(OpTy op,
typename OpTy::Adaptor adaptor,
160 ConversionPatternRewriter &rewriter) {
161 auto loc = op.getLoc();
164 BinaryComplexOperands unpacked;
166 unpacked.lhs.real(
lhs.real(rewriter, loc));
167 unpacked.lhs.imag(
lhs.imaginary(rewriter, loc));
169 unpacked.rhs.real(
rhs.real(rewriter, loc));
170 unpacked.rhs.imag(
rhs.imaginary(rewriter, loc));
176 using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;
179 matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
180 ConversionPatternRewriter &rewriter)
const override {
181 auto loc = op.getLoc();
182 BinaryComplexOperands arg =
183 unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
186 auto structType = typeConverter->convertType(op.getType());
190 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
191 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
194 Value real = LLVM::FAddOp::create(rewriter, loc, arg.lhs.real(),
195 arg.rhs.real(), fmf);
196 Value imag = LLVM::FAddOp::create(rewriter, loc, arg.lhs.imag(),
197 arg.rhs.imag(), fmf);
198 result.setReal(rewriter, loc, real);
199 result.setImaginary(rewriter, loc, imag);
201 rewriter.replaceOp(op, {
result});
207 DivOpConversion(
const LLVMTypeConverter &converter,
208 complex::ComplexRangeFlags
target)
209 : ConvertOpToLLVMPattern<complex::DivOp>(converter),
212 using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern;
215 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
216 ConversionPatternRewriter &rewriter)
const override {
217 auto loc = op.getLoc();
218 BinaryComplexOperands arg =
219 unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
222 auto structType = typeConverter->convertType(op.getType());
226 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
227 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
230 Value rhsRe = arg.rhs.real();
231 Value rhsIm = arg.rhs.imag();
232 Value lhsRe = arg.lhs.real();
233 Value lhsIm = arg.lhs.imag();
235 Value resultRe, resultIm;
237 if (complexRange == complex::ComplexRangeFlags::basic ||
238 complexRange == complex::ComplexRangeFlags::none) {
240 rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm);
241 }
else if (complexRange == complex::ComplexRangeFlags::improved) {
243 rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm);
246 result.setReal(rewriter, loc, resultRe);
247 result.setImaginary(rewriter, loc, resultIm);
249 rewriter.replaceOp(op, {
result});
254 complex::ComplexRangeFlags complexRange;
258 using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern;
261 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
262 ConversionPatternRewriter &rewriter)
const override {
263 auto loc = op.getLoc();
264 BinaryComplexOperands arg =
265 unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
268 auto structType = typeConverter->convertType(op.getType());
272 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
273 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
276 Value rhsRe = arg.rhs.real();
277 Value rhsIm = arg.rhs.imag();
278 Value lhsRe = arg.lhs.real();
279 Value lhsIm = arg.lhs.imag();
281 Value real = LLVM::FSubOp::create(
282 rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, lhsRe, fmf),
283 LLVM::FMulOp::create(rewriter, loc, rhsIm, lhsIm, fmf), fmf);
285 Value imag = LLVM::FAddOp::create(
286 rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
287 LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
289 result.setReal(rewriter, loc, real);
290 result.setImaginary(rewriter, loc, imag);
292 rewriter.replaceOp(op, {
result});
298 using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
301 matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
302 ConversionPatternRewriter &rewriter)
const override {
303 auto loc = op.getLoc();
304 BinaryComplexOperands arg =
305 unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
308 auto structType = typeConverter->convertType(op.getType());
312 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
313 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
316 Value real = LLVM::FSubOp::create(rewriter, loc, arg.lhs.real(),
317 arg.rhs.real(), fmf);
318 Value imag = LLVM::FSubOp::create(rewriter, loc, arg.lhs.imag(),
319 arg.rhs.imag(), fmf);
320 result.setReal(rewriter, loc, real);
321 result.setImaginary(rewriter, loc, imag);
323 rewriter.replaceOp(op, {
result});
331 complex::ComplexRangeFlags complexRange) {
344 patterns.add<DivOpConversion>(converter, complexRange);
349struct ConvertComplexToLLVMPass
353 void runOnOperation()
override;
357void ConvertComplexToLLVMPass::runOnOperation() {
364 target.addIllegalDialect<complex::ComplexDialect>();
366 applyPartialConversion(getOperation(),
target, std::move(
patterns))))
378 void loadDependentDialects(MLIRContext *context)
const final {
379 context->loadDialect<LLVM::LLVMDialect>();
384 void populateConvertToLLVMConversionPatterns(
385 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
386 RewritePatternSet &
patterns)
const final {
394 +[](
MLIRContext *ctx, complex::ComplexDialect *dialect) {
395 dialect->addInterfaces<ComplexToLLVMDialectInterface>();
static constexpr unsigned kRealPosInComplexNumberStruct
static constexpr unsigned kImaginaryPosInComplexNumberStruct
Value imaginary(OpBuilder &builder, Location loc)
void setImaginary(OpBuilder &builder, Location loc, Value imaginary)
void setReal(OpBuilder &builder, Location loc, Value real)
Value real(OpBuilder &builder, Location loc)
static ComplexStructBuilder poison(OpBuilder &builder, Location loc, Type type)
Build IR creating an undef value of the complex number type.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr)
Builds IR to set a value in the struct at position pos.
Value extractPtr(OpBuilder &builder, Location loc, unsigned pos) const
Builds IR to extract a value from the struct at position pos.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
void convertDivToLLVMUsingRangeReduction(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the LLVM dialect using Smith's method
void convertDivToLLVMUsingAlgebraic(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the LLVM dialect using algebraic method
Include the generated interface declarations.
void populateComplexToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, mlir::complex::ComplexRangeFlags complexRange=mlir::complex::ComplexRangeFlags::basic)
Populate the given list with patterns that convert from Complex to LLVM.
const FrozenRewritePatternSet & patterns
void registerConvertComplexToLLVMInterface(DialectRegistry ®istry)