13 #include "../SPIRVCommon/Pattern.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/FormatVariadic.h"
25 #define DEBUG_TYPE "math-to-spirv-pattern"
37 if (
auto vectorType = dyn_cast<VectorType>(type)) {
38 if (!vectorType.getElementType().isInteger(32))
41 return builder.
create<spirv::ConstantOp>(loc, type,
45 return builder.
create<spirv::ConstantOp>(loc, type,
58 if (
auto vecTy = dyn_cast<VectorType>(originalType)) {
59 if (!vecTy.getElementType().isIntOrIndexOrFloat())
61 if (vecTy.isScalable())
63 if (vecTy.getRank() > 1)
80 for (
Type ty : allTypes) {
85 "unsupported source type for Math to SPIR-V conversion: {0}",
104 template <
typename Op,
typename SPIRVOp>
105 struct CheckedElementwiseOpPattern final
108 using BasePattern::BasePattern;
111 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
116 return BasePattern::matchAndRewrite(op, adaptor, rewriter);
125 matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
131 Type type = getTypeConverter()->convertType(copySignOp.getType());
136 if (
auto scalarType = dyn_cast<FloatType>(copySignOp.getType())) {
137 floatType = scalarType;
138 }
else if (
auto vectorType = dyn_cast<VectorType>(copySignOp.getType())) {
139 floatType = cast<FloatType>(vectorType.getElementType());
145 int bitwidth = floatType.
getWidth();
147 uint64_t intValue = uint64_t(1) << (bitwidth - 1);
149 Value signMask = rewriter.
create<spirv::ConstantOp>(
151 Value valueMask = rewriter.
create<spirv::ConstantOp>(
154 if (
auto vectorType = dyn_cast<VectorType>(type)) {
155 assert(vectorType.getRank() == 1);
156 int count = vectorType.getNumElements();
161 rewriter.
create<spirv::CompositeConstructOp>(loc, intType, signSplat);
164 valueMask = rewriter.
create<spirv::CompositeConstructOp>(loc, intType,
169 rewriter.
create<spirv::BitcastOp>(loc, intType, adaptor.getLhs());
171 rewriter.
create<spirv::BitcastOp>(loc, intType, adaptor.getRhs());
173 Value value = rewriter.
create<spirv::BitwiseAndOp>(
174 loc, intType,
ValueRange{lhsCast, valueMask});
178 Value result = rewriter.
create<spirv::BitwiseOrOp>(loc, intType,
190 struct CountLeadingZerosPattern final
195 matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
200 Type type = getTypeConverter()->convertType(countOp.getType());
205 unsigned bitwidth = 0;
206 if (isa<IntegerType>(type))
208 if (
auto vectorType = dyn_cast<VectorType>(type))
209 bitwidth = vectorType.getElementTypeBitWidth();
214 Value input = adaptor.getOperand();
219 Value msb = rewriter.
create<spirv::GLFindUMsbOp>(loc, input);
224 Value subMsb = rewriter.
create<spirv::ISubOp>(loc, val31, msb);
229 Value subInput = rewriter.
create<spirv::ISubOp>(loc, val32, input);
230 Value cmp = rewriter.
create<spirv::ULessThanEqualOp>(loc, input, val1);
241 template <
typename ExpOp>
246 matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
248 assert(adaptor.getOperands().size() == 1);
254 Type type = this->getTypeConverter()->convertType(operation.getType());
258 Value exp = rewriter.
create<ExpOp>(loc, type, adaptor.getOperand());
259 auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
269 template <
typename LogOp>
274 matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
276 assert(adaptor.getOperands().size() == 1);
282 Type type = this->getTypeConverter()->convertType(operation.getType());
286 auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
288 rewriter.
create<spirv::FAddOp>(loc, one, adaptor.getOperand());
301 template <
typename MathLogOp,
typename SpirvLogOp>
306 static constexpr
double log2Reciprocal =
307 1.442695040888963407359924681001892137426645954152985934135449407;
308 static constexpr
double log10Reciprocal =
309 0.4342944819032518276511289189166050822943970058036665661144537832;
312 matchAndRewrite(MathLogOp operation, OpAdaptor adaptor,
314 assert(adaptor.getOperands().size() == 1);
320 Type type = this->getTypeConverter()->convertType(operation.getType());
324 auto getConstantValue = [&](
double value) {
325 if (
auto floatType = dyn_cast<FloatType>(type)) {
326 return rewriter.
create<spirv::ConstantOp>(
329 if (
auto vectorType = dyn_cast<VectorType>(type)) {
330 Type elemType = vectorType.getElementType();
332 if (isa<FloatType>(elemType)) {
333 return rewriter.
create<spirv::ConstantOp>(
340 llvm_unreachable(
"unimplemented types for log2/log10");
343 Value constantValue = getConstantValue(
344 std::is_same<MathLogOp, math::Log2Op>() ? log2Reciprocal
346 Value log = rewriter.
create<SpirvLogOp>(loc, adaptor.getOperand());
358 matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
363 Type dstType = getTypeConverter()->convertType(powfOp.getType());
369 if (
auto scalarType = dyn_cast<FloatType>(powfOp.getType())) {
370 scalarFloatType = scalarType;
371 }
else if (
auto vectorType = dyn_cast<VectorType>(powfOp.getType())) {
372 scalarFloatType = cast<FloatType>(vectorType.getElementType());
379 Type intType = scalarIntType;
380 auto operandType = adaptor.getRhs().getType();
381 if (
auto vectorType = dyn_cast<VectorType>(operandType)) {
382 auto shape = vectorType.getShape();
391 rewriter.
create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
397 Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter);
399 rewriter.
create<spirv::FRemOp>(loc, adaptor.getRhs(), floatOne);
400 Value expRemNonZero =
401 rewriter.
create<spirv::FOrdNotEqualOp>(loc, expRem, zero);
402 Value cmpNegativeWithFractionalExp =
403 rewriter.
create<spirv::LogicalAndOp>(loc, expRemNonZero, lessThan);
406 const auto nan = APFloat::getNaN(floatSemantics);
408 if (
auto vectorType = dyn_cast<VectorType>(operandType))
412 rewriter.
create<spirv::ConstantOp>(loc, operandType, nanAttr);
414 loc, cmpNegativeWithFractionalExp, NanValue, adaptor.getLhs());
423 rewriter.
create<spirv::ConvertFToSOp>(loc, intType, adaptor.getRhs());
424 Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter);
425 Value bitwiseAndOne =
426 rewriter.
create<spirv::BitwiseAndOp>(loc, intRhs, intOne);
427 Value isOdd = rewriter.
create<spirv::IEqualOp>(loc, bitwiseAndOne, intOne);
430 Value pow = rewriter.
create<spirv::GLPowOp>(loc,
abs, adaptor.getRhs());
431 Value negate = rewriter.
create<spirv::FNegateOp>(loc, pow);
434 rewriter.
create<spirv::LogicalAndOp>(loc, lessThan, isOdd);
446 matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
452 Value operand = roundOp.getOperand();
457 auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
459 if (VectorType vty = dyn_cast<VectorType>(ty)) {
460 half = rewriter.
create<spirv::ConstantOp>(
465 half = rewriter.
create<spirv::ConstantOp>(
469 auto abs = rewriter.
create<spirv::GLFAbsOp>(loc, operand);
473 rewriter.
create<spirv::FOrdGreaterThanEqualOp>(loc, sub, half);
474 auto select = rewriter.
create<spirv::SelectOp>(loc, greater, one, zero);
475 auto add = rewriter.
create<spirv::FAddOp>(loc,
floor, select);
495 .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>,
496 Log2Log10OpPattern<math::Log2Op, spirv::GLLogOp>,
497 Log2Log10OpPattern<math::Log10Op, spirv::GLLogOp>,
498 ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
499 CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
500 CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
501 CheckedElementwiseOpPattern<math::AtanOp, spirv::GLAtanOp>,
502 CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
503 CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
504 CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
505 CheckedElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
506 CheckedElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
507 CheckedElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
508 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
509 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
510 CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
511 CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
512 CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>>(
513 typeConverter,
patterns.getContext());
516 patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
517 Log2Log10OpPattern<math::Log2Op, spirv::CLLogOp>,
518 Log2Log10OpPattern<math::Log10Op, spirv::CLLogOp>,
519 CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
520 CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
521 CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
522 CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>,
523 CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
524 CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
525 CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
526 CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
527 CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
528 CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
529 CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
530 CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
531 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
532 CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
533 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
534 CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
535 CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
536 CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>>(
537 typeConverter,
patterns.getContext());
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter, Operation *sourceOp)
Check if all sourceOp types are supported by math-to-spirv conversion.
static bool isSupportedSourceType(Type originalType)
Check if the type is supported by math-to-spirv conversion.
static Value getScalarOrVectorI32Constant(Type type, int value, OpBuilder &builder, Location loc)
Creates a 32-bit scalar/vector integer constant.
Attributes are known-constant values of operations.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
This class implements a pattern rewriter for use with ConversionPatterns.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getWidth()
Return the bitwidth of this float type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
DynamicAPInt floor(const Fraction &f)
Fraction abs(const Fraction &f)
Include the generated interface declarations.
void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Math ops to SPIR-V ops.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.