18 #include "llvm/Support/FormatVariadic.h"
28 template <
typename IntegerDotProductOpTy>
30 assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&
31 "Not an integer dot product op?");
32 assert(op->getNumResults() == 1 &&
"Expected a single result");
36 Type factorTy = op->getOperand(0).getType();
37 StringAttr packedVectorFormatAttrName =
38 IntegerDotProductOpTy::getFormatAttrName(op->getName());
39 if (
auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
40 auto packedVectorFormat =
41 llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
42 op->getAttr(packedVectorFormatAttrName));
43 if (!packedVectorFormat)
44 return op->emitOpError(
"requires Packed Vector Format attribute for "
45 "integer vector operands");
47 assert(packedVectorFormat.getValue() ==
48 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&
49 "Unknown Packed Vector Format");
50 if (intTy.getWidth() != 32)
51 return op->emitOpError(
52 llvm::formatv(
"with specified Packed Vector Format ({0}) requires "
53 "integer vector operands to be 32-bits wide",
54 packedVectorFormat.getValue()));
56 if (op->hasAttr(packedVectorFormatAttrName))
57 return op->emitOpError(llvm::formatv(
58 "with invalid format attribute for vector operands of type '{0}'",
62 Type resultTy = op->getResultTypes().front();
65 if (factorBitWidth > resultBitWidth)
66 return op->emitOpError(
67 llvm::formatv(
"result type has insufficient bit-width ({0} bits) "
68 "for the specified vector operand type ({1} bits)",
69 resultBitWidth, factorBitWidth));
75 return spirv::Version::V_1_0;
79 return spirv::Version::V_1_6;
86 static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product;
90 template <
typename IntegerDotProductOpTy>
95 static const auto dotProductCap = spirv::Capability::DotProduct;
96 static const auto dotProductInput4x8BitPackedCap =
97 spirv::Capability::DotProductInput4x8BitPacked;
98 static const auto dotProductInput4x8BitCap =
99 spirv::Capability::DotProductInput4x8Bit;
100 static const auto dotProductInputAllCap =
101 spirv::Capability::DotProductInputAll;
105 Type factorTy = op->getOperand(0).getType();
106 StringAttr packedVectorFormatAttrName =
107 IntegerDotProductOpTy::getFormatAttrName(op->getName());
108 if (
auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
109 auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
110 op->getAttr(packedVectorFormatAttrName));
111 if (formatAttr.getValue() ==
112 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
113 capabilities.push_back(dotProductInput4x8BitPackedCap);
118 auto vecTy = llvm::cast<VectorType>(factorTy);
119 if (vecTy.getElementTypeBitWidth() == 8) {
120 capabilities.push_back(dotProductInput4x8BitCap);
124 capabilities.push_back(dotProductInputAllCap);
128 #define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \
129 LogicalResult OpName::verify() { \
130 return verifyIntegerDotProduct<OpName>(*this); \
132 SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() { \
133 return getIntegerDotProductExtensions(); \
135 SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() { \
136 return getIntegerDotProductCapabilities<OpName>(*this); \
138 std::optional<spirv::Version> OpName::getMinVersion() { \
139 return getIntegerDotProductMinVersion(); \
141 std::optional<spirv::Version> OpName::getMaxVersion() { \
142 return getIntegerDotProductMaxVersion(); \
152 #undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP
#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName)
Operation is the basic unit of execution within MLIR.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static LogicalResult verifyIntegerDotProduct(Operation *op)
static std::optional< spirv::Version > getIntegerDotProductMinVersion()
static SmallVector< ArrayRef< spirv::Capability >, 1 > getIntegerDotProductCapabilities(Operation *op)
static std::optional< spirv::Version > getIntegerDotProductMaxVersion()
static SmallVector< ArrayRef< spirv::Extension >, 1 > getIntegerDotProductExtensions()
unsigned getBitWidth(Type type)
Returns the bit width of the type.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
This class represents an efficient way to signal success or failure.