18 #include "llvm/Support/FormatVariadic.h"
29 assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&
30 "Not an integer dot product op?");
31 assert(op->getNumResults() == 1 &&
"Expected a single result");
35 Type factorTy = op->getOperand(0).getType();
36 if (
auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
37 auto packedVectorFormat =
38 llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
40 if (!packedVectorFormat)
41 return op->emitOpError(
"requires Packed Vector Format attribute for "
42 "integer vector operands");
44 assert(packedVectorFormat.getValue() ==
45 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&
46 "Unknown Packed Vector Format");
47 if (intTy.getWidth() != 32)
48 return op->emitOpError(
49 llvm::formatv(
"with specified Packed Vector Format ({0}) requires "
50 "integer vector operands to be 32-bits wide",
51 packedVectorFormat.getValue()));
54 return op->emitOpError(llvm::formatv(
55 "with invalid format attribute for vector operands of type '{0}'",
59 Type resultTy = op->getResultTypes().front();
62 if (factorBitWidth > resultBitWidth)
63 return op->emitOpError(
64 llvm::formatv(
"result type has insufficient bit-width ({0} bits) "
65 "for the specified vector operand type ({1} bits)",
66 resultBitWidth, factorBitWidth));
72 return spirv::Version::V_1_0;
76 return spirv::Version::V_1_6;
83 static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product;
91 static const auto dotProductCap = spirv::Capability::DotProduct;
92 static const auto dotProductInput4x8BitPackedCap =
93 spirv::Capability::DotProductInput4x8BitPacked;
94 static const auto dotProductInput4x8BitCap =
95 spirv::Capability::DotProductInput4x8Bit;
96 static const auto dotProductInputAllCap =
97 spirv::Capability::DotProductInputAll;
101 Type factorTy = op->getOperand(0).getType();
102 if (
auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
103 auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
105 if (formatAttr.getValue() ==
106 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
107 capabilities.push_back(dotProductInput4x8BitPackedCap);
112 auto vecTy = llvm::cast<VectorType>(factorTy);
113 if (vecTy.getElementTypeBitWidth() == 8) {
114 capabilities.push_back(dotProductInput4x8BitCap);
118 capabilities.push_back(dotProductInputAllCap);
122 #define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \
123 LogicalResult OpName::verify() { return verifyIntegerDotProduct(*this); } \
124 SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() { \
125 return getIntegerDotProductExtensions(); \
127 SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() { \
128 return getIntegerDotProductCapabilities(*this); \
130 std::optional<spirv::Version> OpName::getMinVersion() { \
131 return getIntegerDotProductMinVersion(); \
133 std::optional<spirv::Version> OpName::getMaxVersion() { \
134 return getIntegerDotProductMaxVersion(); \
144 #undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP
#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName)
Operation is the basic unit of execution within MLIR.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
constexpr char kPackedVectorFormatAttrName[]
static std::optional< spirv::Version > getIntegerDotProductMinVersion()
static SmallVector< ArrayRef< spirv::Capability >, 1 > getIntegerDotProductCapabilities(Operation *op)
static std::optional< spirv::Version > getIntegerDotProductMaxVersion()
static LogicalResult verifyIntegerDotProduct(Operation *op)
static SmallVector< ArrayRef< spirv::Extension >, 1 > getIntegerDotProductExtensions()
unsigned getBitWidth(Type type)
Returns the bit width of the type.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
This class represents an efficient way to signal success or failure.