13 #include "../SPIRVCommon/Pattern.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/Support/FormatVariadic.h"
23 #define DEBUG_TYPE "math-to-spirv-pattern"
35 if (
auto vectorType = dyn_cast<VectorType>(type)) {
36 if (!vectorType.getElementType().isInteger(32))
39 return spirv::ConstantOp::create(builder, loc, type,
43 return spirv::ConstantOp::create(builder, loc, type,
56 if (
auto vecTy = dyn_cast<VectorType>(originalType)) {
57 if (!vecTy.getElementType().isIntOrIndexOrFloat())
59 if (vecTy.isScalable())
61 if (vecTy.getRank() > 1)
78 for (
Type ty : allTypes) {
83 "unsupported source type for Math to SPIR-V conversion: {0}",
102 template <
typename Op,
typename SPIRVOp>
103 struct CheckedElementwiseOpPattern final
106 using BasePattern::BasePattern;
109 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
114 return BasePattern::matchAndRewrite(op, adaptor, rewriter);
123 matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
129 Type type = getTypeConverter()->convertType(copySignOp.getType());
134 if (
auto scalarType = dyn_cast<FloatType>(copySignOp.getType())) {
135 floatType = scalarType;
136 }
else if (
auto vectorType = dyn_cast<VectorType>(copySignOp.getType())) {
137 floatType = cast<FloatType>(vectorType.getElementType());
143 int bitwidth = floatType.getWidth();
145 uint64_t intValue = uint64_t(1) << (bitwidth - 1);
147 Value signMask = spirv::ConstantOp::create(
148 rewriter, loc, intType, rewriter.
getIntegerAttr(intType, intValue));
149 Value valueMask = spirv::ConstantOp::create(
150 rewriter, loc, intType,
153 if (
auto vectorType = dyn_cast<VectorType>(type)) {
154 assert(vectorType.getRank() == 1);
155 int count = vectorType.getNumElements();
159 signMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
163 valueMask = spirv::CompositeConstructOp::create(rewriter, loc, intType,
168 spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getLhs());
170 spirv::BitcastOp::create(rewriter, loc, intType, adaptor.getRhs());
172 Value value = spirv::BitwiseAndOp::create(rewriter, loc, intType,
174 Value sign = spirv::BitwiseAndOp::create(rewriter, loc, intType,
177 Value result = spirv::BitwiseOrOp::create(rewriter, loc, intType,
189 struct CountLeadingZerosPattern final
194 matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
199 Type type = getTypeConverter()->convertType(countOp.getType());
204 unsigned bitwidth = 0;
205 if (isa<IntegerType>(type))
207 if (
auto vectorType = dyn_cast<VectorType>(type))
208 bitwidth = vectorType.getElementTypeBitWidth();
213 Value input = adaptor.getOperand();
218 Value msb = spirv::GLFindUMsbOp::create(rewriter, loc, input);
223 Value subMsb = spirv::ISubOp::create(rewriter, loc, val31, msb);
228 Value subInput = spirv::ISubOp::create(rewriter, loc, val32, input);
229 Value cmp = spirv::ULessThanEqualOp::create(rewriter, loc, input, val1);
240 template <
typename ExpOp>
245 matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
247 assert(adaptor.getOperands().size() == 1);
253 Type type = this->getTypeConverter()->convertType(operation.getType());
257 Value exp = ExpOp::create(rewriter, loc, type, adaptor.getOperand());
258 auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
268 template <
typename LogOp>
273 matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
275 assert(adaptor.getOperands().size() == 1);
281 Type type = this->getTypeConverter()->convertType(operation.getType());
285 auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
287 spirv::FAddOp::create(rewriter, loc, one, adaptor.getOperand());
300 template <
typename MathLogOp,
typename SpirvLogOp>
305 static constexpr
double log2Reciprocal =
306 1.442695040888963407359924681001892137426645954152985934135449407;
307 static constexpr
double log10Reciprocal =
308 0.4342944819032518276511289189166050822943970058036665661144537832;
311 matchAndRewrite(MathLogOp operation, OpAdaptor adaptor,
313 assert(adaptor.getOperands().size() == 1);
319 Type type = this->getTypeConverter()->convertType(operation.getType());
323 auto getConstantValue = [&](
double value) {
324 if (
auto floatType = dyn_cast<FloatType>(type)) {
325 return spirv::ConstantOp::create(
326 rewriter, loc, type, rewriter.
getFloatAttr(floatType, value));
328 if (
auto vectorType = dyn_cast<VectorType>(type)) {
329 Type elemType = vectorType.getElementType();
331 if (isa<FloatType>(elemType)) {
332 return spirv::ConstantOp::create(
339 llvm_unreachable(
"unimplemented types for log2/log10");
342 Value constantValue = getConstantValue(
343 std::is_same<MathLogOp, math::Log2Op>() ? log2Reciprocal
345 Value log = SpirvLogOp::create(rewriter, loc, adaptor.getOperand());
357 matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
362 Type dstType = getTypeConverter()->convertType(powfOp.getType());
367 FloatType scalarFloatType;
368 if (
auto scalarType = dyn_cast<FloatType>(powfOp.getType())) {
369 scalarFloatType = scalarType;
370 }
else if (
auto vectorType = dyn_cast<VectorType>(powfOp.getType())) {
371 scalarFloatType = cast<FloatType>(vectorType.getElementType());
378 Type intType = scalarIntType;
379 auto operandType = adaptor.getRhs().getType();
380 if (
auto vectorType = dyn_cast<VectorType>(operandType)) {
381 auto shape = vectorType.getShape();
390 spirv::FOrdLessThanOp::create(rewriter, loc, adaptor.getLhs(), zero);
396 Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter);
398 spirv::FRemOp::create(rewriter, loc, adaptor.getRhs(), floatOne);
399 Value expRemNonZero =
400 spirv::FOrdNotEqualOp::create(rewriter, loc, expRem, zero);
401 Value cmpNegativeWithFractionalExp =
402 spirv::LogicalAndOp::create(rewriter, loc, expRemNonZero, lessThan);
404 const auto &floatSemantics = scalarFloatType.getFloatSemantics();
405 const auto nan = APFloat::getNaN(floatSemantics);
407 if (
auto vectorType = dyn_cast<VectorType>(operandType))
411 spirv::ConstantOp::create(rewriter, loc, operandType, nanAttr);
413 spirv::SelectOp::create(rewriter, loc, cmpNegativeWithFractionalExp,
414 NanValue, adaptor.getLhs());
415 Value abs = spirv::GLFAbsOp::create(rewriter, loc, lhs);
423 spirv::ConvertFToSOp::create(rewriter, loc, intType, adaptor.getRhs());
424 Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter);
425 Value bitwiseAndOne =
426 spirv::BitwiseAndOp::create(rewriter, loc, intRhs, intOne);
427 Value isOdd = spirv::IEqualOp::create(rewriter, loc, bitwiseAndOne, intOne);
430 Value pow = spirv::GLPowOp::create(rewriter, loc, abs, adaptor.getRhs());
431 Value negate = spirv::FNegateOp::create(rewriter, loc, pow);
434 spirv::LogicalAndOp::create(rewriter, loc, lessThan, isOdd);
446 matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
452 Value operand = roundOp.getOperand();
457 auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
459 if (VectorType vty = dyn_cast<VectorType>(ty)) {
460 half = spirv::ConstantOp::create(
465 half = spirv::ConstantOp::create(rewriter, loc, ty,
469 auto abs = spirv::GLFAbsOp::create(rewriter, loc, operand);
470 auto floor = spirv::GLFloorOp::create(rewriter, loc, abs);
471 auto sub = spirv::FSubOp::create(rewriter, loc, abs,
floor);
473 spirv::FOrdGreaterThanEqualOp::create(rewriter, loc, sub, half);
474 auto select = spirv::SelectOp::create(rewriter, loc, greater, one, zero);
475 auto add = spirv::FAddOp::create(rewriter, loc,
floor, select);
492 .add<CopySignPattern,
493 CheckedElementwiseOpPattern<math::IsInfOp, spirv::IsInfOp>,
494 CheckedElementwiseOpPattern<math::IsNaNOp, spirv::IsNanOp>,
495 CheckedElementwiseOpPattern<math::IsFiniteOp, spirv::IsFiniteOp>>(
496 typeConverter,
patterns.getContext());
500 .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>,
501 Log2Log10OpPattern<math::Log2Op, spirv::GLLogOp>,
502 Log2Log10OpPattern<math::Log10Op, spirv::GLLogOp>,
503 ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
504 CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
505 CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
506 CheckedElementwiseOpPattern<math::AtanOp, spirv::GLAtanOp>,
507 CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
508 CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
509 CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
510 CheckedElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
511 CheckedElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
512 CheckedElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
513 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
514 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
515 CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
516 CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
517 CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>,
518 CheckedElementwiseOpPattern<math::TanOp, spirv::GLTanOp>,
519 CheckedElementwiseOpPattern<math::AsinOp, spirv::GLAsinOp>,
520 CheckedElementwiseOpPattern<math::AcosOp, spirv::GLAcosOp>,
521 CheckedElementwiseOpPattern<math::SinhOp, spirv::GLSinhOp>,
522 CheckedElementwiseOpPattern<math::CoshOp, spirv::GLCoshOp>,
523 CheckedElementwiseOpPattern<math::AsinhOp, spirv::GLAsinhOp>,
524 CheckedElementwiseOpPattern<math::AcoshOp, spirv::GLAcoshOp>,
525 CheckedElementwiseOpPattern<math::AtanhOp, spirv::GLAtanhOp>>(
526 typeConverter,
patterns.getContext());
529 patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
530 Log2Log10OpPattern<math::Log2Op, spirv::CLLogOp>,
531 Log2Log10OpPattern<math::Log10Op, spirv::CLLogOp>,
532 CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
533 CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
534 CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
535 CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>,
536 CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
537 CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
538 CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
539 CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
540 CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
541 CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
542 CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
543 CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
544 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
545 CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
546 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
547 CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
548 CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
549 CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>,
550 CheckedElementwiseOpPattern<math::TanOp, spirv::CLTanOp>,
551 CheckedElementwiseOpPattern<math::AsinOp, spirv::CLAsinOp>,
552 CheckedElementwiseOpPattern<math::AcosOp, spirv::CLAcosOp>,
553 CheckedElementwiseOpPattern<math::SinhOp, spirv::CLSinhOp>,
554 CheckedElementwiseOpPattern<math::CoshOp, spirv::CLCoshOp>,
555 CheckedElementwiseOpPattern<math::AsinhOp, spirv::CLAsinhOp>,
556 CheckedElementwiseOpPattern<math::AcoshOp, spirv::CLAcoshOp>,
557 CheckedElementwiseOpPattern<math::AtanhOp, spirv::CLAtanhOp>>(
558 typeConverter,
patterns.getContext());
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter, Operation *sourceOp)
Check if all sourceOp types are supported by math-to-spirv conversion.
static bool isSupportedSourceType(Type originalType)
Check if the type is supported by math-to-spirv conversion.
static Value getScalarOrVectorI32Constant(Type type, int value, OpBuilder &builder, Location loc)
Creates a 32-bit scalar/vector integer constant.
Attributes are known-constant values of operations.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
This class implements a pattern rewriter for use with ConversionPatterns.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
DynamicAPInt floor(const Fraction &f)
Fraction abs(const Fraction &f)
Include the generated interface declarations.
void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Math ops to SPIR-V ops.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.