68 ConversionPatternRewriter &rewriter)
const override {
70 Value input = adaptor.getOperand();
72 auto convertedInput = maybeExt(input, rewriter);
73 auto computeType = convertedInput.getType();
76 if (isa<Float32Type>(computeType)) {
77 const arith::FastMathFlags flag = op.getFastmath();
78 const bool useApprox =
79 mlir::arith::bitEnumContainsAny(flag, arith::FastMathFlags::afn);
80 sincosFunc = useApprox ?
"__nv_fast_sincosf" :
"__nv_sincosf";
81 }
else if (isa<Float64Type>(computeType)) {
82 sincosFunc =
"__nv_sincos";
84 return rewriter.notifyMatchFailure(op,
85 "unsupported operand type for sincos");
88 auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
95 assert(scope &&
"Expected op to be inside automatic allocation scope");
96 rewriter.setInsertionPointToStart(&scope->getRegion(0).front());
97 auto one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
98 rewriter.getI32IntegerAttr(1));
100 LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
102 LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
105 createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
108 auto sinResult = LLVM::LoadOp::create(rewriter, loc, computeType, sinPtr);
109 auto cosResult = LLVM::LoadOp::create(rewriter, loc, computeType, cosPtr);
111 rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter),
112 maybeTrunc(cosResult, inputType, rewriter)});
118 if (isa<Float16Type, BFloat16Type>(operand.
getType()))
119 return LLVM::FPExtOp::create(rewriter, operand.
getLoc(),
125 Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter)
const {
127 return LLVM::FPTruncOp::create(rewriter, operand.
getLoc(), type, operand);
131 void createSincosCall(ConversionPatternRewriter &rewriter, Location loc,
132 StringRef funcName, Value input, Value sinPtr,
133 Value cosPtr, Operation *op)
const {
134 auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext());
135 auto ptrType = sinPtr.
getType();
137 SmallVector<Type> operandTypes = {input.
getType(), ptrType, ptrType};
138 auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes);
140 auto funcAttr = StringAttr::get(op->
getContext(), funcName);
146 assert(parentFunc &&
"expected there to be a parent function");
147 OpBuilder
b(parentFunc);
150 funcOp = LLVM::LLVMFuncOp::create(
b, globalloc, funcName, funcType);
153 SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
154 LLVM::CallOp::create(rewriter, loc, funcOp, callOperands);
164 "__nv_fmaxf",
"__nv_fmax");
166 "__nv_fminf",
"__nv_fmin");
190 "__nv_copysignf",
"__nv_copysign");
192 "__nv_cos",
"__nv_fast_cosf");
200 "__nv_exp",
"__nv_fast_expf");
211 "__nv_finitef",
"__nv_isfinited");
217 "__nv_log",
"__nv_fast_logf");
219 "__nv_log10",
"__nv_fast_log10f");
223 "__nv_log2",
"__nv_fast_log2f");
225 "__nv_pow",
"__nv_fast_powf");
227 "__nv_powif",
"__nv_powi");
231 "__nv_rintf",
"__nv_rint");
235 "__nv_sin",
"__nv_fast_sinf");
241 "__nv_tan",
"__nv_fast_tanf");