16 #define GEN_PASS_DEF_CONVERTCOMPLEXTOROCDLLIBRARYCALLS
17 #include "mlir/Conversion/Passes.h.inc"
24 template <
typename Op,
typename FloatTy>
28 ComplexOpToROCDLLibraryCalls(
MLIRContext *context, StringRef funcName,
34 Type resType = op.getType();
35 if (
auto complexType = dyn_cast<ComplexType>(resType))
36 resType = complexType.getElementType();
37 if (!isa<FloatTy>(resType))
40 auto opFunc = dyn_cast_or_null<SymbolOpInterface>(
46 rewriter.
getContext(), op->getOperandTypes(), op->getResultTypes());
47 opFunc = func::FuncOp::create(rewriter, rewriter.
getUnknownLoc(),
64 LogicalResult matchAndRewrite(complex::PowOp op,
67 Value logBase = complex::LogOp::create(rewriter, loc, op.getLhs());
68 Value mul = complex::MulOp::create(rewriter, loc, op.getRhs(), logBase);
69 Value exp = complex::ExpOp::create(rewriter, loc, mul);
70 rewriter.replaceOp(op, exp);
79 patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>(
80 patterns.getContext(),
"__ocml_cabs_f32");
81 patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
82 patterns.getContext(),
"__ocml_cabs_f64");
83 patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>(
84 patterns.getContext(),
"__ocml_ccos_f32");
85 patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>(
86 patterns.getContext(),
"__ocml_ccos_f64");
87 patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>(
88 patterns.getContext(),
"__ocml_cexp_f32");
89 patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>(
90 patterns.getContext(),
"__ocml_cexp_f64");
91 patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float32Type>>(
92 patterns.getContext(),
"__ocml_clog_f32");
93 patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>(
94 patterns.getContext(),
"__ocml_clog_f64");
95 patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>(
96 patterns.getContext(),
"__ocml_csin_f32");
97 patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>(
98 patterns.getContext(),
"__ocml_csin_f64");
99 patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float32Type>>(
100 patterns.getContext(),
"__ocml_csqrt_f32");
101 patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float64Type>>(
102 patterns.getContext(),
"__ocml_csqrt_f64");
103 patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float32Type>>(
104 patterns.getContext(),
"__ocml_ctan_f32");
105 patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float64Type>>(
106 patterns.getContext(),
"__ocml_ctan_f64");
107 patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float32Type>>(
108 patterns.getContext(),
"__ocml_ctanh_f32");
109 patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float64Type>>(
110 patterns.getContext(),
"__ocml_ctanh_f64");
114 struct ConvertComplexToROCDLLibraryCallsPass
115 :
public impl::ConvertComplexToROCDLLibraryCallsBase<
116 ConvertComplexToROCDLLibraryCallsPass> {
117 void runOnOperation()
override;
121 void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
128 target.addLegalDialect<func::FuncDialect>();
129 target.addLegalOp<complex::MulOp>();
130 target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
131 complex::LogOp, complex::PowOp, complex::SinOp,
132 complex::SqrtOp, complex::TanOp, complex::TanhOp>();
static MLIRContext * getContext(OpFoldResult val)
MLIRContext * getContext() const
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
void populateComplexToROCDLLibraryCallsConversionPatterns(RewritePatternSet &patterns)
Populate the given list with patterns that convert from Complex to ROCDL calls.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...