21 #define GEN_PASS_DEF_CONVERTCOMPLEXTOLLVMPASS
22 #include "mlir/Conversion/Passes.h.inc"
38 Value val = LLVM::PoisonOp::create(builder, loc, type);
70 matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
72 auto loc = op.getLoc();
75 Value real = complexStruct.real(rewriter, op.getLoc());
76 Value imag = complexStruct.imaginary(rewriter, op.getLoc());
78 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
82 Value sqNorm = LLVM::FAddOp::create(
83 rewriter, loc, LLVM::FMulOp::create(rewriter, loc, real, real, fmf),
84 LLVM::FMulOp::create(rewriter, loc, imag, imag, fmf), fmf);
95 matchAndRewrite(complex::ConstantOp op, OpAdaptor adaptor,
98 op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(),
99 op->getAttrs(), *getTypeConverter(), rewriter);
107 matchAndRewrite(complex::CreateOp complexOp, OpAdaptor adaptor,
110 auto loc = complexOp.getLoc();
111 auto structType = typeConverter->convertType(complexOp.getType());
114 complexStruct.setReal(rewriter, loc, adaptor.getReal());
115 complexStruct.setImaginary(rewriter, loc, adaptor.getImaginary());
117 rewriter.
replaceOp(complexOp, {complexStruct});
126 matchAndRewrite(complex::ReOp op, OpAdaptor adaptor,
130 Value real = complexStruct.real(rewriter, op.getLoc());
141 matchAndRewrite(complex::ImOp op, OpAdaptor adaptor,
145 Value imaginary = complexStruct.imaginary(rewriter, op.getLoc());
152 struct BinaryComplexOperands {
153 std::complex<Value> lhs;
154 std::complex<Value> rhs;
157 template <
typename OpTy>
158 BinaryComplexOperands
159 unpackBinaryComplexOperands(OpTy op,
typename OpTy::Adaptor adaptor,
161 auto loc = op.getLoc();
164 BinaryComplexOperands unpacked;
166 unpacked.lhs.real(lhs.real(rewriter, loc));
167 unpacked.lhs.imag(lhs.imaginary(rewriter, loc));
169 unpacked.rhs.real(rhs.real(rewriter, loc));
170 unpacked.rhs.imag(rhs.imaginary(rewriter, loc));
179 matchAndRewrite(complex::AddOp op, OpAdaptor adaptor,
181 auto loc = op.getLoc();
182 BinaryComplexOperands arg =
183 unpackBinaryComplexOperands<complex::AddOp>(op, adaptor, rewriter);
186 auto structType = typeConverter->convertType(op.getType());
190 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
194 Value real = LLVM::FAddOp::create(rewriter, loc, arg.lhs.real(),
195 arg.rhs.real(), fmf);
196 Value imag = LLVM::FAddOp::create(rewriter, loc, arg.lhs.imag(),
197 arg.rhs.imag(), fmf);
198 result.setReal(rewriter, loc, real);
199 result.setImaginary(rewriter, loc, imag);
208 complex::ComplexRangeFlags target)
210 complexRange(target) {}
215 matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
217 auto loc = op.getLoc();
218 BinaryComplexOperands arg =
219 unpackBinaryComplexOperands<complex::DivOp>(op, adaptor, rewriter);
222 auto structType = typeConverter->convertType(op.getType());
226 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
230 Value rhsRe = arg.rhs.real();
231 Value rhsIm = arg.rhs.imag();
232 Value lhsRe = arg.lhs.real();
233 Value lhsIm = arg.lhs.imag();
235 Value resultRe, resultIm;
237 if (complexRange == complex::ComplexRangeFlags::basic ||
238 complexRange == complex::ComplexRangeFlags::none) {
240 rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm);
241 }
else if (complexRange == complex::ComplexRangeFlags::improved) {
243 rewriter, loc, lhsRe, lhsIm, rhsRe, rhsIm, fmf, &resultRe, &resultIm);
246 result.setReal(rewriter, loc, resultRe);
247 result.setImaginary(rewriter, loc, resultIm);
254 complex::ComplexRangeFlags complexRange;
261 matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
263 auto loc = op.getLoc();
264 BinaryComplexOperands arg =
265 unpackBinaryComplexOperands<complex::MulOp>(op, adaptor, rewriter);
268 auto structType = typeConverter->convertType(op.getType());
272 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
276 Value rhsRe = arg.rhs.real();
277 Value rhsIm = arg.rhs.imag();
278 Value lhsRe = arg.lhs.real();
279 Value lhsIm = arg.lhs.imag();
281 Value real = LLVM::FSubOp::create(
282 rewriter, loc, LLVM::FMulOp::create(rewriter, loc, rhsRe, lhsRe, fmf),
283 LLVM::FMulOp::create(rewriter, loc, rhsIm, lhsIm, fmf), fmf);
285 Value imag = LLVM::FAddOp::create(
286 rewriter, loc, LLVM::FMulOp::create(rewriter, loc, lhsIm, rhsRe, fmf),
287 LLVM::FMulOp::create(rewriter, loc, lhsRe, rhsIm, fmf), fmf);
289 result.setReal(rewriter, loc, real);
290 result.setImaginary(rewriter, loc, imag);
301 matchAndRewrite(complex::SubOp op, OpAdaptor adaptor,
303 auto loc = op.getLoc();
304 BinaryComplexOperands arg =
305 unpackBinaryComplexOperands<complex::SubOp>(op, adaptor, rewriter);
308 auto structType = typeConverter->convertType(op.getType());
312 arith::FastMathFlagsAttr complexFMFAttr = op.getFastMathFlagsAttr();
316 Value real = LLVM::FSubOp::create(rewriter, loc, arg.lhs.real(),
317 arg.rhs.real(), fmf);
318 Value imag = LLVM::FSubOp::create(rewriter, loc, arg.lhs.imag(),
319 arg.rhs.imag(), fmf);
320 result.setReal(rewriter, loc, real);
321 result.setImaginary(rewriter, loc, imag);
331 complex::ComplexRangeFlags complexRange) {
344 patterns.add<DivOpConversion>(converter, complexRange);
349 struct ConvertComplexToLLVMPass
350 :
public impl::ConvertComplexToLLVMPassBase<ConvertComplexToLLVMPass> {
353 void runOnOperation()
override;
357 void ConvertComplexToLLVMPass::runOnOperation() {
364 target.addIllegalDialect<complex::ComplexDialect>();
378 void loadDependentDialects(
MLIRContext *context)
const final {
379 context->loadDialect<LLVM::LLVMDialect>();
384 void populateConvertToLLVMConversionPatterns(
394 +[](
MLIRContext *ctx, complex::ComplexDialect *dialect) {
395 dialect->addInterfaces<ComplexToLLVMDialectInterface>();
static constexpr unsigned kRealPosInComplexNumberStruct
static constexpr unsigned kImaginaryPosInComplexNumberStruct
static MLIRContext * getContext(OpFoldResult val)
Value imaginary(OpBuilder &builder, Location loc)
void setImaginary(OpBuilder &builder, Location loc, Value imaginary)
void setReal(OpBuilder &builder, Location loc, Value real)
Value real(OpBuilder &builder, Location loc)
static ComplexStructBuilder poison(OpBuilder &builder, Location loc, Type type)
Build IR creating an undef value of the complex number type.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF)
Maps arithmetic fastmath enum values to LLVM enum values.
void convertDivToLLVMUsingRangeReduction(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the LLVM dialect using Smith's method
void convertDivToLLVMUsingAlgebraic(ConversionPatternRewriter &rewriter, Location loc, Value lhsRe, Value lhsIm, Value rhsRe, Value rhsIm, LLVM::FastmathFlagsAttr fmf, Value *resultRe, Value *resultIm)
convert a complex division to the LLVM dialect using algebraic method
Include the generated interface declarations.
void populateComplexToLLVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, mlir::complex::ComplexRangeFlags complexRange=mlir::complex::ComplexRangeFlags::basic)
Populate the given list with patterns that convert from Complex to LLVM.
const FrozenRewritePatternSet & patterns
void registerConvertComplexToLLVMInterface(DialectRegistry ®istry)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.