MLIR 22.0.0git
DotProductOps.cpp
Go to the documentation of this file.
1//===- DotProductOps.cpp - MLIR SPIR-V Dot Product Ops -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Defines the Dot Product operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
14
15#include "SPIRVOpUtils.h"
16#include "SPIRVParsingUtils.h"
17
18#include "llvm/Support/FormatVariadic.h"
19
20using namespace mlir::spirv::AttrNames;
21
22namespace mlir::spirv {
23
24//===----------------------------------------------------------------------===//
25// Dot Product ops
26//===----------------------------------------------------------------------===//
27
28static std::optional<spirv::Version> getDotProductMinVersion() {
29 return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
30}
31
32static std::optional<spirv::Version> getDotProductMaxVersion() {
33 return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
34}
35
36SmallVector<ArrayRef<spirv::Extension>, 1> DotOp::getExtensions() {
37 if (isa<BFloat16Type>(getType())) {
38 static const auto extension = spirv::Extension::SPV_KHR_bfloat16;
39 return {extension};
40 }
41
42 return {};
43}
44
45SmallVector<ArrayRef<spirv::Capability>, 1> DotOp::getCapabilities() {
46 if (isa<BFloat16Type>(getType())) {
47 static const auto capability = spirv::Capability::BFloat16DotProductKHR;
48 return {capability};
49 }
50
51 return {};
52}
53
54std::optional<spirv::Version> DotOp::getMinVersion() {
56}
57
58std::optional<spirv::Version> DotOp::getMaxVersion() {
60}
61
62//===----------------------------------------------------------------------===//
63// Integer Dot Product ops
64//===----------------------------------------------------------------------===//
65
66template <typename IntegerDotProductOpTy>
67static LogicalResult verifyIntegerDotProduct(Operation *op) {
68 assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&
69 "Not an integer dot product op?");
70 assert(op->getNumResults() == 1 && "Expected a single result");
71
72 // ODS enforces that vector 1 and vector 2, and result and the accumulator
73 // have the same types.
74 Type factorTy = op->getOperand(0).getType();
75 StringAttr packedVectorFormatAttrName =
76 IntegerDotProductOpTy::getFormatAttrName(op->getName());
77 if (auto intTy = dyn_cast<IntegerType>(factorTy)) {
78 auto packedVectorFormat = dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
79 op->getAttr(packedVectorFormatAttrName));
80 if (!packedVectorFormat)
81 return op->emitOpError("requires Packed Vector Format attribute for "
82 "integer vector operands");
83
84 assert(packedVectorFormat.getValue() ==
85 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&
86 "Unknown Packed Vector Format");
87 if (intTy.getWidth() != 32)
88 return op->emitOpError(
89 llvm::formatv("with specified Packed Vector Format ({0}) requires "
90 "integer vector operands to be 32-bits wide",
91 packedVectorFormat.getValue()));
92 } else {
93 if (op->hasAttr(packedVectorFormatAttrName))
94 return op->emitOpError(llvm::formatv(
95 "with invalid format attribute for vector operands of type '{0}'",
96 factorTy));
97 }
98
99 Type resultTy = op->getResultTypes().front();
100 unsigned factorBitWidth = getBitWidth(factorTy);
101 unsigned resultBitWidth = getBitWidth(resultTy);
102 if (factorBitWidth > resultBitWidth)
103 return op->emitOpError(
104 llvm::formatv("result type has insufficient bit-width ({0} bits) "
105 "for the specified vector operand type ({1} bits)",
106 resultBitWidth, factorBitWidth));
107
108 return success();
109}
110
113 // Requires the SPV_KHR_integer_dot_product extension, specified either
114 // explicitly or implied by target env's SPIR-V version >= 1.6.
115 static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product;
116 return {extension};
117}
118
119template <typename IntegerDotProductOpTy>
122 // Requires the the DotProduct capability and capabilities that depend on
123 // exact op types.
124 static const auto dotProductCap = spirv::Capability::DotProduct;
125 static const auto dotProductInput4x8BitPackedCap =
126 spirv::Capability::DotProductInput4x8BitPacked;
127 static const auto dotProductInput4x8BitCap =
128 spirv::Capability::DotProductInput4x8Bit;
129 static const auto dotProductInputAllCap =
130 spirv::Capability::DotProductInputAll;
131
132 SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
133
134 Type factorTy = op->getOperand(0).getType();
135 StringAttr packedVectorFormatAttrName =
136 IntegerDotProductOpTy::getFormatAttrName(op->getName());
137 if (auto intTy = dyn_cast<IntegerType>(factorTy)) {
138 auto formatAttr = cast<spirv::PackedVectorFormatAttr>(
139 op->getAttr(packedVectorFormatAttrName));
140 if (formatAttr.getValue() ==
141 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
142 capabilities.push_back(dotProductInput4x8BitPackedCap);
143
144 return capabilities;
145 }
146
147 auto vecTy = cast<VectorType>(factorTy);
148 if (vecTy.getElementTypeBitWidth() == 8) {
149 capabilities.push_back(dotProductInput4x8BitCap);
150 return capabilities;
151 }
152
153 capabilities.push_back(dotProductInputAllCap);
154 return capabilities;
155}
156
157#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \
158 LogicalResult OpName::verify() { \
159 return verifyIntegerDotProduct<OpName>(*this); \
160 } \
161 SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() { \
162 return getIntegerDotProductExtensions(); \
163 } \
164 SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() { \
165 return getIntegerDotProductCapabilities<OpName>(*this); \
166 } \
167 std::optional<spirv::Version> OpName::getMinVersion() { \
168 return getDotProductMinVersion(); \
169 } \
170 std::optional<spirv::Version> OpName::getMaxVersion() { \
171 return getDotProductMaxVersion(); \
172 }
173
180
181#undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP
182
183} // namespace mlir::spirv
return success()
#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName)
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition Operation.h:534
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition Operation.h:560
unsigned getNumOperands()
Definition Operation.h:346
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
result_type_range getResultTypes()
Definition Operation.h:428
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Type front()
Return first type in the range.
Definition TypeRange.h:152
Type getType() const
Return the type of this value.
Definition Value.h:105
static LogicalResult verifyIntegerDotProduct(Operation *op)
static SmallVector< ArrayRef< spirv::Capability >, 1 > getIntegerDotProductCapabilities(Operation *op)
static std::optional< spirv::Version > getDotProductMaxVersion()
static SmallVector< ArrayRef< spirv::Extension >, 1 > getIntegerDotProductExtensions()
static std::optional< spirv::Version > getDotProductMinVersion()
unsigned getBitWidth(Type type)
Returns the bit width of the type.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304