13 #include "../SPIRVCommon/Pattern.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/FormatVariadic.h"
25 #define DEBUG_TYPE "math-to-spirv-pattern"
37 if (
auto vectorType = dyn_cast<VectorType>(type)) {
38 if (!vectorType.getElementType().isInteger(32))
41 return builder.
create<spirv::ConstantOp>(loc, type,
45 return builder.
create<spirv::ConstantOp>(loc, type,
58 if (
auto vecTy = dyn_cast<VectorType>(originalType)) {
59 if (!vecTy.getElementType().isIntOrIndexOrFloat())
61 if (vecTy.isScalable())
63 if (vecTy.getRank() > 1)
80 for (
Type ty : allTypes) {
85 "unsupported source type for Math to SPIR-V conversion: {0}",
104 template <
typename Op,
typename SPIRVOp>
105 struct CheckedElementwiseOpPattern final
108 using BasePattern::BasePattern;
111 matchAndRewrite(
Op op,
typename Op::Adaptor adaptor,
116 return BasePattern::matchAndRewrite(op, adaptor, rewriter);
125 matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
131 Type type = getTypeConverter()->convertType(copySignOp.getType());
136 if (
auto scalarType = dyn_cast<FloatType>(copySignOp.getType())) {
137 floatType = scalarType;
138 }
else if (
auto vectorType = dyn_cast<VectorType>(copySignOp.getType())) {
139 floatType = cast<FloatType>(vectorType.getElementType());
145 int bitwidth = floatType.
getWidth();
147 uint64_t intValue = uint64_t(1) << (bitwidth - 1);
149 Value signMask = rewriter.
create<spirv::ConstantOp>(
151 Value valueMask = rewriter.
create<spirv::ConstantOp>(
154 if (
auto vectorType = dyn_cast<VectorType>(type)) {
155 assert(vectorType.getRank() == 1);
156 int count = vectorType.getNumElements();
161 rewriter.
create<spirv::CompositeConstructOp>(loc, intType, signSplat);
164 valueMask = rewriter.
create<spirv::CompositeConstructOp>(loc, intType,
169 rewriter.
create<spirv::BitcastOp>(loc, intType, adaptor.getLhs());
171 rewriter.
create<spirv::BitcastOp>(loc, intType, adaptor.getRhs());
173 Value value = rewriter.
create<spirv::BitwiseAndOp>(
174 loc, intType,
ValueRange{lhsCast, valueMask});
178 Value result = rewriter.
create<spirv::BitwiseOrOp>(loc, intType,
190 struct CountLeadingZerosPattern final
195 matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
200 Type type = getTypeConverter()->convertType(countOp.getType());
205 unsigned bitwidth = 0;
206 if (isa<IntegerType>(type))
208 if (
auto vectorType = dyn_cast<VectorType>(type))
209 bitwidth = vectorType.getElementTypeBitWidth();
214 Value input = adaptor.getOperand();
219 Value msb = rewriter.
create<spirv::GLFindUMsbOp>(loc, input);
224 Value subMsb = rewriter.
create<spirv::ISubOp>(loc, val31, msb);
229 Value subInput = rewriter.
create<spirv::ISubOp>(loc, val32, input);
230 Value cmp = rewriter.
create<spirv::ULessThanEqualOp>(loc, input, val1);
241 template <
typename ExpOp>
246 matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
248 assert(adaptor.getOperands().size() == 1);
254 Type type = this->getTypeConverter()->convertType(operation.getType());
258 Value exp = rewriter.
create<ExpOp>(loc, type, adaptor.getOperand());
259 auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
269 template <
typename LogOp>
274 matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
276 assert(adaptor.getOperands().size() == 1);
282 Type type = this->getTypeConverter()->convertType(operation.getType());
286 auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
288 rewriter.
create<spirv::FAddOp>(loc, one, adaptor.getOperand());
299 matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
304 Type dstType = getTypeConverter()->convertType(powfOp.getType());
310 if (
auto scalarType = dyn_cast<FloatType>(powfOp.getType())) {
311 scalarFloatType = scalarType;
312 }
else if (
auto vectorType = dyn_cast<VectorType>(powfOp.getType())) {
313 scalarFloatType = cast<FloatType>(vectorType.getElementType());
320 Type intType = scalarIntType;
321 if (
auto vectorType = dyn_cast<VectorType>(adaptor.getRhs().getType())) {
322 auto shape = vectorType.getShape();
332 rewriter.
create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
333 Value abs = rewriter.
create<spirv::GLFAbsOp>(loc, adaptor.getLhs());
341 rewriter.
create<spirv::ConvertFToSOp>(loc, intType, adaptor.getRhs());
342 Value intOne = spirv::ConstantOp::getOne(intType, loc, rewriter);
343 Value bitwiseAndOne =
344 rewriter.
create<spirv::BitwiseAndOp>(loc, intRhs, intOne);
345 Value isOdd = rewriter.
create<spirv::IEqualOp>(loc, bitwiseAndOne, intOne);
348 Value pow = rewriter.
create<spirv::GLPowOp>(loc,
abs, adaptor.getRhs());
349 Value negate = rewriter.
create<spirv::FNegateOp>(loc, pow);
352 rewriter.
create<spirv::LogicalAndOp>(loc, lessThan, isOdd);
364 matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
370 Value operand = roundOp.getOperand();
375 auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
377 if (VectorType vty = dyn_cast<VectorType>(ty)) {
378 half = rewriter.
create<spirv::ConstantOp>(
383 half = rewriter.
create<spirv::ConstantOp>(
387 auto abs = rewriter.
create<spirv::GLFAbsOp>(loc, operand);
391 rewriter.
create<spirv::FOrdGreaterThanEqualOp>(loc, sub, half);
392 auto select = rewriter.
create<spirv::SelectOp>(loc, greater, one, zero);
393 auto add = rewriter.
create<spirv::FAddOp>(loc,
floor, select);
409 patterns.
add<CopySignPattern>(typeConverter, patterns.
getContext());
413 .
add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>,
414 ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
415 CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
416 CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
417 CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
418 CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
419 CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
420 CheckedElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
421 CheckedElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
422 CheckedElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
423 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::GLRoundEvenOp>,
424 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
425 CheckedElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
426 CheckedElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
427 CheckedElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>>(
431 patterns.
add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
432 CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
433 CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
434 CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
435 CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
436 CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
437 CheckedElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
438 CheckedElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
439 CheckedElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
440 CheckedElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
441 CheckedElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
442 CheckedElementwiseOpPattern<math::RoundEvenOp, spirv::CLRintOp>,
443 CheckedElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
444 CheckedElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
445 CheckedElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
446 CheckedElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
447 CheckedElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>>(
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static LogicalResult checkSourceOpTypes(ConversionPatternRewriter &rewriter, Operation *sourceOp)
Check if all sourceOp types are supported by math-to-spirv conversion.
static bool isSupportedSourceType(Type originalType)
Check if the type is supported by math-to-spirv conversion.
static Value getScalarOrVectorI32Constant(Type type, int value, OpBuilder &builder, Location loc)
Creates a 32-bit scalar/vector integer constant.
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
DenseIntElementsAttr getI32VectorAttr(ArrayRef< int32_t > values)
This class implements a pattern rewriter for use with ConversionPatterns.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
unsigned getWidth()
Return the bitwidth of this float type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Type conversion from builtin types to SPIR-V types for shader interface.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
MPInt floor(const Fraction &f)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating Math ops to SPIR-V ops.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Converts elementwise unary, binary and ternary standard operations to SPIR-V operations.