18 #include "llvm/Support/FormatVariadic.h"
28 template <
typename IntegerDotProductOpTy>
31 "Not an integer dot product op?");
32 assert(op->
getNumResults() == 1 &&
"Expected a single result");
37 StringAttr packedVectorFormatAttrName =
38 IntegerDotProductOpTy::getFormatAttrName(op->
getName());
39 if (
auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
40 auto packedVectorFormat =
41 llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
42 op->
getAttr(packedVectorFormatAttrName));
43 if (!packedVectorFormat)
44 return op->
emitOpError(
"requires Packed Vector Format attribute for "
45 "integer vector operands");
47 assert(packedVectorFormat.getValue() ==
48 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&
49 "Unknown Packed Vector Format");
50 if (intTy.getWidth() != 32)
52 llvm::formatv(
"with specified Packed Vector Format ({0}) requires "
53 "integer vector operands to be 32-bits wide",
54 packedVectorFormat.getValue()));
56 if (op->
hasAttr(packedVectorFormatAttrName))
58 "with invalid format attribute for vector operands of type '{0}'",
65 if (factorBitWidth > resultBitWidth)
67 llvm::formatv(
"result type has insufficient bit-width ({0} bits) "
68 "for the specified vector operand type ({1} bits)",
69 resultBitWidth, factorBitWidth));
75 return spirv::Version::V_1_0;
79 return spirv::Version::V_1_6;
86 static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product;
90 template <
typename IntegerDotProductOpTy>
95 static const auto dotProductCap = spirv::Capability::DotProduct;
96 static const auto dotProductInput4x8BitPackedCap =
97 spirv::Capability::DotProductInput4x8BitPacked;
98 static const auto dotProductInput4x8BitCap =
99 spirv::Capability::DotProductInput4x8Bit;
100 static const auto dotProductInputAllCap =
101 spirv::Capability::DotProductInputAll;
106 StringAttr packedVectorFormatAttrName =
107 IntegerDotProductOpTy::getFormatAttrName(op->
getName());
108 if (
auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
109 auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
110 op->
getAttr(packedVectorFormatAttrName));
111 if (formatAttr.getValue() ==
112 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
113 capabilities.push_back(dotProductInput4x8BitPackedCap);
118 auto vecTy = llvm::cast<VectorType>(factorTy);
119 if (vecTy.getElementTypeBitWidth() == 8) {
120 capabilities.push_back(dotProductInput4x8BitCap);
124 capabilities.push_back(dotProductInputAllCap);
128 #define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \
129 LogicalResult OpName::verify() { \
130 return verifyIntegerDotProduct<OpName>(*this); \
132 SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() { \
133 return getIntegerDotProductExtensions(); \
135 SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() { \
136 return getIntegerDotProductCapabilities<OpName>(*this); \
138 std::optional<spirv::Version> OpName::getMinVersion() { \
139 return getIntegerDotProductMinVersion(); \
141 std::optional<spirv::Version> OpName::getMaxVersion() { \
142 return getIntegerDotProductMaxVersion(); \
152 #undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP
#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName)
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
unsigned getNumOperands()
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Type front()
Return first type in the range.
Type getType() const
Return the type of this value.
static LogicalResult verifyIntegerDotProduct(Operation *op)
static std::optional< spirv::Version > getIntegerDotProductMinVersion()
static SmallVector< ArrayRef< spirv::Capability >, 1 > getIntegerDotProductCapabilities(Operation *op)
static std::optional< spirv::Version > getIntegerDotProductMaxVersion()
static SmallVector< ArrayRef< spirv::Extension >, 1 > getIntegerDotProductExtensions()
unsigned getBitWidth(Type type)
Returns the bit width of the type.