MLIR  18.0.0git
IntegerDotProductOps.cpp
Go to the documentation of this file.
1 //===- IntegerDotProductOps.cpp - MLIR SPIR-V Integer Dot Product Ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the Integer Dot Product operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "SPIRVOpUtils.h"
16 #include "SPIRVParsingUtils.h"
17 
18 #include "llvm/Support/FormatVariadic.h"
19 
20 using namespace mlir::spirv::AttrNames;
21 
22 namespace mlir::spirv {
23 
24 //===----------------------------------------------------------------------===//
25 // Integer Dot Product ops
26 //===----------------------------------------------------------------------===//
27 
29  assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&
30  "Not an integer dot product op?");
31  assert(op->getNumResults() == 1 && "Expected a single result");
32 
33  // ODS enforces that vector 1 and vector 2, and result and the accumulator
34  // have the same types.
35  Type factorTy = op->getOperand(0).getType();
36  if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
37  auto packedVectorFormat =
38  llvm::dyn_cast_or_null<spirv::PackedVectorFormatAttr>(
39  op->getAttr(kPackedVectorFormatAttrName));
40  if (!packedVectorFormat)
41  return op->emitOpError("requires Packed Vector Format attribute for "
42  "integer vector operands");
43 
44  assert(packedVectorFormat.getValue() ==
45  spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&
46  "Unknown Packed Vector Format");
47  if (intTy.getWidth() != 32)
48  return op->emitOpError(
49  llvm::formatv("with specified Packed Vector Format ({0}) requires "
50  "integer vector operands to be 32-bits wide",
51  packedVectorFormat.getValue()));
52  } else {
53  if (op->hasAttr(kPackedVectorFormatAttrName))
54  return op->emitOpError(llvm::formatv(
55  "with invalid format attribute for vector operands of type '{0}'",
56  factorTy));
57  }
58 
59  Type resultTy = op->getResultTypes().front();
60  unsigned factorBitWidth = getBitWidth(factorTy);
61  unsigned resultBitWidth = getBitWidth(resultTy);
62  if (factorBitWidth > resultBitWidth)
63  return op->emitOpError(
64  llvm::formatv("result type has insufficient bit-width ({0} bits) "
65  "for the specified vector operand type ({1} bits)",
66  resultBitWidth, factorBitWidth));
67 
68  return success();
69 }
70 
71 static std::optional<spirv::Version> getIntegerDotProductMinVersion() {
72  return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
73 }
74 
75 static std::optional<spirv::Version> getIntegerDotProductMaxVersion() {
76  return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
77 }
78 
81  // Requires the SPV_KHR_integer_dot_product extension, specified either
82  // explicitly or implied by target env's SPIR-V version >= 1.6.
83  static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product;
84  return {extension};
85 }
86 
89  // Requires the the DotProduct capability and capabilities that depend on
90  // exact op types.
91  static const auto dotProductCap = spirv::Capability::DotProduct;
92  static const auto dotProductInput4x8BitPackedCap =
93  spirv::Capability::DotProductInput4x8BitPacked;
94  static const auto dotProductInput4x8BitCap =
95  spirv::Capability::DotProductInput4x8Bit;
96  static const auto dotProductInputAllCap =
97  spirv::Capability::DotProductInputAll;
98 
99  SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
100 
101  Type factorTy = op->getOperand(0).getType();
102  if (auto intTy = llvm::dyn_cast<IntegerType>(factorTy)) {
103  auto formatAttr = llvm::cast<spirv::PackedVectorFormatAttr>(
104  op->getAttr(kPackedVectorFormatAttrName));
105  if (formatAttr.getValue() ==
106  spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
107  capabilities.push_back(dotProductInput4x8BitPackedCap);
108 
109  return capabilities;
110  }
111 
112  auto vecTy = llvm::cast<VectorType>(factorTy);
113  if (vecTy.getElementTypeBitWidth() == 8) {
114  capabilities.push_back(dotProductInput4x8BitCap);
115  return capabilities;
116  }
117 
118  capabilities.push_back(dotProductInputAllCap);
119  return capabilities;
120 }
121 
122 #define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \
123  LogicalResult OpName::verify() { return verifyIntegerDotProduct(*this); } \
124  SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() { \
125  return getIntegerDotProductExtensions(); \
126  } \
127  SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() { \
128  return getIntegerDotProductCapabilities(*this); \
129  } \
130  std::optional<spirv::Version> OpName::getMinVersion() { \
131  return getIntegerDotProductMinVersion(); \
132  } \
133  std::optional<spirv::Version> OpName::getMaxVersion() { \
134  return getIntegerDotProductMaxVersion(); \
135  }
136 
143 
144 #undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP
145 
146 } // namespace mlir::spirv
#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName)
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
constexpr char kPackedVectorFormatAttrName[]
static std::optional< spirv::Version > getIntegerDotProductMinVersion()
static SmallVector< ArrayRef< spirv::Capability >, 1 > getIntegerDotProductCapabilities(Operation *op)
static std::optional< spirv::Version > getIntegerDotProductMaxVersion()
static LogicalResult verifyIntegerDotProduct(Operation *op)
static SmallVector< ArrayRef< spirv::Extension >, 1 > getIntegerDotProductExtensions()
unsigned getBitWidth(Type type)
Returns the bit width of the type.
Definition: SPIRVOpUtils.h:14
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26