21 #define GEN_PASS_DEF_CONVERTCOMPLEXTOLLVMPASS
22 #include "mlir/Conversion/Passes.h.inc"
38 Value val = builder.
create<LLVM::UndefOp>(loc, type);
70 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
72 auto loc = op.getLoc();
75 Value real = complexStruct.real(rewriter, op.getLoc());
76 Value imag = complexStruct.imaginary(rewriter, op.getLoc());
78 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
83 loc, rewriter.
create<LLVM::FMulOp>(loc, real, real, fmf),
84 rewriter.
create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
95 matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor,
98 op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
99 op->getAttrs(), *getTypeConverter(), rewriter);
107 matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
110 auto loc = complexOp.getLoc();
111 auto structType = typeConverter->convertType(complexOp.getType());
113 complexStruct.setReal(rewriter, loc, adaptor.getReal());
114 complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary());
116 rewriter.
replaceOp(complexOp, {complexStruct});
125 matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
129 Value real = complexStruct.real(rewriter, op.getLoc());
140 matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
144 Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
151 struct BinaryComplexOperands {
152 std::complex<Value> lhs;
153 std::complex<Value> rhs;
156 template <
typename OpTy>
157 BinaryComplexOperands
158 unpackBinaryComplexOperands(OpTy op,
typename OpTy::Adaptor adaptor,
160 auto loc = op.getLoc();
163 BinaryComplexOperands unpacked;
165 unpacked.lhs.real(lhs.real(rewriter, loc));
166 unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
168 unpacked.rhs.real(rhs.real(rewriter, loc));
169 unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
178 matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
180 auto loc = op.getLoc();
181 BinaryComplexOperands arg =
182 unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
185 auto structType = typeConverter->convertType(op.getType());
189 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
194 rewriter.
create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
196 rewriter.
create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
197 result.setReal(rewriter, loc, real);
198 result.setImaginary(rewriter, loc, imag);
209 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
211 auto loc = op.getLoc();
212 BinaryComplexOperands arg =
213 unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
216 auto structType = typeConverter->convertType(op.getType());
220 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
224 Value rhsRe = arg.rhs.real();
225 Value rhsIm = arg.rhs.imag();
226 Value lhsRe = arg.lhs.real();
227 Value lhsIm = arg.lhs.imag();
230 loc, rewriter.
create<LLVM::FMulOp>(loc, rhsRe, rhsRe, fmf),
231 rewriter.
create<LLVM::FMulOp>(loc, rhsIm, rhsIm, fmf), fmf);
234 loc, rewriter.
create<LLVM::FMulOp>(loc, lhsRe, rhsRe, fmf),
235 rewriter.
create<LLVM::FMulOp>(loc, lhsIm, rhsIm, fmf), fmf);
238 loc, rewriter.
create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
239 rewriter.
create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
243 rewriter.
create<LLVM::FDivOp>(loc, resultReal, rhsSqNorm, fmf));
246 rewriter.
create<LLVM::FDivOp>(loc, resultImag, rhsSqNorm, fmf));
257 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
259 auto loc = op.getLoc();
260 BinaryComplexOperands arg =
261 unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
264 auto structType = typeConverter->convertType(op.getType());
268 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
272 Value rhsRe = arg.rhs.real();
273 Value rhsIm = arg.rhs.imag();
274 Value lhsRe = arg.lhs.real();
275 Value lhsIm = arg.lhs.imag();
278 loc, rewriter.
create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
279 rewriter.
create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
282 loc, rewriter.
create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
283 rewriter.
create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
285 result.setReal(rewriter, loc, real);
286 result.setImaginary(rewriter, loc, imag);
297 matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
299 auto loc = op.getLoc();
300 BinaryComplexOperands arg =
301 unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
304 auto structType = typeConverter->convertType(op.getType());
308 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
313 rewriter.
create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
315 rewriter.
create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
316 result.setReal(rewriter, loc, real);
317 result.setImaginary(rewriter, loc, imag);
343 struct ConvertComplexToLLVMPass
344 :
public impl::ConvertComplexToLLVMPassBase<ConvertComplexToLLVMPass> {
347 void runOnOperation()
override;
351 void ConvertComplexToLLVMPass::runOnOperation() {
358 target.addIllegalDialect<complex::ComplexDialect>();
372 void loadDependentDialects(
MLIRContext *context)
const final {
373 context->loadDialect<LLVM::LLVMDialect>();
378 void populateConvertToLLVMConversionPatterns(
388 +[](
MLIRContext *ctx, complex::ComplexDialect *dialect) {
389 dialect->addInterfaces<ComplexToLLVMDialectInterface>();
static constexpr unsigned kRealPosInComplexNumberStruct
static constexpr unsigned kImaginaryPosInComplexNumberStruct
static MLIRContext * getContext(OpFoldResult val)
static ComplexStructBuilder undef(OpBuilder &builder, Location loc, Type type)
Build IR creating an undef value of the complex number type.
Value imaginary(OpBuilder &builder, Location loc)
void setImaginary(OpBuilder &builder, Location loc, Value imaginary)
void setReal(OpBuilder &builder, Location loc, Value real)
Value real(OpBuilder &builder, Location loc)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void registerConvertComplexToLLVMInterface(DialectRegistry ®istry)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateComplexToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate the given list with patterns that convert from Complex to LLVM.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.