MLIR  19.0.0git
IntegerDotProductOps.cpp
Go to the documentation of this file.
1 //===- IntegerDotProductOps.cpp - MLIR SPIR-V Integer Dot Product Ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the Integer Dot Product operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "SPIRVOpUtils.h"
16 #include "SPIRVParsingUtils.h"
17 
18 #include "llvm/Support/FormatVariadic.h"
19 
20 using namespace mlir::spirv::AttrNames;
21 
22 namespace mlir::spirv {
23 
24 //===----------------------------------------------------------------------===//
25 // Integer Dot Product ops
26 //===----------------------------------------------------------------------===//
27 
28 template <typename IntegerDotProductOpTy>
30  assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&
31  "Not an integer dot product op?");
32  assert(op->getNumResults() == 1 && "Expected a single result");
33 
34  // ODS enforces that vector 1 and vector 2, and result and the accumulator
35  // have the same types.
36  Type factorTy = op->getOperand(0).getType();
37  StringAttr packedVectorFormatAttrName =
38  IntegerDotProductOpTy::getFormatAttrName(op->getName());
39  if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
40  auto packedVectorFormat =
41  llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
42  op->getAttr(packedVectorFormatAttrName));
43  if (!packedVectorFormat)
44  return op->emitOpError("requires Packed Vector Format attribute for "
45  "integer vector operands");
46 
47  assert(packedVectorFormat.getValue() ==
48  spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&
49  "Unknown Packed Vector Format");
50  if (intTy.getWidth() != 32)
51  return op->emitOpError(
52  llvm::formatv("with specified Packed Vector Format ({0}) requires "
53  "integer vector operands to be 32-bits wide",
54  packedVectorFormat.getValue()));
55  } else {
56  if (op->hasAttr(packedVectorFormatAttrName))
57  return op->emitOpError(llvm::formatv(
58  "with invalid format attribute for vector operands of type '{0}'",
59  factorTy));
60  }
61 
62  Type resultTy = op->getResultTypes().front();
63  unsigned factorBitWidth = getBitWidth(factorTy);
64  unsigned resultBitWidth = getBitWidth(resultTy);
65  if (factorBitWidth > resultBitWidth)
66  return op->emitOpError(
67  llvm::formatv("result type has insufficient bit-width ({0} bits) "
68  "for the specified vector operand type ({1} bits)",
69  resultBitWidth, factorBitWidth));
70 
71  return success();
72 }
73 
74 static std::optional<spirv::Version> getIntegerDotProductMinVersion() {
75  return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
76 }
77 
78 static std::optional<spirv::Version> getIntegerDotProductMaxVersion() {
79  return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
80 }
81 
84  // Requires the SPV_KHR_integer_dot_product extension, specified either
85  // explicitly or implied by target env's SPIR-V version >= 1.6.
86  static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product;
87  return {extension};
88 }
89 
90 template <typename IntegerDotProductOpTy>
93  // Requires the the DotProduct capability and capabilities that depend on
94  // exact op types.
95  static const auto dotProductCap = spirv::Capability::DotProduct;
96  static const auto dotProductInput4x8BitPackedCap =
97  spirv::Capability::DotProductInput4x8BitPacked;
98  static const auto dotProductInput4x8BitCap =
99  spirv::Capability::DotProductInput4x8Bit;
100  static const auto dotProductInputAllCap =
101  spirv::Capability::DotProductInputAll;
102 
103  SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
104 
105  Type factorTy = op->getOperand(0).getType();
106  StringAttr packedVectorFormatAttrName =
107  IntegerDotProductOpTy::getFormatAttrName(op->getName());
108  if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
109  auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
110  op->getAttr(packedVectorFormatAttrName));
111  if (formatAttr.getValue() ==
112  spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
113  capabilities.push_back(dotProductInput4x8BitPackedCap);
114 
115  return capabilities;
116  }
117 
118  auto vecTy = llvm::cast<VectorType>(factorTy);
119  if (vecTy.getElementTypeBitWidth() == 8) {
120  capabilities.push_back(dotProductInput4x8BitCap);
121  return capabilities;
122  }
123 
124  capabilities.push_back(dotProductInputAllCap);
125  return capabilities;
126 }
127 
128 #define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \
129  LogicalResult OpName::verify() { \
130  return verifyIntegerDotProduct<OpName>(*this); \
131  } \
132  SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() { \
133  return getIntegerDotProductExtensions(); \
134  } \
135  SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() { \
136  return getIntegerDotProductCapabilities<OpName>(*this); \
137  } \
138  std::optional<spirv::Version> OpName::getMinVersion() { \
139  return getIntegerDotProductMinVersion(); \
140  } \
141  std::optional<spirv::Version> OpName::getMaxVersion() { \
142  return getIntegerDotProductMaxVersion(); \
143  }
144 
151 
152 #undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP
153 
154 } // namespace mlir::spirv
#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
static LogicalResult verifyIntegerDotProduct(Operation *op)
static std::optional< spirv::Version > getIntegerDotProductMinVersion()
static SmallVector< ArrayRef< spirv::Capability >, 1 > getIntegerDotProductCapabilities(Operation *op)
static std::optional< spirv::Version > getIntegerDotProductMaxVersion()
static SmallVector< ArrayRef< spirv::Extension >, 1 > getIntegerDotProductExtensions()
unsigned getBitWidth(Type type)
Returns the bit width of the type.
Definition: SPIRVOpUtils.h:14
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26