13 #include "../SPIRVCommon/Pattern.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/FormatVariadic.h"
25 #define DEBUG_TYPE "math-to-spirv-pattern"
37 if (
auto vectorType = dyn_cast<VectorType>(type)) {
38 if (!vectorType.getElementType().isInteger(32))
41 return builder.
create<spirv::ConstantOp>(loc, type,
45 return builder.
create<spirv::ConstantOp>(loc, type,
58 if (
auto vecTy = dyn_cast<VectorType>(originalType)) {
59 if (!vecTy.getElementType().isIntOrIndexOrFloat())
61 if (vecTy.isScalable())
63 if (vecTy.getRank() > 1)
80 for (
Type ty : allTypes) {
85 "unsupported source type for Math to SPIR-V conversion: {0}",
104 template <
typename Op,
typename SPIRVOp>
105 struct CheckedElementwiseOpPattern final
108 using BasePattern::BasePattern;
111 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
116 return BasePattern::matchAndRewrite(op, adaptor, rewriter);
125 matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
131 Type type = getTypeConverter()->convertType(copySignOp.getType());
136 if (
auto scalarType = dyn_cast<FloatType>(copySignOp.getType())) {
137 floatType = scalarType;
138 }
else if (
auto vectorType = dyn_cast<VectorType>(copySignOp.getType())) {
139 floatType = cast<FloatType>(vectorType.getElementType());
145 int bitwidth = floatType.
getWidth();
147 uint64_t intValue = uint64_t(1) << (bitwidth - 1);
149 Value signMask = rewriter.
create<spirv::ConstantOp>(
151 Value valueMask = rewriter.
create<spirv::ConstantOp>(
154 if (
auto vectorType = dyn_cast<VectorType>(type)) {
155 assert(vectorType.getRank() == 1);
156 int count = vectorType.getNumElements();
161 rewriter.
create<spirv::CompositeConstructOp>(loc, intType, signSplat);
164 valueMask = rewriter.
create<spirv::CompositeConstructOp>(loc, intType,
169 rewriter.
create<spirv::BitcastOp>(loc, intType, adaptor.getLhs());
171 rewriter.
create<spirv::BitcastOp>(loc, intType, adaptor.getRhs());
173 Value value = rewriter.
create<spirv::BitwiseAndOp>(
174 loc, intType,
ValueRange{lhsCast, valueMask});
178 Value result = rewriter.
create<spirv::BitwiseOrOp>(loc, intType,
190 struct CountLeadingZerosPattern final
195 matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
200 Type type = getTypeConverter()->convertType(countOp.getType());
205 unsigned bitwidth = 0;
206 if (isa<IntegerType>(type))
208 if (
auto vectorType = dyn_cast<VectorType>(type))
209 bitwidth = vectorType.getElementTypeBitWidth();
214 Value input = adaptor.getOperand();
219 Value msb = rewriter.
create<spirv::GLFindUMsbOp>(loc, input);
224 Value subMsb = rewriter.
create<spirv::ISubOp>(loc, val31, msb);
229 Value subInput = rewriter.
create<spirv::ISubOp>(loc, val32, input);
230 Value cmp = rewriter.
create<spirv::ULessThanEqualOp>(loc, input, val1);
241 template <
typename ExpOp>
246 matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
248 assert(adaptor.getOperands().size() == 1);
254 Type type = this->getTypeConverter()->convertType(operation.getType());
258 Value exp = rewriter.
create<ExpOp>(loc, type, adaptor.getOperand());
259 auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
269 template <
typename LogOp>
274 matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
276 assert(adaptor.getOperands().size() == 1);
282 Type type = this->getTypeConverter()->convertType(operation.getType());
286 auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
288 rewriter.
create<spirv::FAddOp>(loc, one, adaptor.getOperand());
301 template <
typename MathLogOp,
typename SpirvLogOp>
306 static constexpr
double log2Reciprocal =
307 1.442695040888963407359924681001892137426645954152985934135449407;
308 static constexpr
double log10Reciprocal =
309 0.4342944819032518276511289189166050822943970058036665661144537832;
312 matchAndRewrite(MathLogOp operation, OpAdaptor adaptor,
314 assert(adaptor.getOperands().size() == 1);
320 Type type = this->getTypeConverter()->convertType(operation.getType());
324 auto getConstantValue = [&](
double value) {
325 if (
auto floatType = dyn_cast<FloatType>(type)) {
326 return rewriter.
create<spirv::ConstantOp>(
329 if (
auto vectorType = dyn_cast<VectorType>(type)) {
330 Type elemType = vectorType.getElementType();
332 if (isa<FloatType>(elemType)) {
333 return rewriter.
create<spirv::ConstantOp>(
340 llvm_unreachable(
"unimplemented types for log2/log10");
343 Value constantValue = getConstantValue(
344 std::is_same<MathLogOp, math::Log2Op>() ? log2Reciprocal
346 Value log = rewriter.
create<SpirvLogOp>(loc, adaptor.getOperand());
358 matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
363 Type dstType = getTypeConverter()->convertType(powfOp.getType());
369 if (
auto scalarType = dyn_cast<FloatType>(powfOp.getType())) {
370 scalarFloatType = scalarType;
371 }
else if (
auto vectorType = dyn_cast<VectorType>(powfOp.getType())) {
372 scalarFloatType = cast<FloatType>(vectorType.getElementType());
379 Type intType = scalarIntType;
380 if (
auto vectorType = dyn_cast<VectorType>(adaptor.getRhs().getType())) {
381 auto shape = vectorType.getShape();
391 rewriter.
create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
392 Value abs = rewriter.
create<spirv::GLFAbsOp>(loc, adaptor.getLhs());
400 rewriter.
create<spirv::ConvertFToSOp>(loc, intType, adaptor.getRhs());
401 Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter);
402 Value bitwiseAndOne =
403 rewriter.
create<spirv::BitwiseAndOp>(loc, intRhs, intOne);
404 Value isOdd = rewriter.
create<spirv::IEqualOp>(loc, bitwiseAndOne, intOne);
407 Value pow = rewriter.
create<spirv::GLPowOp>(loc,
abs, adaptor.getRhs());
408 Value negate = rewriter.
create<spirv::FNegateOp>(loc, pow);
411 rewriter.
create<spirv::LogicalAndOp>(loc, lessThan, isOdd);
423 matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
429 Value operand = roundOp.getOperand();
434 auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
436 if (VectorType vty = dyn_cast<VectorType>(ty)) {
437 half = rewriter.
create<spirv::ConstantOp>(
442 half = rewriter.
create<spirv::ConstantOp>(
446 auto abs = rewriter.
create<spirv::GLFAbsOp>(loc, operand);
450 rewriter.
create<spirv::FOrdGreaterThanEqualOp>(loc, sub, half);
451 auto select = rewriter.
create<spirv::SelectOp>(loc, greater, one, zero);
452 auto add = rewriter.
create<spirv::FAddOp>(loc,
floor, select);
468 patterns.
add<CopySignPattern>(typeConverter, patterns.
getContext());
472 .
add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>,
473 Log2Log10OpPattern<math::Log2Op, spirv::GLLogOp>,
474 Log2Log10OpPattern<math::Log10Op, spirv::GLLogOp>,
475 ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
476 CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
477 CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
478 CheckedElementwiseOpPattern<math::AtanOp, spirv::GLAtanOp>,
479 CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
480 CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
481 CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
482 CheckedElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
483 CheckedElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
484 CheckedElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
485 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
486 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
487 CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
488 CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
489 CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>>(
493 patterns.
add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
494 Log2Log10OpPattern<math::Log2Op, spirv::CLLogOp>,
495 Log2Log10OpPattern<math::Log10Op, spirv::CLLogOp>,
496 CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
497 CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
498 CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
499 CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>,
500 CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
501 CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
502 CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
503 CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
504 CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
505 CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
506 CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
507 CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
508 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
509 CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
510 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
511 CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
512 CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
513 CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>>(
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter, Operation *sourceOp)
Check if all sourceOp types are supported by math-to-spirv conversion.
static bool isSupportedSourceType(Type originalType)
Check if the type is supported by math-to-spirv conversion.
static Value getScalarOrVectorI32Constant(Type type, int value, OpBuilder &builder, Location loc)
Creates a 32-bit scalar/vector integer constant.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
This class implements a pattern rewriter for use with ConversionPatterns.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static DenseFPElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseFPElementsAttr with the given arguments.
unsigned getWidth()
Return the bitwidth of this float type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
DynamicAPInt floor(const Fraction &f)
Fraction abs(const Fraction &f)
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Math ops to SPIR-V ops.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.