22 #define GEN_PASS_DEF_CONVERTCOMPLEXTOLLVMPASS
23 #include "mlir/Conversion/Passes.h.inc"
39 Value val = builder.
create<LLVM::PoisonOp>(loc, type);
71 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
73 auto loc = op.getLoc();
76 Value real = complexStruct.real(rewriter, op.getLoc());
77 Value imag = complexStruct.imaginary(rewriter, op.getLoc());
79 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
84 loc, rewriter.
create<LLVM::FMulOp>(loc, real, real, fmf),
85 rewriter.
create<LLVM::FMulOp>(loc, imag, imag, fmf), fmf);
96 matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor,
99 op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
100 op->getAttrs(), *getTypeConverter(), rewriter);
108 matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
111 auto loc = complexOp.getLoc();
112 auto structType = typeConverter->convertType(complexOp.getType());
115 complexStruct.setReal(rewriter, loc, adaptor.getReal());
116 complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary());
118 rewriter.
replaceOp(complexOp, {complexStruct});
127 matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
131 Value real = complexStruct.real(rewriter, op.getLoc());
142 matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
146 Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
153 struct BinaryComplexOperands {
154 std::complex<Value> lhs;
155 std::complex<Value> rhs;
158 template <
typename OpTy>
159 BinaryComplexOperands
160 unpackBinaryComplexOperands(OpTy op,
typename OpTy::Adaptor adaptor,
162 auto loc = op.getLoc();
165 BinaryComplexOperands unpacked;
167 unpacked.lhs.real(lhs.real(rewriter, loc));
168 unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
170 unpacked.rhs.real(rhs.real(rewriter, loc));
171 unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
180 matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
182 auto loc = op.getLoc();
183 BinaryComplexOperands arg =
184 unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
187 auto structType = typeConverter->convertType(op.getType());
191 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
196 rewriter.
create<LLVM::FAddOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
198 rewriter.
create<LLVM::FAddOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
199 result.setReal(rewriter, loc, real);
200 result.setImaginary(rewriter, loc, imag);
209 complex::ComplexRangeFlags target)
211 complexRange(target) {}
216 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
218 auto loc = op.getLoc();
219 BinaryComplexOperands arg =
220 unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
223 auto structType = typeConverter->convertType(op.getType());
227 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
231 Value rhsRe = arg.rhs.real();
232 Value rhsIm = arg.rhs.imag();
233 Value lhsRe = arg.lhs.real();
234 Value lhsIm = arg.lhs.imag();
236 Value resultRe, resultIm;
238 if (complexRange == complex::ComplexRangeFlags::basic ||
239 complexRange == complex::ComplexRangeFlags::none) {
241 rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm);
242 }
else if (complexRange == complex::ComplexRangeFlags::improved) {
244 rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm);
247 result.setReal(rewriter, loc, resultRe);
248 result.setImaginary(rewriter, loc, resultIm);
255 complex::ComplexRangeFlags complexRange;
262 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
264 auto loc = op.getLoc();
265 BinaryComplexOperands arg =
266 unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
269 auto structType = typeConverter->convertType(op.getType());
273 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
277 Value rhsRe = arg.rhs.real();
278 Value rhsIm = arg.rhs.imag();
279 Value lhsRe = arg.lhs.real();
280 Value lhsIm = arg.lhs.imag();
283 loc, rewriter.
create<LLVM::FMulOp>(loc, rhsRe, lhsRe, fmf),
284 rewriter.
create<LLVM::FMulOp>(loc, rhsIm, lhsIm, fmf), fmf);
287 loc, rewriter.
create<LLVM::FMulOp>(loc, lhsIm, rhsRe, fmf),
288 rewriter.
create<LLVM::FMulOp>(loc, lhsRe, rhsIm, fmf), fmf);
290 result.setReal(rewriter, loc, real);
291 result.setImaginary(rewriter, loc, imag);
302 matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
304 auto loc = op.getLoc();
305 BinaryComplexOperands arg =
306 unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
309 auto structType = typeConverter->convertType(op.getType());
313 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
318 rewriter.
create<LLVM::FSubOp>(loc, arg.lhs.real(), arg.rhs.real(), fmf);
320 rewriter.
create<LLVM::FSubOp>(loc, arg.lhs.imag(), arg.rhs.imag(), fmf);
321 result.setReal(rewriter, loc, real);
322 result.setImaginary(rewriter, loc, imag);
332 complex::ComplexRangeFlags complexRange) {
345 patterns.add<DivOpConversion>(converter, complexRange);
350 struct ConvertComplexToLLVMPass
351 :
public impl::ConvertComplexToLLVMPassBase<ConvertComplexToLLVMPass> {
354 void runOnOperation()
override;
358 void ConvertComplexToLLVMPass::runOnOperation() {
365 target.addIllegalDialect<complex::ComplexDialect>();
379 void loadDependentDialects(
MLIRContext *context)
const final {
380 context->loadDialect<LLVM::LLVMDialect>();
385 void populateConvertToLLVMConversionPatterns(
395 +[](
MLIRContext *ctx, complex::ComplexDialect *dialect) {
396 dialect->addInterfaces<ComplexToLLVMDialectInterface>();
static constexpr unsigned kRealPosInComplexNumberStruct
static constexpr unsigned kImaginaryPosInComplexNumberStruct
static MLIRContext * getContext(OpFoldResult val)
Value imaginary(OpBuilder &builder, Location loc)
void setImaginary(OpBuilder &builder, Location loc, Value imaginary)
void setReal(OpBuilder &builder, Location loc, Value real)
Value real(OpBuilder &builder, Location loc)
static ComplexStructBuilder poison(OpBuilder &builder, Location loc, Type type)
Build IR creating an undef value of the complex number type.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
void convertDivToLLVMUsingRangeReduction(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the LLVM dialect using Smith's method
void convertDivToLLVMUsingAlgebraic(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the LLVM dialect using algebraic method
Include the generated interface declarations.
void populateComplexToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, mlir::complex::ComplexRangeFlags complexRange=mlir::complex::ComplexRangeFlags::basic)
Populate the given list with patterns that convert from Complex to LLVM.
const FrozenRewritePatternSet & patterns
void registerConvertComplexToLLVMInterface(DialectRegistry ®istry)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.