MLIR  21.0.0git
QuantOps.cpp
Go to the documentation of this file.
1 //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "QuantDialectBytecode.h"
10 #include "TypeDetail.h"
11 
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/TypeUtilities.h"
17 
18 #include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
19 
20 namespace mlir {
21 namespace quant {
22 
23 namespace {
24 
25 // Verify the integrity of per-axis quantization information, if present.
26 //
27 // - uniformQuantizedPerAxisType
28 // A quantized type with per-axis quantization.
29 //
30 // - containerType
31 // Original input or result type of the operation using the provided quantized
32 // type. Used to ensure that the quantized type appears within a tensor and
33 // that the tensor is compatible with per-axis quantization information.
34 //
35 LogicalResult verifyPerAxisQuantization(
36  Operation *op, UniformQuantizedPerAxisType uniformQuantizedPerAxisType,
37  Type containerType) {
38  auto tensorType = dyn_cast<TensorType>(containerType);
39  if (!tensorType)
40  return op->emitError("scalar types may not use per-axis quantization");
41 
42  if (!tensorType.hasRank())
43  return success();
44 
45  int32_t quantizedDimension =
46  uniformQuantizedPerAxisType.getQuantizedDimension();
47  if ((int64_t)quantizedDimension >= tensorType.getRank())
48  return op->emitError("quantized dimension must be less than tensor rank");
49 
50  int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
51  if (quantizedDimensionSize != ShapedType::kDynamic &&
52  quantizedDimensionSize !=
53  (int64_t)uniformQuantizedPerAxisType.getScales().size())
54  return op->emitError(
55  "quantized dimension size does not match number of scales");
56 
57  return success();
58 }
59 
60 // Verifies that the sub-channel quantization parameters are consistent with
61 // the given container type. The function checks the following:
62 //
63 // - The container type must be a ranked tensor type.
64 // - Each quantized dimension must be less than the rank of the tensor.
65 // - The size of each dimension at the quantized dimension must be divisible
66 // by the corresponding block size.
67 // - The scale dimension size at each axis index should match the tensor
68 // dimension at the index divided by the corresponding block size.
69 //
70 // The `uniformQuantizedSubChannelType` argument provides the sub-channel
71 // quantization parameters, and the `containerType` argument specifies the
72 // type of the container holding the quantized data.
73 //
74 LogicalResult verifySubChannelQuantization(
75  Operation *op,
76  UniformQuantizedSubChannelType uniformQuantizedSubChannelType,
77  Type containerType) {
78  auto tensorType = dyn_cast<TensorType>(containerType);
79  if (!tensorType)
80  return op->emitError("scalar types may not use sub-channel quantization");
81 
82  if (!tensorType.hasRank())
83  return op->emitError(
84  "tensor containing the sub-channel quantized type must be ranked");
85 
86  const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo =
87  uniformQuantizedSubChannelType.getBlockSizeInfo();
88  auto shape = tensorType.getShape();
89 
90  // The dimension size of scale for an axis which is not specified as quantized
91  // dimension should be 1.
92  SmallVector<int64_t> expectedScaleShape(tensorType.getShape().size(), 1);
93  for (auto [quantizedDimension, blockSize] : blockSizeInfo) {
94  if (quantizedDimension >= tensorType.getRank())
95  return op->emitError()
96  << "quantized dimension " << quantizedDimension
97  << " must be less than tensor rank " << tensorType.getRank();
98  if (!tensorType.isDynamicDim(quantizedDimension) &&
99  tensorType.getDimSize(quantizedDimension) % blockSize != 0)
100  return op->emitError()
101  << "tensor dimension size "
102  << tensorType.getDimSize(quantizedDimension) << " at axis "
103  << quantizedDimension
104  << " must be divisible by the corresponding block size "
105  << blockSize;
106  if (tensorType.isDynamicDim(quantizedDimension))
107  expectedScaleShape[quantizedDimension] = ShapedType::kDynamic;
108  else
109  expectedScaleShape[quantizedDimension] =
110  tensorType.getDimSize(quantizedDimension) / blockSize;
111  }
112 
113  // Block sizes must be greater than 0 and divide the corresponding dimension
114  // size. While a block size b must be less than or equal to the corresponding
115  // dimension size d, this constraint is implicitly enforced by requiring that
116  // d % b == 0 when d != 0.
117  //
118  // However, a problem arises when d = 0. The divisibility constraint allows b
119  // to be any value, potentially violating the requirement that b <= d.
120  // Furthermore, if b is unspecified (implicitly equal to d), it violates the
121  // constraint that b > 0.
122  //
123  // Therefore, we explicitly disallow the case where d = 0 to maintain
124  // consistency and avoid these issues.
125  if (llvm::find(tensorType.getShape(), 0) != tensorType.getShape().end()) {
126  return op->emitError() << "tensor dimension size of zero is not allowed "
127  "with sub-channel quantization";
128  }
129 
130  auto scaleShape =
131  uniformQuantizedSubChannelType.getScales().getType().getShape();
132  if (scaleShape.size() != shape.size()) {
133  return op->emitError() << "Rank of scales " << scaleShape.size()
134  << " must match "
135  << "the rank of the tensor " << shape.size();
136  }
137 
138  for (auto [index, scaleDim] : llvm::enumerate(expectedScaleShape)) {
139  if (expectedScaleShape[index] != ShapedType::kDynamic &&
140  expectedScaleShape[index] != scaleShape[index])
141  return op->emitError() << "dimension size " << scaleDim
142  << " of scales tensor at axis " << index
143  << " should match (tensor dimension at axis / "
144  "block sizes at axis) = "
145  << expectedScaleShape[index];
146  }
147 
148  return success();
149 }
150 
151 // Common verification logic for 'quant.dcast' and 'quant.qcast' ops.
152 //
153 // - quantizedType
154 // Quantized type used in the input ('quant.dcast') or result ('quant.qcast'),
155 // whether as a primitive type or in a tensor.
156 //
157 // - floatType
158 // Float type used in the input ('quant.qcast') or result ('quant.dcast'),
159 // whether as a primitive type or in a tensor.
160 //
161 // - containerType
162 // Type of original input or result.
163 //
164 LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
165  FloatType floatType, Type containerType) {
166  if (quantizedType.getExpressedType() != floatType)
167  return op->emitError(
168  "expressed type in quantized type expected to match float type");
169 
170  // Verify integrity of per-axis quantization information, if present.
171  if (auto quantizedPerAxisType =
172  dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
173  return verifyPerAxisQuantization(op, quantizedPerAxisType, containerType);
174  }
175 
176  if (auto quantizedSubChannelType =
177  dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) {
178  return verifySubChannelQuantization(op, quantizedSubChannelType,
179  containerType);
180  }
181 
182  // At this point the type is UniformQuantizedType
183  return success();
184 }
185 
186 } // namespace
187 
188 //===----------------------------------------------------------------------===//
189 // Dialect
190 //===----------------------------------------------------------------------===//
191 
192 void QuantDialect::initialize() {
193  addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
194  UniformQuantizedPerAxisType, UniformQuantizedSubChannelType>();
195  addOperations<
196 #define GET_OP_LIST
197 #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
198  >();
200 }
201 
202 //===----------------------------------------------------------------------===//
203 // DequantizeCastOp
204 //===----------------------------------------------------------------------===//
205 
206 LogicalResult DequantizeCastOp::verify() {
207  return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
208  getInput().getType());
209 }
210 
211 OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) {
212  // Matches x -> quant.qcast -> quant.dcast -> y, replacing the quant.dcast op
213  // with the value of x. Values x and y are guaranteed to be of the same type
214  // in this pattern.
215  auto srcQcastOp = getInput().getDefiningOp<QuantizeCastOp>();
216  if (!srcQcastOp)
217  return {};
218  assert(srcQcastOp.getInput().getType() == getType());
219  return srcQcastOp.getInput();
220 }
221 
222 FloatType DequantizeCastOp::getFloatType() {
223  return cast<FloatType>(getElementTypeOrSelf(getResult().getType()));
224 }
225 
226 QuantizedType DequantizeCastOp::getQuantizedType() {
227  return cast<QuantizedType>(getElementTypeOrSelf(getInput().getType()));
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // QuantizeCastOp
232 //===----------------------------------------------------------------------===//
233 
234 LogicalResult QuantizeCastOp::verify() {
235  return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
236  getInput().getType());
237 }
238 
239 OpFoldResult QuantizeCastOp::fold(FoldAdaptor adaptor) {
240  // Matches x -> quant.dcast -> quant.qcast -> y, replacing the quant.qcast op
241  // with the value of x if the casts invert each other. Contrary to the folding
242  // pattern in quant.dcast (i.e., x -> quant.qcast -> quant.dcast -> y), values
243  // x and y are not guaranteed to be of the same type here, as they may use
244  // different quantization parameters.
245  auto srcDcastOp = getInput().getDefiningOp<DequantizeCastOp>();
246  if (!srcDcastOp || srcDcastOp.getInput().getType() != getType())
247  return {};
248  return srcDcastOp.getInput();
249 }
250 
251 FloatType QuantizeCastOp::getFloatType() {
252  return cast<FloatType>(getElementTypeOrSelf(getInput().getType()));
253 }
254 
255 QuantizedType QuantizeCastOp::getQuantizedType() {
256  return cast<QuantizedType>(getElementTypeOrSelf(getResult().getType()));
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // StorageCastOp
261 //===----------------------------------------------------------------------===//
262 
263 LogicalResult StorageCastOp::verify() {
264  auto quantizedType = getQuantizedType();
265  auto integerType = getIntegerType();
266  if (quantizedType.getStorageType() != integerType)
267  return emitError(
268  "storage type in quantized type expected to match integer type");
269 
270  // Verify integrity of per-axis quantization information, if available. While
271  // the quantization type may appear in the input or the result, their tensor
272  // shapes are guaranteed to be identical at this point.
273  if (auto quantizedPerAxisType =
274  dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
275  return verifyPerAxisQuantization(*this, quantizedPerAxisType,
276  getInput().getType());
277  }
278 
279  if (auto quantizedSunChannelType =
280  dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) {
281  return verifySubChannelQuantization(*this, quantizedSunChannelType,
282  getInput().getType());
283  }
284 
285  // At this point the type is UniformQuantizedType
286  return success();
287 }
288 
289 OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
290  // Matches x -> quant.scast -> quant.scast -> y, replacing the second
291  // quant.scast with the value of x if the casts invert each other.
292  auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
293  if (!srcScastOp || srcScastOp.getInput().getType() != getType())
294  return {};
295  return srcScastOp.getInput();
296 }
297 
298 IntegerType StorageCastOp::getIntegerType() {
299  auto inputScalarType = getElementTypeOrSelf(getInput().getType());
300  if (auto integerType = dyn_cast<IntegerType>(inputScalarType))
301  return integerType;
302 
303  auto resultScalarType = getElementTypeOrSelf(getResult().getType());
304  return cast<IntegerType>(resultScalarType);
305 }
306 
307 QuantizedType StorageCastOp::getQuantizedType() {
308  auto inputScalarType = getElementTypeOrSelf(getInput().getType());
309  if (auto quantizedType = dyn_cast<QuantizedType>(inputScalarType))
310  return quantizedType;
311 
312  auto resultScalarType = getElementTypeOrSelf(getResult().getType());
313  return cast<QuantizedType>(resultScalarType);
314 }
315 
316 } // namespace quant
317 } // namespace mlir
318 
319 #define GET_OP_CLASSES
320 #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
FloatType getFloatType(MLIRContext *context, unsigned width)
Returns a supported MLIR floating point type of the given bit width or null if the bit width is not s...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void addBytecodeInterface(QuantDialect *dialect)
Add the interfaces necessary for encoding the quantization dialect components in bytecode.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:424