21#include "llvm/ADT/STLExtras.h"
22#include "llvm/ADT/TypeSwitch.h"
23#include "llvm/Support/FormatVariadic.h"
25#define DEBUG_TYPE "math-to-spirv-pattern"
37 if (
auto vectorType = dyn_cast<VectorType>(type)) {
38 if (!vectorType.getElementType().isInteger(32))
41 return spirv::ConstantOp::create(builder, loc, type,
45 return spirv::ConstantOp::create(builder, loc, type,
58 if (
auto vecTy = dyn_cast<VectorType>(originalType)) {
59 if (!vecTy.getElementType().isIntOrIndexOrFloat())
61 if (vecTy.isScalable())
63 if (vecTy.getRank() > 1)
80 for (
Type ty : allTypes) {
82 return rewriter.notifyMatchFailure(
85 "unsupported source type for Math to SPIR-V conversion: {0}",
104template <
typename Op,
typename SPIRVOp>
105struct CheckedElementwiseOpPattern final
107 using BasePattern =
typename spirv::ElementwiseOpPattern<Op, SPIRVOp>;
108 using BasePattern::BasePattern;
111 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
112 ConversionPatternRewriter &rewriter)
const override {
116 return BasePattern::matchAndRewrite(op, adaptor, rewriter);
121struct CopySignPattern final :
public OpConversionPattern<math::CopySignOp> {
125 matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
126 ConversionPatternRewriter &rewriter)
const override {
131 Type type = getTypeConverter()->convertType(copySignOp.getType());
136 if (
auto scalarType = dyn_cast<FloatType>(copySignOp.getType())) {
137 floatType = scalarType;
138 }
else if (
auto vectorType = dyn_cast<VectorType>(copySignOp.getType())) {
139 floatType = cast<FloatType>(vectorType.getElementType());
144 Location loc = copySignOp.getLoc();
145 int bitwidth = floatType.getWidth();
146 Type intType = rewriter.getIntegerType(bitwidth);
147 uint64_t intValue = uint64_t(1) << (bitwidth - 1);
149 Value signMask = spirv::ConstantOp::create(
150 rewriter, loc, intType, rewriter.getIntegerAttr(intType, intValue));
151 Value valueMask = spirv::ConstantOp::create(
152 rewriter, loc, intType,
153 rewriter.getIntegerAttr(intType, intValue - 1u));
155 if (
auto vectorType = dyn_cast<VectorType>(type)) {
156 assert(vectorType.getRank() == 1);
157 int count = vectorType.getNumElements();
158 intType = VectorType::get(count, intType);
160 Repeated<Value> signSplat(count, signMask);
161 signMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
164 Repeated<Value> valueSplat(count, valueMask);
165 valueMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
170 spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getLhs());
172 spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getRhs());
174 Value value = spirv::BitwiseAndOp::create(rewriter, loc, intType,
176 Value sign = spirv::BitwiseAndOp::create(rewriter, loc, intType,
179 Value
result = spirv::BitwiseOrOp::create(rewriter, loc, intType,
181 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type,
result);
194struct CountLeadingZerosPattern final
195 :
public OpConversionPattern<math::CountLeadingZerosOp> {
199 matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
200 ConversionPatternRewriter &rewriter)
const override {
204 Type type = getTypeConverter()->convertType(countOp.getType());
208 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
209 if (!typeConverter.getTargetEnv().allows(spirv::Capability::Shader))
210 return rewriter.notifyMatchFailure(countOp,
"requires Shader capability");
213 unsigned bitwidth = 0;
214 if (isa<IntegerType>(type))
216 if (
auto vectorType = dyn_cast<VectorType>(type))
217 bitwidth = vectorType.getElementTypeBitWidth();
221 Location loc = countOp.getLoc();
222 Value input = adaptor.getOperand();
227 Value msb = spirv::GLFindUMsbOp::create(rewriter, loc, input);
232 Value subMsb = spirv::ISubOp::create(rewriter, loc, val31, msb);
237 Value subInput = spirv::ISubOp::create(rewriter, loc, val32, input);
238 Value cmp = spirv::ULessThanEqualOp::create(rewriter, loc, input, val1);
239 rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput,
248struct CountTrailingZerosPattern final
249 :
public OpConversionPattern<math::CountTrailingZerosOp> {
253 matchAndRewrite(math::CountTrailingZerosOp countOp, OpAdaptor adaptor,
254 ConversionPatternRewriter &rewriter)
const override {
258 Type type = getTypeConverter()->convertType(countOp.getType());
262 unsigned bitwidth = 0;
263 if (isa<IntegerType>(type))
265 else if (
auto vectorType = dyn_cast<VectorType>(type))
266 bitwidth = vectorType.getElementTypeBitWidth();
270 Location loc = countOp.getLoc();
271 Value input = adaptor.getOperand();
276 Value lsb = spirv::GLFindILsbOp::create(rewriter, loc, input);
277 Value isZero = spirv::IEqualOp::create(rewriter, loc, input, val0);
278 rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, isZero, valBitwidth,
288template <
typename ExpOp>
289struct ExpM1OpPattern final :
public OpConversionPattern<math::ExpM1Op> {
293 matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
294 ConversionPatternRewriter &rewriter)
const override {
295 assert(adaptor.getOperands().size() == 1);
300 Location loc = operation.getLoc();
301 Type type = this->getTypeConverter()->convertType(operation.getType());
305 Value exp = ExpOp::create(rewriter, loc, type, adaptor.getOperand());
306 auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
307 rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
316template <
typename LogOp>
317struct Log1pOpPattern final :
public OpConversionPattern<math::Log1pOp> {
321 matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
322 ConversionPatternRewriter &rewriter)
const override {
323 assert(adaptor.getOperands().size() == 1);
328 Location loc = operation.getLoc();
329 Type type = this->getTypeConverter()->convertType(operation.getType());
333 auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
335 spirv::FAddOp::create(rewriter, loc, one, adaptor.getOperand());
336 rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
345struct Log10OpPattern final :
public OpConversionPattern<math::Log10Op> {
348 static constexpr double log10Reciprocal =
349 0.4342944819032518276511289189166050822943970058036665661144537832;
352 matchAndRewrite(math::Log10Op operation, OpAdaptor adaptor,
353 ConversionPatternRewriter &rewriter)
const override {
354 assert(adaptor.getOperands().size() == 1);
359 Location loc = operation.getLoc();
360 Type type = this->getTypeConverter()->convertType(operation.getType());
362 return rewriter.notifyMatchFailure(operation,
"type conversion failed");
364 auto getConstantValue = [&](
double value) {
365 if (
auto floatType = dyn_cast<FloatType>(type)) {
366 return spirv::ConstantOp::create(
367 rewriter, loc, type, rewriter.getFloatAttr(floatType, value));
369 if (
auto vectorType = dyn_cast<VectorType>(type)) {
370 Type elemType = vectorType.getElementType();
372 if (isa<FloatType>(elemType)) {
373 return spirv::ConstantOp::create(
376 vectorType, FloatAttr::get(elemType, value).getValue()));
379 llvm_unreachable(
"unimplemented type for log10");
382 Value constantValue = getConstantValue(log10Reciprocal);
383 Value log = spirv::GLLogOp::create(rewriter, loc, adaptor.getOperand());
384 rewriter.replaceOpWithNewOp<spirv::FMulOp>(operation, type, log,
391struct PowFOpPattern final :
public OpConversionPattern<math::PowFOp> {
395 matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
396 ConversionPatternRewriter &rewriter)
const override {
400 Type dstType = getTypeConverter()->convertType(powfOp.getType());
404 Location loc = powfOp.getLoc();
405 Type operandType = adaptor.getRhs().getType();
410 auto isOdd = [](
const APFloat &v) {
413 v.convertToInteger(i, APFloat::rmTowardZero, &ignored);
417 SmallVector<bool> oddMask;
421 .Case([&](FloatAttr a) {
422 if (a.getValue().isInteger())
423 oddMask.push_back(isOdd(a.getValue()));
425 .Case([&](SplatElementsAttr a) {
427 if (splat.isInteger())
428 oddMask.push_back(isOdd(splat));
430 .Case([&](DenseElementsAttr a) {
431 SmallVector<bool> mask;
432 for (
const APFloat &elt : a.
getValues<APFloat>()) {
433 if (!elt.isInteger())
435 mask.push_back(isOdd(elt));
437 oddMask = std::move(mask);
441 if (oddMask.empty()) {
442 Value log = spirv::GLLogOp::create(rewriter, loc, adaptor.getLhs());
443 Value
mul = spirv::FMulOp::create(rewriter, loc, adaptor.getRhs(), log);
444 rewriter.replaceOpWithNewOp<spirv::GLExpOp>(powfOp,
mul);
450 Value
abs = spirv::GLFAbsOp::create(rewriter, loc, adaptor.getLhs());
451 Value pow = spirv::GLPowOp::create(rewriter, loc, abs, adaptor.getRhs());
454 if (llvm::none_of(oddMask, [](
bool b) {
return b; })) {
455 rewriter.replaceOp(powfOp, pow);
459 Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter);
461 spirv::FOrdLessThanOp::create(rewriter, loc, adaptor.getLhs(), zero);
462 Value negate = spirv::FNegateOp::create(rewriter, loc, pow);
465 if (llvm::all_equal(oddMask)) {
467 shouldNegate = lessThan;
471 auto vecType = cast<VectorType>(operandType);
472 auto maskType = VectorType::get(vecType.getShape(), rewriter.getI1Type());
473 Value oddConst = spirv::ConstantOp::create(
476 spirv::LogicalAndOp::create(rewriter, loc, lessThan, oddConst);
479 rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, shouldNegate, negate,
486struct PowIOpPattern final :
public OpConversionPattern<math::FPowIOp> {
490 matchAndRewrite(math::FPowIOp op, OpAdaptor adaptor,
491 ConversionPatternRewriter &rewriter)
const override {
495 Type dstType = getTypeConverter()->convertType(op.getType());
499 rewriter.replaceOpWithNewOp<spirv::CLPownOp>(op, dstType, adaptor.getLhs(),
509struct PowIOpGLPattern final :
public OpConversionPattern<math::FPowIOp> {
513 matchAndRewrite(math::FPowIOp op, OpAdaptor adaptor,
514 ConversionPatternRewriter &rewriter)
const override {
518 Type dstType = getTypeConverter()->convertType(op.getType());
522 Location loc = op.getLoc();
523 Value base = adaptor.getLhs();
524 Value power = adaptor.getRhs();
527 spirv::ConvertSToFOp::create(rewriter, loc, dstType, power);
528 Value
abs = spirv::GLFAbsOp::create(rewriter, loc, base);
529 Value pow = spirv::GLPowOp::create(rewriter, loc, abs, expFloat);
531 Value zeroF = spirv::ConstantOp::getZero(dstType, loc, rewriter);
532 Value lessThan = spirv::FOrdLessThanOp::create(rewriter, loc, base, zeroF);
534 Type powerType = power.
getType();
535 Value oneI = spirv::ConstantOp::getOne(powerType, loc, rewriter);
536 Value lowBit = spirv::BitwiseAndOp::create(rewriter, loc, power, oneI);
537 Value isOdd = spirv::IEqualOp::create(rewriter, loc, lowBit, oneI);
540 spirv::LogicalAndOp::create(rewriter, loc, lessThan, isOdd);
541 Value negate = spirv::FNegateOp::create(rewriter, loc, pow);
542 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, shouldNegate, negate, pow);
548struct RoundOpPattern final :
public OpConversionPattern<math::RoundOp> {
552 matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
553 ConversionPatternRewriter &rewriter)
const override {
557 Location loc = roundOp.getLoc();
558 auto ty = getTypeConverter()->convertType(adaptor.getOperand().getType());
560 return rewriter.notifyMatchFailure(
562 llvm::formatv(
"failed to convert type {0} for SPIR-V",
568 auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter);
569 auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
571 if (VectorType vty = dyn_cast<VectorType>(ty)) {
572 half = spirv::ConstantOp::create(
575 rewriter.getFloatAttr(ety, 0.5).getValue()));
577 half = spirv::ConstantOp::create(rewriter, loc, ty,
578 rewriter.getFloatAttr(ety, 0.5));
581 auto abs = spirv::GLFAbsOp::create(rewriter, loc, adaptor.getOperand());
582 auto floor = spirv::GLFloorOp::create(rewriter, loc, abs);
583 auto sub = spirv::FSubOp::create(rewriter, loc, abs, floor);
585 spirv::FOrdGreaterThanEqualOp::create(rewriter, loc, sub, half);
586 auto select = spirv::SelectOp::create(rewriter, loc, greater, one, zero);
587 auto add = spirv::FAddOp::create(rewriter, loc, floor, select);
588 rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp,
add,
589 adaptor.getOperand());
605 .
add<CopySignPattern,
606 CheckedElementwiseOpPattern<math::CtPopOp, spirv::BitCountOp>,
607 CheckedElementwiseOpPattern<math::IsInfOp, spirv::IsInfOp>,
608 CheckedElementwiseOpPattern<math::IsNaNOp, spirv::IsNanOp>,
609 CheckedElementwiseOpPattern<math::IsFiniteOp, spirv::IsFiniteOp>,
610 CheckedElementwiseOpPattern<math::IsNormalOp, spirv::IsNormalOp>>(
615 .
add<CountLeadingZerosPattern, CountTrailingZerosPattern,
616 Log1pOpPattern<spirv::GLLogOp>, Log10OpPattern,
617 ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, PowIOpGLPattern,
619 CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
620 CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
621 CheckedElementwiseOpPattern<math::AtanOp, spirv::GLAtanOp>,
622 CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
623 CheckedElementwiseOpPattern<math::ClampFOp, spirv::GLFClampOp>,
624 CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
625 CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
626 CheckedElementwiseOpPattern<math::Exp2Op, spirv::GLExp2Op>,
627 CheckedElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
628 CheckedElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
629 CheckedElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
630 CheckedElementwiseOpPattern<math::Log2Op, spirv::GLLog2Op>,
631 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
632 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
633 CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
634 CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
635 CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>,
636 CheckedElementwiseOpPattern<math::TanOp, spirv::GLTanOp>,
637 CheckedElementwiseOpPattern<math::TruncOp, spirv::GLTruncOp>,
638 CheckedElementwiseOpPattern<math::AsinOp, spirv::GLAsinOp>,
639 CheckedElementwiseOpPattern<math::AcosOp, spirv::GLAcosOp>,
640 CheckedElementwiseOpPattern<math::SinhOp, spirv::GLSinhOp>,
641 CheckedElementwiseOpPattern<math::CoshOp, spirv::GLCoshOp>,
642 CheckedElementwiseOpPattern<math::AsinhOp, spirv::GLAsinhOp>,
643 CheckedElementwiseOpPattern<math::AcoshOp, spirv::GLAcoshOp>,
644 CheckedElementwiseOpPattern<math::AtanhOp, spirv::GLAtanhOp>>(
649 Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
650 CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
651 CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
652 CheckedElementwiseOpPattern<math::CountLeadingZerosOp, spirv::CLClzOp>,
653 CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
654 CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>,
655 CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
656 CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
657 CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
658 CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
659 CheckedElementwiseOpPattern<math::Exp2Op, spirv::CLExp2Op>,
660 CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
661 CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
662 CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
663 CheckedElementwiseOpPattern<math::Log2Op, spirv::CLLog2Op>,
664 CheckedElementwiseOpPattern<math::Log10Op, spirv::CLLog10Op>,
665 CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>, PowIOpPattern,
666 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
667 CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
668 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
669 CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
670 CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
671 CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>,
672 CheckedElementwiseOpPattern<math::TanOp, spirv::CLTanOp>,
673 CheckedElementwiseOpPattern<math::TruncOp, spirv::CLTruncOp>,
674 CheckedElementwiseOpPattern<math::AsinOp, spirv::CLAsinOp>,
675 CheckedElementwiseOpPattern<math::AcosOp, spirv::CLAcosOp>,
676 CheckedElementwiseOpPattern<math::SinhOp, spirv::CLSinhOp>,
677 CheckedElementwiseOpPattern<math::CoshOp, spirv::CLCoshOp>,
678 CheckedElementwiseOpPattern<math::AsinhOp, spirv::CLAsinhOp>,
679 CheckedElementwiseOpPattern<math::AcoshOp, spirv::CLAcoshOp>,
680 CheckedElementwiseOpPattern<math::AtanhOp, spirv::CLAtanhOp>>(
static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter, Operation *sourceOp)
Check if all sourceOp types are supported by math-to-spirv conversion.
static bool isSupportedSourceType(Type originalType)
Check if the type is supported by math-to-spirv conversion.
static Value getScalarOrVectorI32Constant(Type type, int value, OpBuilder &builder, Location loc)
Creates a 32-bit scalar/vector integer constant.
IntegerAttr getI32IntegerAttr(int32_t value)
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
auto getValues() const
Return the held element values as a range of the given type.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
DynamicAPInt floor(const Fraction &f)
Fraction abs(const Fraction &f)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Math ops to SPIR-V ops.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
llvm::TypeSwitch< T, ResultT > TypeSwitch
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.