20#include "llvm/ADT/STLExtras.h"
21#include "llvm/Support/FormatVariadic.h"
23#define DEBUG_TYPE "math-to-spirv-pattern"
35 if (
auto vectorType = dyn_cast<VectorType>(type)) {
36 if (!vectorType.getElementType().isInteger(32))
39 return spirv::ConstantOp::create(builder, loc, type,
43 return spirv::ConstantOp::create(builder, loc, type,
56 if (
auto vecTy = dyn_cast<VectorType>(originalType)) {
57 if (!vecTy.getElementType().isIntOrIndexOrFloat())
59 if (vecTy.isScalable())
61 if (vecTy.getRank() > 1)
78 for (
Type ty : allTypes) {
80 return rewriter.notifyMatchFailure(
83 "unsupported source type for Math to SPIR-V conversion: {0}",
102template <
typename Op,
typename SPIRVOp>
103struct CheckedElementwiseOpPattern final
105 using BasePattern =
typename spirv::ElementwiseOpPattern<Op, SPIRVOp>;
106 using BasePattern::BasePattern;
109 matchAndRewrite(Op op,
typename Op::Adaptor adaptor,
110 ConversionPatternRewriter &rewriter)
const override {
114 return BasePattern::matchAndRewrite(op, adaptor, rewriter);
119struct CopySignPattern final :
public OpConversionPattern<math::CopySignOp> {
123 matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
124 ConversionPatternRewriter &rewriter)
const override {
129 Type type = getTypeConverter()->convertType(copySignOp.getType());
134 if (
auto scalarType = dyn_cast<FloatType>(copySignOp.getType())) {
135 floatType = scalarType;
136 }
else if (
auto vectorType = dyn_cast<VectorType>(copySignOp.getType())) {
137 floatType = cast<FloatType>(vectorType.getElementType());
142 Location loc = copySignOp.getLoc();
143 int bitwidth = floatType.getWidth();
144 Type intType = rewriter.getIntegerType(bitwidth);
145 uint64_t intValue = uint64_t(1) << (bitwidth - 1);
147 Value signMask = spirv::ConstantOp::create(
148 rewriter, loc, intType, rewriter.getIntegerAttr(intType, intValue));
149 Value valueMask = spirv::ConstantOp::create(
150 rewriter, loc, intType,
151 rewriter.getIntegerAttr(intType, intValue - 1u));
153 if (
auto vectorType = dyn_cast<VectorType>(type)) {
154 assert(vectorType.getRank() == 1);
155 int count = vectorType.getNumElements();
156 intType = VectorType::get(count, intType);
158 SmallVector<Value> signSplat(count, signMask);
159 signMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
162 SmallVector<Value> valueSplat(count, valueMask);
163 valueMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
168 spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getLhs());
170 spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getRhs());
172 Value value = spirv::BitwiseAndOp::create(rewriter, loc, intType,
174 Value sign = spirv::BitwiseAndOp::create(rewriter, loc, intType,
177 Value
result = spirv::BitwiseOrOp::create(rewriter, loc, intType,
179 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type,
result);
189struct CountLeadingZerosPattern final
190 :
public OpConversionPattern<math::CountLeadingZerosOp> {
194 matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
195 ConversionPatternRewriter &rewriter)
const override {
199 Type type = getTypeConverter()->convertType(countOp.getType());
204 unsigned bitwidth = 0;
205 if (isa<IntegerType>(type))
207 if (
auto vectorType = dyn_cast<VectorType>(type))
208 bitwidth = vectorType.getElementTypeBitWidth();
212 Location loc = countOp.getLoc();
213 Value input = adaptor.getOperand();
218 Value msb = spirv::GLFindUMsbOp::create(rewriter, loc, input);
223 Value subMsb = spirv::ISubOp::create(rewriter, loc, val31, msb);
228 Value subInput = spirv::ISubOp::create(rewriter, loc, val32, input);
229 Value cmp = spirv::ULessThanEqualOp::create(rewriter, loc, input, val1);
230 rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput,
240template <
typename ExpOp>
241struct ExpM1OpPattern final :
public OpConversionPattern<math::ExpM1Op> {
245 matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
246 ConversionPatternRewriter &rewriter)
const override {
247 assert(adaptor.getOperands().size() == 1);
252 Location loc = operation.getLoc();
253 Type type = this->getTypeConverter()->convertType(operation.getType());
257 Value exp = ExpOp::create(rewriter, loc, type, adaptor.getOperand());
258 auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
259 rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
268template <
typename LogOp>
269struct Log1pOpPattern final :
public OpConversionPattern<math::Log1pOp> {
273 matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
274 ConversionPatternRewriter &rewriter)
const override {
275 assert(adaptor.getOperands().size() == 1);
280 Location loc = operation.getLoc();
281 Type type = this->getTypeConverter()->convertType(operation.getType());
285 auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
287 spirv::FAddOp::create(rewriter, loc, one, adaptor.getOperand());
288 rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
300template <
typename MathLogOp,
typename SpirvLogOp>
301struct Log2Log10OpPattern final :
public OpConversionPattern<MathLogOp> {
302 using OpConversionPattern<MathLogOp>::OpConversionPattern;
303 using typename OpConversionPattern<MathLogOp>::OpAdaptor;
305 static constexpr double log2Reciprocal =
306 1.442695040888963407359924681001892137426645954152985934135449407;
307 static constexpr double log10Reciprocal =
308 0.4342944819032518276511289189166050822943970058036665661144537832;
311 matchAndRewrite(MathLogOp operation, OpAdaptor adaptor,
312 ConversionPatternRewriter &rewriter)
const override {
313 assert(adaptor.getOperands().size() == 1);
318 Location loc = operation.getLoc();
319 Type type = this->getTypeConverter()->convertType(operation.getType());
321 return rewriter.notifyMatchFailure(operation,
"type conversion failed");
323 auto getConstantValue = [&](
double value) {
324 if (
auto floatType = dyn_cast<FloatType>(type)) {
325 return spirv::ConstantOp::create(
326 rewriter, loc, type, rewriter.getFloatAttr(floatType, value));
328 if (
auto vectorType = dyn_cast<VectorType>(type)) {
329 Type elemType = vectorType.getElementType();
331 if (isa<FloatType>(elemType)) {
332 return spirv::ConstantOp::create(
335 vectorType, FloatAttr::get(elemType, value).getValue()));
339 llvm_unreachable(
"unimplemented types for log2/log10");
342 Value constantValue = getConstantValue(
343 std::is_same<MathLogOp, math::Log2Op>() ? log2Reciprocal
345 Value log = SpirvLogOp::create(rewriter, loc, adaptor.getOperand());
346 rewriter.replaceOpWithNewOp<spirv::FMulOp>(operation, type, log,
353struct PowFOpPattern final :
public OpConversionPattern<math::PowFOp> {
357 matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
358 ConversionPatternRewriter &rewriter)
const override {
362 Type dstType = getTypeConverter()->convertType(powfOp.getType());
367 FloatType scalarFloatType;
368 if (
auto scalarType = dyn_cast<FloatType>(powfOp.getType())) {
369 scalarFloatType = scalarType;
370 }
else if (
auto vectorType = dyn_cast<VectorType>(powfOp.getType())) {
371 scalarFloatType = cast<FloatType>(vectorType.getElementType());
377 Type scalarIntType = rewriter.getIntegerType(32);
378 Type intType = scalarIntType;
379 auto operandType = adaptor.getRhs().getType();
380 if (
auto vectorType = dyn_cast<VectorType>(operandType)) {
381 auto shape = vectorType.getShape();
382 intType = VectorType::get(shape, scalarIntType);
387 Location loc = powfOp.getLoc();
388 Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter);
390 spirv::FOrdLessThanOp::create(rewriter, loc, adaptor.getLhs(), zero);
396 Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter);
398 spirv::FRemOp::create(rewriter, loc, adaptor.getRhs(), floatOne);
399 Value expRemNonZero =
400 spirv::FOrdNotEqualOp::create(rewriter, loc, expRem, zero);
401 Value cmpNegativeWithFractionalExp =
402 spirv::LogicalAndOp::create(rewriter, loc, expRemNonZero, lessThan);
404 const auto &floatSemantics = scalarFloatType.getFloatSemantics();
405 const auto nan = APFloat::getNaN(floatSemantics);
406 Attribute nanAttr = rewriter.getFloatAttr(scalarFloatType, nan);
407 if (
auto vectorType = dyn_cast<VectorType>(operandType))
411 spirv::ConstantOp::create(rewriter, loc, operandType, nanAttr);
413 spirv::SelectOp::create(rewriter, loc, cmpNegativeWithFractionalExp,
414 nanValue, adaptor.getLhs());
415 Value
abs = spirv::GLFAbsOp::create(rewriter, loc,
lhs);
423 spirv::ConvertFToSOp::create(rewriter, loc, intType, adaptor.getRhs());
424 Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter);
425 Value bitwiseAndOne =
426 spirv::BitwiseAndOp::create(rewriter, loc, intRhs, intOne);
427 Value isOdd = spirv::IEqualOp::create(rewriter, loc, bitwiseAndOne, intOne);
430 Value pow = spirv::GLPowOp::create(rewriter, loc, abs, adaptor.getRhs());
431 Value negate = spirv::FNegateOp::create(rewriter, loc, pow);
434 spirv::LogicalAndOp::create(rewriter, loc, lessThan, isOdd);
435 rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, shouldNegate, negate,
442struct RoundOpPattern final :
public OpConversionPattern<math::RoundOp> {
446 matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
447 ConversionPatternRewriter &rewriter)
const override {
451 Location loc = roundOp.getLoc();
452 Value operand = roundOp.getOperand();
456 auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter);
457 auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
459 if (VectorType vty = dyn_cast<VectorType>(ty)) {
460 half = spirv::ConstantOp::create(
463 rewriter.getFloatAttr(ety, 0.5).getValue()));
465 half = spirv::ConstantOp::create(rewriter, loc, ty,
466 rewriter.getFloatAttr(ety, 0.5));
469 auto abs = spirv::GLFAbsOp::create(rewriter, loc, operand);
470 auto floor = spirv::GLFloorOp::create(rewriter, loc, abs);
471 auto sub = spirv::FSubOp::create(rewriter, loc, abs, floor);
473 spirv::FOrdGreaterThanEqualOp::create(rewriter, loc, sub, half);
474 auto select = spirv::SelectOp::create(rewriter, loc, greater, one, zero);
475 auto add = spirv::FAddOp::create(rewriter, loc, floor, select);
476 rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp,
add, operand);
492 .add<CopySignPattern,
493 CheckedElementwiseOpPattern<math::IsInfOp, spirv::IsInfOp>,
494 CheckedElementwiseOpPattern<math::IsNaNOp, spirv::IsNanOp>,
495 CheckedElementwiseOpPattern<math::IsFiniteOp, spirv::IsFiniteOp>>(
496 typeConverter,
patterns.getContext());
500 .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>,
501 Log2Log10OpPattern<math::Log2Op, spirv::GLLogOp>,
502 Log2Log10OpPattern<math::Log10Op, spirv::GLLogOp>,
503 ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
504 CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
505 CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
506 CheckedElementwiseOpPattern<math::AtanOp, spirv::GLAtanOp>,
507 CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
508 CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
509 CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
510 CheckedElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
511 CheckedElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
512 CheckedElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
513 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
514 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
515 CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
516 CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
517 CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>,
518 CheckedElementwiseOpPattern<math::TanOp, spirv::GLTanOp>,
519 CheckedElementwiseOpPattern<math::AsinOp, spirv::GLAsinOp>,
520 CheckedElementwiseOpPattern<math::AcosOp, spirv::GLAcosOp>,
521 CheckedElementwiseOpPattern<math::SinhOp, spirv::GLSinhOp>,
522 CheckedElementwiseOpPattern<math::CoshOp, spirv::GLCoshOp>,
523 CheckedElementwiseOpPattern<math::AsinhOp, spirv::GLAsinhOp>,
524 CheckedElementwiseOpPattern<math::AcoshOp, spirv::GLAcoshOp>,
525 CheckedElementwiseOpPattern<math::AtanhOp, spirv::GLAtanhOp>>(
526 typeConverter,
patterns.getContext());
529 patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
530 Log2Log10OpPattern<math::Log2Op, spirv::CLLogOp>,
531 Log2Log10OpPattern<math::Log10Op, spirv::CLLogOp>,
532 CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
533 CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
534 CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
535 CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>,
536 CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
537 CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
538 CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
539 CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
540 CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
541 CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
542 CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
543 CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
544 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
545 CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
546 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
547 CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
548 CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
549 CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>,
550 CheckedElementwiseOpPattern<math::TanOp, spirv::CLTanOp>,
551 CheckedElementwiseOpPattern<math::AsinOp, spirv::CLAsinOp>,
552 CheckedElementwiseOpPattern<math::AcosOp, spirv::CLAcosOp>,
553 CheckedElementwiseOpPattern<math::SinhOp, spirv::CLSinhOp>,
554 CheckedElementwiseOpPattern<math::CoshOp, spirv::CLCoshOp>,
555 CheckedElementwiseOpPattern<math::AsinhOp, spirv::CLAsinhOp>,
556 CheckedElementwiseOpPattern<math::AcoshOp, spirv::CLAcoshOp>,
557 CheckedElementwiseOpPattern<math::AtanhOp, spirv::CLAtanhOp>>(
558 typeConverter,
patterns.getContext());
static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter, Operation *sourceOp)
Check if all sourceOp types are supported by math-to-spirv conversion.
static bool isSupportedSourceType(Type originalType)
Check if the type is supported by math-to-spirv conversion.
static Value getScalarOrVectorI32Constant(Type type, int value, OpBuilder &builder, Location loc)
Creates a 32-bit scalar/vector integer constant.
IntegerAttr getI32IntegerAttr(int32_t value)
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
DynamicAPInt floor(const Fraction &f)
Fraction abs(const Fraction &f)
Include the generated interface declarations.
void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Math ops to SPIR-V ops.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.