18 #define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDLLIBRARYCALLS
19 #include "mlir/Conversion/Passes.h.inc"
26 template <
typename Op,
typename FloatTy>
30 ComplexOpToROCDLLibraryCalls(
MLIRContext *context, StringRef funcName,
36 Type resType = op.getType();
37 if (
auto complexType = dyn_cast<ComplexType>(resType))
38 resType = complexType.getElementType();
39 if (!isa<FloatTy>(resType))
42 auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
48 rewriter.
getContext(), op->getOperandTypes(), op->getResultTypes());
49 opFunc = func::FuncOp::create(rewriter, rewriter.
getUnknownLoc(),
66 LogicalResult matchAndRewrite(complex::PowOp op,
69 auto fastmath = op.getFastmathAttr();
71 complex::LogOp::create(rewriter, loc, op.getLhs(), fastmath);
73 complex::MulOp::create(rewriter, loc, op.getRhs(), logBase, fastmath);
74 Value exp = complex::ExpOp::create(rewriter, loc, mul, fastmath);
75 rewriter.replaceOp(op, exp);
81 struct PowiOpToROCDLLibraryCalls :
public OpRewritePattern<complex::PowiOp> {
84 LogicalResult matchAndRewrite(complex::PowiOp op,
87 Type elementType = complexType.getElementType();
89 Type exponentType = op.getRhs().getType();
90 Type exponentFloatType = elementType;
91 if (
auto shapedType = dyn_cast<ShapedType>(exponentType))
92 exponentFloatType = shapedType.cloneWith(std::nullopt, elementType);
96 rewriter.create<arith::SIToFPOp>(loc, exponentFloatType, op.getRhs());
97 Value zeroImag = rewriter.create<arith::ConstantOp>(
98 loc, rewriter.getZeroAttr(exponentFloatType));
99 Value exponent = rewriter.create<complex::CreateOp>(
100 loc, op.getLhs().getType(), exponentReal, zeroImag);
102 rewriter.replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(),
103 exponent, op.getFastmathAttr());
113 patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
114 patterns.getContext(),
"__ocml_cabs_f32");
115 patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
116 patterns.getContext(),
"__ocml_cabs_f64");
117 patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>(
118 patterns.getContext(),
"__ocml_ccos_f32");
119 patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>(
120 patterns.getContext(),
"__ocml_ccos_f64");
121 patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>(
122 patterns.getContext(),
"__ocml_cexp_f32");
123 patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>(
124 patterns.getContext(),
"__ocml_cexp_f64");
125 patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float32Type>>(
126 patterns.getContext(),
"__ocml_clog_f32");
127 patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>(
128 patterns.getContext(),
"__ocml_clog_f64");
129 patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>(
130 patterns.getContext(),
"__ocml_csin_f32");
131 patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>(
132 patterns.getContext(),
"__ocml_csin_f64");
133 patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float32Type>>(
134 patterns.getContext(),
"__ocml_csqrt_f32");
135 patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float64Type>>(
136 patterns.getContext(),
"__ocml_csqrt_f64");
137 patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float32Type>>(
138 patterns.getContext(),
"__ocml_ctan_f32");
139 patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float64Type>>(
140 patterns.getContext(),
"__ocml_ctan_f64");
141 patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float32Type>>(
142 patterns.getContext(),
"__ocml_ctanh_f32");
143 patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float64Type>>(
144 patterns.getContext(),
"__ocml_ctanh_f64");
148 struct ConvertComplexToROCDLLibraryCallsPass
149 :
public impl::ConvertComplexToROCDLLibraryCallsBase<
150 ConvertComplexToROCDLLibraryCallsPass> {
151 void runOnOperation()
override;
155 void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
162 target.addLegalDialect<arith::ArithDialect, func::FuncDialect>();
163 target.addLegalOp<complex::CreateOp, complex::MulOp>();
164 target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
165 complex::LogOp, complex::PowOp, complex::PowiOp,
166 complex::SinOp, complex::SqrtOp, complex::TanOp,
static MLIRContext * getContext(OpFoldResult val)
MLIRContext * getContext() const
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
void populateComplexToROCDLLibraryCallsConversionPatterns(RewritePatternSet &patterns)
Populate the given list with patterns that convert from Complex to ROCDL calls.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...