18#include "llvm/Support/FormatVariadic.h"
29 return spirv::Version::V_1_0;
33 return spirv::Version::V_1_6;
37 if (isa<BFloat16Type>(
getType())) {
38 static const auto extension = spirv::Extension::SPV_KHR_bfloat16;
45SmallVector<ArrayRef<spirv::Capability>, 1> DotOp::getCapabilities() {
46 if (isa<BFloat16Type>(
getType())) {
47 static const auto capability = spirv::Capability::BFloat16DotProductKHR;
54std::optional<spirv::Version> DotOp::getMinVersion() {
58std::optional<spirv::Version> DotOp::getMaxVersion() {
66template <
typename IntegerDotProductOpTy>
69 "Not an integer dot product op?");
70 assert(op->
getNumResults() == 1 &&
"Expected a single result");
75 StringAttr packedVectorFormatAttrName =
76 IntegerDotProductOpTy::getFormatAttrName(op->
getName());
77 if (
auto intTy = dyn_cast<IntegerType>(factorTy)) {
78 auto packedVectorFormat = dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
79 op->
getAttr(packedVectorFormatAttrName));
80 if (!packedVectorFormat)
81 return op->
emitOpError(
"requires Packed Vector Format attribute for "
82 "integer vector operands");
84 assert(packedVectorFormat.getValue() ==
85 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&
86 "Unknown Packed Vector Format");
87 if (intTy.getWidth() != 32)
89 llvm::formatv(
"with specified Packed Vector Format ({0}) requires "
90 "integer vector operands to be 32-bits wide",
91 packedVectorFormat.getValue()));
93 if (op->
hasAttr(packedVectorFormatAttrName))
95 "with invalid format attribute for vector operands of type '{0}'",
102 if (factorBitWidth > resultBitWidth)
104 llvm::formatv(
"result type has insufficient bit-width ({0} bits) "
105 "for the specified vector operand type ({1} bits)",
106 resultBitWidth, factorBitWidth));
115 static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product;
119template <
typename IntegerDotProductOpTy>
124 static const auto dotProductCap = spirv::Capability::DotProduct;
125 static const auto dotProductInput4x8BitPackedCap =
126 spirv::Capability::DotProductInput4x8BitPacked;
127 static const auto dotProductInput4x8BitCap =
128 spirv::Capability::DotProductInput4x8Bit;
129 static const auto dotProductInputAllCap =
130 spirv::Capability::DotProductInputAll;
135 StringAttr packedVectorFormatAttrName =
136 IntegerDotProductOpTy::getFormatAttrName(op->
getName());
137 if (
auto intTy = dyn_cast<IntegerType>(factorTy)) {
138 auto formatAttr = cast<spirv::PackedVectorFormatAttr>(
139 op->
getAttr(packedVectorFormatAttrName));
140 if (formatAttr.getValue() ==
141 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
142 capabilities.push_back(dotProductInput4x8BitPackedCap);
147 auto vecTy = cast<VectorType>(factorTy);
148 if (vecTy.getElementTypeBitWidth() == 8) {
149 capabilities.push_back(dotProductInput4x8BitCap);
153 capabilities.push_back(dotProductInputAllCap);
157#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \
158 LogicalResult OpName::verify() { \
159 return verifyIntegerDotProduct<OpName>(*this); \
161 SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() { \
162 return getIntegerDotProductExtensions(); \
164 SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() { \
165 return getIntegerDotProductCapabilities<OpName>(*this); \
167 std::optional<spirv::Version> OpName::getMinVersion() { \
168 return getDotProductMinVersion(); \
170 std::optional<spirv::Version> OpName::getMaxVersion() { \
171 return getDotProductMaxVersion(); \
181#undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP
#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName)
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
unsigned getNumOperands()
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Type front()
Return first type in the range.
Type getType() const
Return the type of this value.
static LogicalResult verifyIntegerDotProduct(Operation *op)
static SmallVector< ArrayRef< spirv::Capability >, 1 > getIntegerDotProductCapabilities(Operation *op)
static std::optional< spirv::Version > getDotProductMaxVersion()
static SmallVector< ArrayRef< spirv::Extension >, 1 > getIntegerDotProductExtensions()
static std::optional< spirv::Version > getDotProductMinVersion()
unsigned getBitWidth(Type type)
Returns the bit width of the type.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.