21#define GEN_PASS_DEF_CONVERTCOMPLEXTOLLVMPASS
22#include "mlir/Conversion/Passes.h.inc"
38 Value val = LLVM::PoisonOp::create(builder, loc, type);
70 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
71 ConversionPatternRewriter &rewriter)
const override {
72 auto loc = op.getLoc();
75 Value real = complexStruct.real(rewriter, op.getLoc());
76 Value imag = complexStruct.imaginary(rewriter, op.getLoc());
78 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
79 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
82 Value sqNorm = LLVM::FAddOp::create(
83 rewriter, loc, LLVM::FMulOp::create(rewriter, loc, real, real, fmf),
84 LLVM::FMulOp::create(rewriter, loc, imag, imag, fmf), fmf);
86 rewriter.replaceOpWithNewOp<LLVM::SqrtOp>(op, sqNorm);
95 matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor,
96 ConversionPatternRewriter &rewriter)
const override {
98 op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
99 op->getAttrs(), Attribute{}, *getTypeConverter(),
105 using ConvertOpToLLVMPattern<complex::CreateOp>::ConvertOpToLLVMPattern;
108 matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
109 ConversionPatternRewriter &rewriter)
const override {
111 auto loc = complexOp.getLoc();
112 auto structType = typeConverter->convertType(complexOp.getType());
115 complexStruct.setReal(rewriter, loc, adaptor.getReal());
116 complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary());
118 rewriter.replaceOp(complexOp, {complexStruct});
124 using ConvertOpToLLVMPattern<complex::ReOp>::ConvertOpToLLVMPattern;
127 matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
128 ConversionPatternRewriter &rewriter)
const override {
130 ComplexStructBuilder complexStruct(adaptor.getComplex());
131 Value real = complexStruct.real(rewriter, op.getLoc());
132 rewriter.replaceOp(op, real);
139 using ConvertOpToLLVMPattern<complex::ImOp>::ConvertOpToLLVMPattern;
142 matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
143 ConversionPatternRewriter &rewriter)
const override {
145 ComplexStructBuilder complexStruct(adaptor.getComplex());
146 Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
147 rewriter.replaceOp(op, imaginary);
153struct BinaryComplexOperands {
154 std::complex<Value>
lhs;
155 std::complex<Value>
rhs;
158template <
typename OpTy>
160unpackBinaryComplexOperands(OpTy op,
typename OpTy::Adaptor adaptor,
161 ConversionPatternRewriter &rewriter) {
162 auto loc = op.getLoc();
165 BinaryComplexOperands unpacked;
167 unpacked.lhs.real(
lhs.real(rewriter, loc));
168 unpacked.lhs.imag(
lhs.imaginary(rewriter, loc));
170 unpacked.rhs.real(
rhs.real(rewriter, loc));
171 unpacked.rhs.imag(
rhs.imaginary(rewriter, loc));
177 using ConvertOpToLLVMPattern<complex::AddOp>::ConvertOpToLLVMPattern;
180 matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
181 ConversionPatternRewriter &rewriter)
const override {
182 auto loc = op.getLoc();
183 BinaryComplexOperands arg =
184 unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
187 auto structType = typeConverter->convertType(op.getType());
191 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
192 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
195 Value real = LLVM::FAddOp::create(rewriter, loc, arg.lhs.real(),
196 arg.rhs.real(), fmf);
197 Value imag = LLVM::FAddOp::create(rewriter, loc, arg.lhs.imag(),
198 arg.rhs.imag(), fmf);
199 result.setReal(rewriter, loc, real);
200 result.setImaginary(rewriter, loc, imag);
202 rewriter.replaceOp(op, {
result});
208 DivOpConversion(
const LLVMTypeConverter &converter,
209 complex::ComplexRangeFlags
target)
210 : ConvertOpToLLVMPattern<complex::DivOp>(converter),
213 using ConvertOpToLLVMPattern<complex::DivOp>::ConvertOpToLLVMPattern;
216 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
217 ConversionPatternRewriter &rewriter)
const override {
218 auto loc = op.getLoc();
219 BinaryComplexOperands arg =
220 unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
223 auto structType = typeConverter->convertType(op.getType());
227 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
228 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
231 Value rhsRe = arg.rhs.real();
232 Value rhsIm = arg.rhs.imag();
233 Value lhsRe = arg.lhs.real();
234 Value lhsIm = arg.lhs.imag();
236 Value resultRe, resultIm;
238 if (complexRange == complex::ComplexRangeFlags::basic ||
239 complexRange == complex::ComplexRangeFlags::none) {
241 rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm);
242 }
else if (complexRange == complex::ComplexRangeFlags::improved) {
244 rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm);
247 result.setReal(rewriter, loc, resultRe);
248 result.setImaginary(rewriter, loc, resultIm);
250 rewriter.replaceOp(op, {
result});
255 complex::ComplexRangeFlags complexRange;
259 using ConvertOpToLLVMPattern<complex::MulOp>::ConvertOpToLLVMPattern;
262 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
263 ConversionPatternRewriter &rewriter)
const override {
264 auto loc = op.getLoc();
265 BinaryComplexOperands arg =
266 unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
269 auto structType = typeConverter->convertType(op.getType());
273 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
274 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
277 Value rhsRe = arg.rhs.real();
278 Value rhsIm = arg.rhs.imag();
279 Value lhsRe = arg.lhs.real();
280 Value lhsIm = arg.lhs.imag();
282 Value real = LLVM::FSubOp::create(
283 rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, lhsRe, fmf),
284 LLVM::FMulOp::create(rewriter, loc, rhsIm, lhsIm, fmf), fmf);
286 Value imag = LLVM::FAddOp::create(
287 rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
288 LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
290 result.setReal(rewriter, loc, real);
291 result.setImaginary(rewriter, loc, imag);
293 rewriter.replaceOp(op, {
result});
299 using ConvertOpToLLVMPattern<complex::SubOp>::ConvertOpToLLVMPattern;
302 matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
303 ConversionPatternRewriter &rewriter)
const override {
304 auto loc = op.getLoc();
305 BinaryComplexOperands arg =
306 unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
309 auto structType = typeConverter->convertType(op.getType());
313 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
314 LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get(
317 Value real = LLVM::FSubOp::create(rewriter, loc, arg.lhs.real(),
318 arg.rhs.real(), fmf);
319 Value imag = LLVM::FSubOp::create(rewriter, loc, arg.lhs.imag(),
320 arg.rhs.imag(), fmf);
321 result.setReal(rewriter, loc, real);
322 result.setImaginary(rewriter, loc, imag);
324 rewriter.replaceOp(op, {
result});
332 complex::ComplexRangeFlags complexRange) {
345 patterns.add<DivOpConversion>(converter, complexRange);
350struct ConvertComplexToLLVMPass
354 void runOnOperation()
override;
358void ConvertComplexToLLVMPass::runOnOperation() {
365 target.addIllegalDialect<complex::ComplexDialect>();
367 applyPartialConversion(getOperation(),
target, std::move(
patterns))))
379 void loadDependentDialects(MLIRContext *context)
const final {
380 context->loadDialect<LLVM::LLVMDialect>();
385 void populateConvertToLLVMConversionPatterns(
386 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
387 RewritePatternSet &
patterns)
const final {
395 +[](
MLIRContext *ctx, complex::ComplexDialect *dialect) {
396 dialect->addInterfaces<ComplexToLLVMDialectInterface>();
static constexpr unsigned kRealPosInComplexNumberStruct
static constexpr unsigned kImaginaryPosInComplexNumberStruct
Value imaginary(OpBuilder &builder, Location loc)
void setImaginary(OpBuilder &builder, Location loc, Value imaginary)
void setReal(OpBuilder &builder, Location loc, Value real)
Value real(OpBuilder &builder, Location loc)
static ComplexStructBuilder poison(OpBuilder &builder, Location loc, Type type)
Build IR creating an undef value of the complex number type.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr)
Builds IR to set a value in the struct at position pos.
Value extractPtr(OpBuilder &builder, Location loc, unsigned pos) const
Builds IR to extract a value from the struct at position pos.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, Attribute propertiesAttr, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
void convertDivToLLVMUsingRangeReduction(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the LLVM dialect using Smith's method
void convertDivToLLVMUsingAlgebraic(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the LLVM dialect using algebraic method
Include the generated interface declarations.
void populateComplexToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, mlir::complex::ComplexRangeFlags complexRange=mlir::complex::ComplexRangeFlags::basic)
Populate the given list with patterns that convert from Complex to LLVM.
const FrozenRewritePatternSet & patterns
void registerConvertComplexToLLVMInterface(DialectRegistry ®istry)