MLIR  22.0.0git
DotProductOps.cpp
Go to the documentation of this file.
1 //===- DotProductOps.cpp - MLIR SPIR-V Dot Product Ops -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the Dot Product operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "SPIRVOpUtils.h"
16 #include "SPIRVParsingUtils.h"
17 
18 #include "llvm/Support/FormatVariadic.h"
19 
20 using namespace mlir::spirv::AttrNames;
21 
22 namespace mlir::spirv {
23 
24 //===----------------------------------------------------------------------===//
25 // Dot Product ops
26 //===----------------------------------------------------------------------===//
27 
28 static std::optional<spirv::Version> getDotProductMinVersion() {
29  return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
30 }
31 
32 static std::optional<spirv::Version> getDotProductMaxVersion() {
33  return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
34 }
35 
36 SmallVector<ArrayRef<spirv::Extension>, 1> DotOp::getExtensions() {
37  if (isa<BFloat16Type>(getType())) {
38  static const auto extension = spirv::Extension::SPV_KHR_bfloat16;
39  return {extension};
40  }
41 
42  return {};
43 }
44 
45 SmallVector<ArrayRef<spirv::Capability>, 1> DotOp::getCapabilities() {
46  if (isa<BFloat16Type>(getType())) {
47  static const auto capability = spirv::Capability::BFloat16DotProductKHR;
48  return {capability};
49  }
50 
51  return {};
52 }
53 
54 std::optional<spirv::Version> DotOp::getMinVersion() {
55  return getDotProductMinVersion();
56 }
57 
58 std::optional<spirv::Version> DotOp::getMaxVersion() {
59  return getDotProductMaxVersion();
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // Integer Dot Product ops
64 //===----------------------------------------------------------------------===//
65 
66 template <typename IntegerDotProductOpTy>
67 static LogicalResult verifyIntegerDotProduct(Operation *op) {
68  assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&
69  "Not an integer dot product op?");
70  assert(op->getNumResults() == 1 && "Expected a single result");
71 
72  // ODS enforces that vector 1 and vector 2, and result and the accumulator
73  // have the same types.
74  Type factorTy = op->getOperand(0).getType();
75  StringAttr packedVectorFormatAttrName =
76  IntegerDotProductOpTy::getFormatAttrName(op->getName());
77  if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
78  auto packedVectorFormat =
79  llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
80  op->getAttr(packedVectorFormatAttrName));
81  if (!packedVectorFormat)
82  return op->emitOpError("requires Packed Vector Format attribute for "
83  "integer vector operands");
84 
85  assert(packedVectorFormat.getValue() ==
86  spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&
87  "Unknown Packed Vector Format");
88  if (intTy.getWidth() != 32)
89  return op->emitOpError(
90  llvm::formatv("with specified Packed Vector Format ({0}) requires "
91  "integer vector operands to be 32-bits wide",
92  packedVectorFormat.getValue()));
93  } else {
94  if (op->hasAttr(packedVectorFormatAttrName))
95  return op->emitOpError(llvm::formatv(
96  "with invalid format attribute for vector operands of type '{0}'",
97  factorTy));
98  }
99 
100  Type resultTy = op->getResultTypes().front();
101  unsigned factorBitWidth = getBitWidth(factorTy);
102  unsigned resultBitWidth = getBitWidth(resultTy);
103  if (factorBitWidth > resultBitWidth)
104  return op->emitOpError(
105  llvm::formatv("result type has insufficient bit-width ({0} bits) "
106  "for the specified vector operand type ({1} bits)",
107  resultBitWidth, factorBitWidth));
108 
109  return success();
110 }
111 
114  // Requires the SPV_KHR_integer_dot_product extension, specified either
115  // explicitly or implied by target env's SPIR-V version >= 1.6.
116  static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product;
117  return {extension};
118 }
119 
120 template <typename IntegerDotProductOpTy>
123  // Requires the the DotProduct capability and capabilities that depend on
124  // exact op types.
125  static const auto dotProductCap = spirv::Capability::DotProduct;
126  static const auto dotProductInput4x8BitPackedCap =
127  spirv::Capability::DotProductInput4x8BitPacked;
128  static const auto dotProductInput4x8BitCap =
129  spirv::Capability::DotProductInput4x8Bit;
130  static const auto dotProductInputAllCap =
131  spirv::Capability::DotProductInputAll;
132 
133  SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
134 
135  Type factorTy = op->getOperand(0).getType();
136  StringAttr packedVectorFormatAttrName =
137  IntegerDotProductOpTy::getFormatAttrName(op->getName());
138  if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
139  auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
140  op->getAttr(packedVectorFormatAttrName));
141  if (formatAttr.getValue() ==
142  spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
143  capabilities.push_back(dotProductInput4x8BitPackedCap);
144 
145  return capabilities;
146  }
147 
148  auto vecTy = llvm::cast<VectorType>(factorTy);
149  if (vecTy.getElementTypeBitWidth() == 8) {
150  capabilities.push_back(dotProductInput4x8BitCap);
151  return capabilities;
152  }
153 
154  capabilities.push_back(dotProductInputAllCap);
155  return capabilities;
156 }
157 
158 #define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \
159  LogicalResult OpName::verify() { \
160  return verifyIntegerDotProduct<OpName>(*this); \
161  } \
162  SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() { \
163  return getIntegerDotProductExtensions(); \
164  } \
165  SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() { \
166  return getIntegerDotProductCapabilities<OpName>(*this); \
167  } \
168  std::optional<spirv::Version> OpName::getMinVersion() { \
169  return getDotProductMinVersion(); \
170  } \
171  std::optional<spirv::Version> OpName::getMaxVersion() { \
172  return getDotProductMaxVersion(); \
173  }
174 
181 
182 #undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP
183 
184 } // namespace mlir::spirv
#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Definition: Operation.h:560
unsigned getNumOperands()
Definition: Operation.h:346
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Type front()
Return first type in the range.
Definition: TypeRange.h:152
Type getType() const
Return the type of this value.
Definition: Value.h:105
static LogicalResult verifyIntegerDotProduct(Operation *op)
static std::optional< spirv::Version > getDotProductMaxVersion()
static SmallVector< ArrayRef< spirv::Capability >, 1 > getIntegerDotProductCapabilities(Operation *op)
static std::optional< spirv::Version > getDotProductMinVersion()
static SmallVector< ArrayRef< spirv::Extension >, 1 > getIntegerDotProductExtensions()
unsigned getBitWidth(Type type)
Returns the bit width of the type.
Definition: SPIRVOpUtils.h:14
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304