18 #include "llvm/Support/FormatVariadic.h"
29 return spirv::Version::V_1_0;
33 return spirv::Version::V_1_6;
37 if (isa<BFloat16Type>(
getType())) {
38 static const auto extension = spirv::Extension::SPV_KHR_bfloat16;
45 SmallVector<ArrayRef<spirv::Capability>, 1> DotOp::getCapabilities() {
46 if (isa<BFloat16Type>(
getType())) {
47 static const auto capability = spirv::Capability::BFloat16DotProductKHR;
54 std::optional<spirv::Version> DotOp::getMinVersion() {
58 std::optional<spirv::Version> DotOp::getMaxVersion() {
66 template <
typename IntegerDotProductOpTy>
69 "Not an integer dot product op?");
70 assert(op->
getNumResults() == 1 &&
"Expected a single result");
75 StringAttr packedVectorFormatAttrName =
76 IntegerDotProductOpTy::getFormatAttrName(op->
getName());
77 if (
auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
78 auto packedVectorFormat =
79 llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
80 op->
getAttr(packedVectorFormatAttrName));
81 if (!packedVectorFormat)
82 return op->
emitOpError(
"requires Packed Vector Format attribute for "
83 "integer vector operands");
85 assert(packedVectorFormat.getValue() ==
86 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&
87 "Unknown Packed Vector Format");
88 if (intTy.getWidth() != 32)
90 llvm::formatv(
"with specified Packed Vector Format ({0}) requires "
91 "integer vector operands to be 32-bits wide",
92 packedVectorFormat.getValue()));
94 if (op->
hasAttr(packedVectorFormatAttrName))
96 "with invalid format attribute for vector operands of type '{0}'",
103 if (factorBitWidth > resultBitWidth)
105 llvm::formatv(
"result type has insufficient bit-width ({0} bits) "
106 "for the specified vector operand type ({1} bits)",
107 resultBitWidth, factorBitWidth));
116 static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product;
120 template <
typename IntegerDotProductOpTy>
125 static const auto dotProductCap = spirv::Capability::DotProduct;
126 static const auto dotProductInput4x8BitPackedCap =
127 spirv::Capability::DotProductInput4x8BitPacked;
128 static const auto dotProductInput4x8BitCap =
129 spirv::Capability::DotProductInput4x8Bit;
130 static const auto dotProductInputAllCap =
131 spirv::Capability::DotProductInputAll;
136 StringAttr packedVectorFormatAttrName =
137 IntegerDotProductOpTy::getFormatAttrName(op->
getName());
138 if (
auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
139 auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
140 op->
getAttr(packedVectorFormatAttrName));
141 if (formatAttr.getValue() ==
142 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
143 capabilities.push_back(dotProductInput4x8BitPackedCap);
148 auto vecTy = llvm::cast<VectorType>(factorTy);
149 if (vecTy.getElementTypeBitWidth() == 8) {
150 capabilities.push_back(dotProductInput4x8BitCap);
154 capabilities.push_back(dotProductInputAllCap);
158 #define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \
159 LogicalResult OpName::verify() { \
160 return verifyIntegerDotProduct<OpName>(*this); \
162 SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() { \
163 return getIntegerDotProductExtensions(); \
165 SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() { \
166 return getIntegerDotProductCapabilities<OpName>(*this); \
168 std::optional<spirv::Version> OpName::getMinVersion() { \
169 return getDotProductMinVersion(); \
171 std::optional<spirv::Version> OpName::getMaxVersion() { \
172 return getDotProductMaxVersion(); \
182 #undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP
#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName)
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
unsigned getNumOperands()
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Type front()
Return first type in the range.
Type getType() const
Return the type of this value.
static LogicalResult verifyIntegerDotProduct(Operation *op)
static std::optional< spirv::Version > getDotProductMaxVersion()
static SmallVector< ArrayRef< spirv::Capability >, 1 > getIntegerDotProductCapabilities(Operation *op)
static std::optional< spirv::Version > getDotProductMinVersion()
static SmallVector< ArrayRef< spirv::Extension >, 1 > getIntegerDotProductExtensions()
unsigned getBitWidth(Type type)
Returns the bit width of the type.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.