20#include "llvm/ADT/STLExtras.h"
21#include "llvm/Support/FormatVariadic.h"
23#define DEBUG_TYPE "math-to-spirv-pattern"
35 if (
auto vectorType = dyn_cast<VectorType>(type)) {
36 if (!vectorType.getElementType().isInteger(32))
39 return spirv::ConstantOp::create(builder, loc, type,
43 return spirv::ConstantOp::create(builder, loc, type,
56 if (
auto vecTy = dyn_cast<VectorType>(originalType)) {
57 if (!vecTy.getElementType().isIntOrIndexOrFloat())
59 if (vecTy.isScalable())
61 if (vecTy.getRank() > 1)
78 for (
Type ty : allTypes) {
80 return rewriter.notifyMatchFailure(
83 "unsupported source type for Math to SPIR-V conversion: {0}",
102template <
typename Op,
typename SPIRVOp>
103struct CheckedElementwiseOpPattern final
105 using BasePattern =
typename spirv::ElementwiseOpPattern<Op, SPIRVOp>;
106 using BasePattern::BasePattern;
109 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
110 ConversionPatternRewriter &rewriter)
const override {
114 return BasePattern::matchAndRewrite(op, adaptor, rewriter);
119struct CopySignPattern final :
public OpConversionPattern<math::CopySignOp> {
123 matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
124 ConversionPatternRewriter &rewriter)
const override {
129 Type type = getTypeConverter()->convertType(copySignOp.getType());
134 if (
auto scalarType = dyn_cast<FloatType>(copySignOp.getType())) {
135 floatType = scalarType;
136 }
else if (
auto vectorType = dyn_cast<VectorType>(copySignOp.getType())) {
137 floatType = cast<FloatType>(vectorType.getElementType());
142 Location loc = copySignOp.getLoc();
143 int bitwidth = floatType.getWidth();
144 Type intType = rewriter.getIntegerType(bitwidth);
145 uint64_t intValue = uint64_t(1) << (bitwidth - 1);
147 Value signMask = spirv::ConstantOp::create(
148 rewriter, loc, intType, rewriter.getIntegerAttr(intType, intValue));
149 Value valueMask = spirv::ConstantOp::create(
150 rewriter, loc, intType,
151 rewriter.getIntegerAttr(intType, intValue - 1u));
153 if (
auto vectorType = dyn_cast<VectorType>(type)) {
154 assert(vectorType.getRank() == 1);
155 int count = vectorType.getNumElements();
156 intType = VectorType::get(count, intType);
158 Repeated<Value> signSplat(count, signMask);
159 signMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
162 Repeated<Value> valueSplat(count, valueMask);
163 valueMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
168 spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getLhs());
170 spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getRhs());
172 Value value = spirv::BitwiseAndOp::create(rewriter, loc, intType,
174 Value sign = spirv::BitwiseAndOp::create(rewriter, loc, intType,
177 Value
result = spirv::BitwiseOrOp::create(rewriter, loc, intType,
179 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type,
result);
192struct CountLeadingZerosPattern final
193 :
public OpConversionPattern<math::CountLeadingZerosOp> {
197 matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
198 ConversionPatternRewriter &rewriter)
const override {
202 Type type = getTypeConverter()->convertType(countOp.getType());
206 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
207 if (!typeConverter.getTargetEnv().allows(spirv::Capability::Shader))
208 return rewriter.notifyMatchFailure(countOp,
"requires Shader capability");
211 unsigned bitwidth = 0;
212 if (isa<IntegerType>(type))
214 if (
auto vectorType = dyn_cast<VectorType>(type))
215 bitwidth = vectorType.getElementTypeBitWidth();
219 Location loc = countOp.getLoc();
220 Value input = adaptor.getOperand();
225 Value msb = spirv::GLFindUMsbOp::create(rewriter, loc, input);
230 Value subMsb = spirv::ISubOp::create(rewriter, loc, val31, msb);
235 Value subInput = spirv::ISubOp::create(rewriter, loc, val32, input);
236 Value cmp = spirv::ULessThanEqualOp::create(rewriter, loc, input, val1);
237 rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput,
247template <
typename ExpOp>
248struct ExpM1OpPattern final :
public OpConversionPattern<math::ExpM1Op> {
252 matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
253 ConversionPatternRewriter &rewriter)
const override {
254 assert(adaptor.getOperands().size() == 1);
259 Location loc = operation.getLoc();
260 Type type = this->getTypeConverter()->convertType(operation.getType());
264 Value exp = ExpOp::create(rewriter, loc, type, adaptor.getOperand());
265 auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
266 rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
275template <
typename LogOp>
276struct Log1pOpPattern final :
public OpConversionPattern<math::Log1pOp> {
280 matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
281 ConversionPatternRewriter &rewriter)
const override {
282 assert(adaptor.getOperands().size() == 1);
287 Location loc = operation.getLoc();
288 Type type = this->getTypeConverter()->convertType(operation.getType());
292 auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
294 spirv::FAddOp::create(rewriter, loc, one, adaptor.getOperand());
295 rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
304struct Log10OpPattern final :
public OpConversionPattern<math::Log10Op> {
307 static constexpr double log10Reciprocal =
308 0.4342944819032518276511289189166050822943970058036665661144537832;
311 matchAndRewrite(math::Log10Op operation, OpAdaptor adaptor,
312 ConversionPatternRewriter &rewriter)
const override {
313 assert(adaptor.getOperands().size() == 1);
318 Location loc = operation.getLoc();
319 Type type = this->getTypeConverter()->convertType(operation.getType());
321 return rewriter.notifyMatchFailure(operation,
"type conversion failed");
323 auto getConstantValue = [&](
double value) {
324 if (
auto floatType = dyn_cast<FloatType>(type)) {
325 return spirv::ConstantOp::create(
326 rewriter, loc, type, rewriter.getFloatAttr(floatType, value));
328 if (
auto vectorType = dyn_cast<VectorType>(type)) {
329 Type elemType = vectorType.getElementType();
331 if (isa<FloatType>(elemType)) {
332 return spirv::ConstantOp::create(
335 vectorType, FloatAttr::get(elemType, value).getValue()));
338 llvm_unreachable(
"unimplemented type for log10");
341 Value constantValue = getConstantValue(log10Reciprocal);
342 Value log = spirv::GLLogOp::create(rewriter, loc, adaptor.getOperand());
343 rewriter.replaceOpWithNewOp<spirv::FMulOp>(operation, type, log,
350struct PowFOpPattern final :
public OpConversionPattern<math::PowFOp> {
354 matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
355 ConversionPatternRewriter &rewriter)
const override {
359 Type dstType = getTypeConverter()->convertType(powfOp.getType());
364 FloatType scalarFloatType;
365 if (
auto scalarType = dyn_cast<FloatType>(powfOp.getType())) {
366 scalarFloatType = scalarType;
367 }
else if (
auto vectorType = dyn_cast<VectorType>(powfOp.getType())) {
368 scalarFloatType = cast<FloatType>(vectorType.getElementType());
374 Type scalarIntType = rewriter.getIntegerType(32);
375 Type intType = scalarIntType;
376 auto operandType = adaptor.getRhs().getType();
377 if (
auto vectorType = dyn_cast<VectorType>(operandType)) {
378 auto shape = vectorType.getShape();
379 intType = VectorType::get(shape, scalarIntType);
384 Location loc = powfOp.getLoc();
385 Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter);
387 spirv::FOrdLessThanOp::create(rewriter, loc, adaptor.getLhs(), zero);
393 Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter);
395 spirv::FRemOp::create(rewriter, loc, adaptor.getRhs(), floatOne);
396 Value expRemNonZero =
397 spirv::FOrdNotEqualOp::create(rewriter, loc, expRem, zero);
398 Value cmpNegativeWithFractionalExp =
399 spirv::LogicalAndOp::create(rewriter, loc, expRemNonZero, lessThan);
401 const auto &floatSemantics = scalarFloatType.getFloatSemantics();
402 const auto nan = APFloat::getNaN(floatSemantics);
403 Attribute nanAttr = rewriter.getFloatAttr(scalarFloatType, nan);
404 if (
auto vectorType = dyn_cast<VectorType>(operandType))
408 spirv::ConstantOp::create(rewriter, loc, operandType, nanAttr);
410 spirv::SelectOp::create(rewriter, loc, cmpNegativeWithFractionalExp,
411 nanValue, adaptor.getLhs());
412 Value
abs = spirv::GLFAbsOp::create(rewriter, loc,
lhs);
420 spirv::ConvertFToSOp::create(rewriter, loc, intType, adaptor.getRhs());
421 Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter);
422 Value bitwiseAndOne =
423 spirv::BitwiseAndOp::create(rewriter, loc, intRhs, intOne);
424 Value isOdd = spirv::IEqualOp::create(rewriter, loc, bitwiseAndOne, intOne);
427 Value pow = spirv::GLPowOp::create(rewriter, loc, abs, adaptor.getRhs());
428 Value negate = spirv::FNegateOp::create(rewriter, loc, pow);
431 spirv::LogicalAndOp::create(rewriter, loc, lessThan, isOdd);
432 rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, shouldNegate, negate,
439struct PowIOpPattern final :
public OpConversionPattern<math::FPowIOp> {
443 matchAndRewrite(math::FPowIOp op, OpAdaptor adaptor,
444 ConversionPatternRewriter &rewriter)
const override {
448 Type dstType = getTypeConverter()->convertType(op.getType());
452 rewriter.replaceOpWithNewOp<spirv::CLPownOp>(op, dstType, adaptor.getLhs(),
459struct RoundOpPattern final :
public OpConversionPattern<math::RoundOp> {
463 matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
464 ConversionPatternRewriter &rewriter)
const override {
468 Location loc = roundOp.getLoc();
469 auto ty = getTypeConverter()->convertType(adaptor.getOperand().getType());
471 return rewriter.notifyMatchFailure(
473 llvm::formatv(
"failed to convert type {0} for SPIR-V",
479 auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter);
480 auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
482 if (VectorType vty = dyn_cast<VectorType>(ty)) {
483 half = spirv::ConstantOp::create(
486 rewriter.getFloatAttr(ety, 0.5).getValue()));
488 half = spirv::ConstantOp::create(rewriter, loc, ty,
489 rewriter.getFloatAttr(ety, 0.5));
492 auto abs = spirv::GLFAbsOp::create(rewriter, loc, adaptor.getOperand());
493 auto floor = spirv::GLFloorOp::create(rewriter, loc, abs);
494 auto sub = spirv::FSubOp::create(rewriter, loc, abs, floor);
496 spirv::FOrdGreaterThanEqualOp::create(rewriter, loc, sub, half);
497 auto select = spirv::SelectOp::create(rewriter, loc, greater, one, zero);
498 auto add = spirv::FAddOp::create(rewriter, loc, floor, select);
499 rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp,
add,
500 adaptor.getOperand());
516 .
add<CopySignPattern,
517 CheckedElementwiseOpPattern<math::IsInfOp, spirv::IsInfOp>,
518 CheckedElementwiseOpPattern<math::IsNaNOp, spirv::IsNanOp>,
519 CheckedElementwiseOpPattern<math::IsFiniteOp, spirv::IsFiniteOp>>(
524 CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>, Log10OpPattern,
525 ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
526 CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
527 CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
528 CheckedElementwiseOpPattern<math::AtanOp, spirv::GLAtanOp>,
529 CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
530 CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
531 CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
532 CheckedElementwiseOpPattern<math::Exp2Op, spirv::GLExp2Op>,
533 CheckedElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
534 CheckedElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
535 CheckedElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
536 CheckedElementwiseOpPattern<math::Log2Op, spirv::GLLog2Op>,
537 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
538 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
539 CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
540 CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
541 CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>,
542 CheckedElementwiseOpPattern<math::TanOp, spirv::GLTanOp>,
543 CheckedElementwiseOpPattern<math::AsinOp, spirv::GLAsinOp>,
544 CheckedElementwiseOpPattern<math::AcosOp, spirv::GLAcosOp>,
545 CheckedElementwiseOpPattern<math::SinhOp, spirv::GLSinhOp>,
546 CheckedElementwiseOpPattern<math::CoshOp, spirv::GLCoshOp>,
547 CheckedElementwiseOpPattern<math::AsinhOp, spirv::GLAsinhOp>,
548 CheckedElementwiseOpPattern<math::AcoshOp, spirv::GLAcoshOp>,
549 CheckedElementwiseOpPattern<math::AtanhOp, spirv::GLAtanhOp>>(
554 Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
555 CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
556 CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
557 CheckedElementwiseOpPattern<math::CountLeadingZerosOp, spirv::CLClzOp>,
558 CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
559 CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>,
560 CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
561 CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
562 CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
563 CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
564 CheckedElementwiseOpPattern<math::Exp2Op, spirv::CLExp2Op>,
565 CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
566 CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
567 CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
568 CheckedElementwiseOpPattern<math::Log2Op, spirv::CLLog2Op>,
569 CheckedElementwiseOpPattern<math::Log10Op, spirv::CLLog10Op>,
570 CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>, PowIOpPattern,
571 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
572 CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
573 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
574 CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
575 CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
576 CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>,
577 CheckedElementwiseOpPattern<math::TanOp, spirv::CLTanOp>,
578 CheckedElementwiseOpPattern<math::AsinOp, spirv::CLAsinOp>,
579 CheckedElementwiseOpPattern<math::AcosOp, spirv::CLAcosOp>,
580 CheckedElementwiseOpPattern<math::SinhOp, spirv::CLSinhOp>,
581 CheckedElementwiseOpPattern<math::CoshOp, spirv::CLCoshOp>,
582 CheckedElementwiseOpPattern<math::AsinhOp, spirv::CLAsinhOp>,
583 CheckedElementwiseOpPattern<math::AcoshOp, spirv::CLAcoshOp>,
584 CheckedElementwiseOpPattern<math::AtanhOp, spirv::CLAtanhOp>>(
static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter, Operation *sourceOp)
Check if all sourceOp types are supported by math-to-spirv conversion.
static bool isSupportedSourceType(Type originalType)
Check if the type is supported by math-to-spirv conversion.
static Value getScalarOrVectorI32Constant(Type type, int value, OpBuilder &builder, Location loc)
Creates a 32-bit scalar/vector integer constant.
IntegerAttr getI32IntegerAttr(int32_t value)
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
DynamicAPInt floor(const Fraction &f)
Fraction abs(const Fraction &f)
Include the generated interface declarations.
void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Math ops to SPIR-V ops.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.