13 #include "../SPIRVCommon/Pattern.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/Support/FormatVariadic.h"
23 #define DEBUG_TYPE "math-to-spirv-pattern"
35 if (
auto vectorType = dyn_cast<VectorType>(type)) {
36 if (!vectorType.getElementType().isInteger(32))
39 return builder.
create<spirv::ConstantOp>(loc, type,
43 return builder.
create<spirv::ConstantOp>(loc, type,
56 if (
auto vecTy = dyn_cast<VectorType>(originalType)) {
57 if (!vecTy.getElementType().isIntOrIndexOrFloat())
59 if (vecTy.isScalable())
61 if (vecTy.getRank() > 1)
78 for (
Type ty : allTypes) {
83 "unsupported source type for Math to SPIR-V conversion: {0}",
102 template <
typename Op,
typename SPIRVOp>
103 struct CheckedElementwiseOpPattern final
106 using BasePattern::BasePattern;
109 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
114 return BasePattern::matchAndRewrite(op, adaptor, rewriter);
123 matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
129 Type type = getTypeConverter()->convertType(copySignOp.getType());
134 if (
auto scalarType = dyn_cast<FloatType>(copySignOp.getType())) {
135 floatType = scalarType;
136 }
else if (
auto vectorType = dyn_cast<VectorType>(copySignOp.getType())) {
137 floatType = cast<FloatType>(vectorType.getElementType());
143 int bitwidth = floatType.getWidth();
145 uint64_t intValue = uint64_t(1) << (bitwidth - 1);
147 Value signMask = rewriter.
create<spirv::ConstantOp>(
149 Value valueMask = rewriter.
create<spirv::ConstantOp>(
152 if (
auto vectorType = dyn_cast<VectorType>(type)) {
153 assert(vectorType.getRank() == 1);
154 int count = vectorType.getNumElements();
159 rewriter.
create<spirv::CompositeConstructOp>(loc, intType, signSplat);
162 valueMask = rewriter.
create<spirv::CompositeConstructOp>(loc, intType,
167 rewriter.
create<spirv::BitcastOp>(loc, intType, adaptor.getLhs());
169 rewriter.
create<spirv::BitcastOp>(loc, intType, adaptor.getRhs());
171 Value value = rewriter.
create<spirv::BitwiseAndOp>(
172 loc, intType,
ValueRange{lhsCast, valueMask});
176 Value result = rewriter.
create<spirv::BitwiseOrOp>(loc, intType,
188 struct CountLeadingZerosPattern final
193 matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
198 Type type = getTypeConverter()->convertType(countOp.getType());
203 unsigned bitwidth = 0;
204 if (isa<IntegerType>(type))
206 if (
auto vectorType = dyn_cast<VectorType>(type))
207 bitwidth = vectorType.getElementTypeBitWidth();
212 Value input = adaptor.getOperand();
217 Value msb = rewriter.
create<spirv::GLFindUMsbOp>(loc, input);
222 Value subMsb = rewriter.
create<spirv::ISubOp>(loc, val31, msb);
227 Value subInput = rewriter.
create<spirv::ISubOp>(loc, val32, input);
228 Value cmp = rewriter.
create<spirv::ULessThanEqualOp>(loc, input, val1);
239 template <
typename ExpOp>
244 matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
246 assert(adaptor.getOperands().size() == 1);
252 Type type = this->getTypeConverter()->convertType(operation.getType());
256 Value exp = rewriter.
create<ExpOp>(loc, type, adaptor.getOperand());
257 auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
267 template <
typename LogOp>
272 matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
274 assert(adaptor.getOperands().size() == 1);
280 Type type = this->getTypeConverter()->convertType(operation.getType());
284 auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
286 rewriter.
create<spirv::FAddOp>(loc, one, adaptor.getOperand());
299 template <
typename MathLogOp,
typename SpirvLogOp>
304 static constexpr
double log2Reciprocal =
305 1.442695040888963407359924681001892137426645954152985934135449407;
306 static constexpr
double log10Reciprocal =
307 0.4342944819032518276511289189166050822943970058036665661144537832;
310 matchAndRewrite(MathLogOp operation, OpAdaptor adaptor,
312 assert(adaptor.getOperands().size() == 1);
318 Type type = this->getTypeConverter()->convertType(operation.getType());
322 auto getConstantValue = [&](
double value) {
323 if (
auto floatType = dyn_cast<FloatType>(type)) {
324 return rewriter.
create<spirv::ConstantOp>(
327 if (
auto vectorType = dyn_cast<VectorType>(type)) {
328 Type elemType = vectorType.getElementType();
330 if (isa<FloatType>(elemType)) {
331 return rewriter.
create<spirv::ConstantOp>(
338 llvm_unreachable(
"unimplemented types for log2/log10");
341 Value constantValue = getConstantValue(
342 std::is_same<MathLogOp, math::Log2Op>() ? log2Reciprocal
344 Value log = rewriter.
create<SpirvLogOp>(loc, adaptor.getOperand());
356 matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
361 Type dstType = getTypeConverter()->convertType(powfOp.getType());
366 FloatType scalarFloatType;
367 if (
auto scalarType = dyn_cast<FloatType>(powfOp.getType())) {
368 scalarFloatType = scalarType;
369 }
else if (
auto vectorType = dyn_cast<VectorType>(powfOp.getType())) {
370 scalarFloatType = cast<FloatType>(vectorType.getElementType());
377 Type intType = scalarIntType;
378 auto operandType = adaptor.getRhs().getType();
379 if (
auto vectorType = dyn_cast<VectorType>(operandType)) {
380 auto shape = vectorType.getShape();
389 rewriter.
create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
395 Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter);
397 rewriter.
create<spirv::FRemOp>(loc, adaptor.getRhs(), floatOne);
398 Value expRemNonZero =
399 rewriter.
create<spirv::FOrdNotEqualOp>(loc, expRem, zero);
400 Value cmpNegativeWithFractionalExp =
401 rewriter.
create<spirv::LogicalAndOp>(loc, expRemNonZero, lessThan);
403 const auto &floatSemantics = scalarFloatType.getFloatSemantics();
404 const auto nan = APFloat::getNaN(floatSemantics);
406 if (
auto vectorType = dyn_cast<VectorType>(operandType))
410 rewriter.
create<spirv::ConstantOp>(loc, operandType, nanAttr);
412 loc, cmpNegativeWithFractionalExp, NanValue, adaptor.getLhs());
421 rewriter.
create<spirv::ConvertFToSOp>(loc, intType, adaptor.getRhs());
422 Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter);
423 Value bitwiseAndOne =
424 rewriter.
create<spirv::BitwiseAndOp>(loc, intRhs, intOne);
425 Value isOdd = rewriter.
create<spirv::IEqualOp>(loc, bitwiseAndOne, intOne);
428 Value pow = rewriter.
create<spirv::GLPowOp>(loc,
abs, adaptor.getRhs());
429 Value negate = rewriter.
create<spirv::FNegateOp>(loc, pow);
432 rewriter.
create<spirv::LogicalAndOp>(loc, lessThan, isOdd);
444 matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
450 Value operand = roundOp.getOperand();
455 auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
457 if (VectorType vty = dyn_cast<VectorType>(ty)) {
458 half = rewriter.
create<spirv::ConstantOp>(
463 half = rewriter.
create<spirv::ConstantOp>(
467 auto abs = rewriter.
create<spirv::GLFAbsOp>(loc, operand);
471 rewriter.
create<spirv::FOrdGreaterThanEqualOp>(loc, sub, half);
472 auto select = rewriter.
create<spirv::SelectOp>(loc, greater, one, zero);
473 auto add = rewriter.
create<spirv::FAddOp>(loc,
floor, select);
493 .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>,
494 Log2Log10OpPattern<math::Log2Op, spirv::GLLogOp>,
495 Log2Log10OpPattern<math::Log10Op, spirv::GLLogOp>,
496 ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
497 CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
498 CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
499 CheckedElementwiseOpPattern<math::AtanOp, spirv::GLAtanOp>,
500 CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
501 CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
502 CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
503 CheckedElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
504 CheckedElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
505 CheckedElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
506 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
507 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
508 CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
509 CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
510 CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>,
511 CheckedElementwiseOpPattern<math::TanOp, spirv::GLTanOp>,
512 CheckedElementwiseOpPattern<math::AsinOp, spirv::GLAsinOp>,
513 CheckedElementwiseOpPattern<math::AcosOp, spirv::GLAcosOp>,
514 CheckedElementwiseOpPattern<math::SinhOp, spirv::GLSinhOp>,
515 CheckedElementwiseOpPattern<math::CoshOp, spirv::GLCoshOp>,
516 CheckedElementwiseOpPattern<math::AsinhOp, spirv::GLAsinhOp>,
517 CheckedElementwiseOpPattern<math::AcoshOp, spirv::GLAcoshOp>,
518 CheckedElementwiseOpPattern<math::AtanhOp, spirv::GLAtanhOp>>(
519 typeConverter,
patterns.getContext());
522 patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
523 Log2Log10OpPattern<math::Log2Op, spirv::CLLogOp>,
524 Log2Log10OpPattern<math::Log10Op, spirv::CLLogOp>,
525 CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
526 CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
527 CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
528 CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>,
529 CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
530 CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
531 CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
532 CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
533 CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
534 CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
535 CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
536 CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
537 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
538 CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
539 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
540 CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
541 CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
542 CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>,
543 CheckedElementwiseOpPattern<math::TanOp, spirv::CLTanOp>,
544 CheckedElementwiseOpPattern<math::AsinOp, spirv::CLAsinOp>,
545 CheckedElementwiseOpPattern<math::AcosOp, spirv::CLAcosOp>,
546 CheckedElementwiseOpPattern<math::SinhOp, spirv::CLSinhOp>,
547 CheckedElementwiseOpPattern<math::CoshOp, spirv::CLCoshOp>,
548 CheckedElementwiseOpPattern<math::AsinhOp, spirv::CLAsinhOp>,
549 CheckedElementwiseOpPattern<math::AcoshOp, spirv::CLAcoshOp>,
550 CheckedElementwiseOpPattern<math::AtanhOp, spirv::CLAtanhOp>>(
551 typeConverter,
patterns.getContext());
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter, Operation *sourceOp)
Check if all sourceOp types are supported by math-to-spirv conversion.
static bool isSupportedSourceType(Type originalType)
Check if the type is supported by math-to-spirv conversion.
static Value getScalarOrVectorI32Constant(Type type, int value, OpBuilder &builder, Location loc)
Creates a 32-bit scalar/vector integer constant.
Attributes are known-constant values of operations.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
This class implements a pattern rewriter for use with ConversionPatterns.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
DynamicAPInt floor(const Fraction &f)
Fraction abs(const Fraction &f)
Include the generated interface declarations.
void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Math ops to SPIR-V ops.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.