MLIR 22.0.0git
QuantOps.cpp
Go to the documentation of this file.
1//===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "TypeDetail.h"
11
17
18#include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
19
20namespace mlir {
21namespace quant {
22
23namespace {
24
25// Verify the integrity of per-axis quantization information, if present.
26//
27// - uniformQuantizedPerAxisType
28// A quantized type with per-axis quantization.
29//
30// - containerType
31// Original input or result type of the operation using the provided quantized
32// type. Used to ensure that the quantized type appears within a tensor and
33// that the tensor is compatible with per-axis quantization information.
34//
35LogicalResult verifyPerAxisQuantization(
36 Operation *op, UniformQuantizedPerAxisType uniformQuantizedPerAxisType,
37 Type containerType) {
38 auto tensorType = dyn_cast<TensorType>(containerType);
39 if (!tensorType)
40 return op->emitError("scalar types may not use per-axis quantization");
41
42 if (!tensorType.hasRank())
43 return success();
44
45 int32_t quantizedDimension =
46 uniformQuantizedPerAxisType.getQuantizedDimension();
47 if ((int64_t)quantizedDimension >= tensorType.getRank())
48 return op->emitError("quantized dimension must be less than tensor rank");
49
50 int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
51 if (quantizedDimensionSize != ShapedType::kDynamic &&
52 quantizedDimensionSize !=
53 (int64_t)uniformQuantizedPerAxisType.getScales().size())
54 return op->emitError(
55 "quantized dimension size does not match number of scales");
56
57 return success();
58}
59
60// Verifies that the sub-channel quantization parameters are consistent with
61// the given container type. The function checks the following:
62//
63// - The container type must be a ranked tensor type.
64// - Each quantized dimension must be less than the rank of the tensor.
65// - The size of each dimension at the quantized dimension must be divisible
66// by the corresponding block size.
67// - The scale dimension size at each axis index should match the tensor
68// dimension at the index divided by the corresponding block size.
69//
70// The `uniformQuantizedSubChannelType` argument provides the sub-channel
71// quantization parameters, and the `containerType` argument specifies the
72// type of the container holding the quantized data.
73//
74LogicalResult verifySubChannelQuantization(
75 Operation *op,
76 UniformQuantizedSubChannelType uniformQuantizedSubChannelType,
77 Type containerType) {
78 auto tensorType = dyn_cast<TensorType>(containerType);
79 if (!tensorType)
80 return op->emitError("scalar types may not use sub-channel quantization");
81
82 if (!tensorType.hasRank())
83 return op->emitError(
84 "tensor containing the sub-channel quantized type must be ranked");
85
86 const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo =
87 uniformQuantizedSubChannelType.getBlockSizeInfo();
88 auto shape = tensorType.getShape();
89
90 // The dimension size of scale for an axis which is not specified as quantized
91 // dimension should be 1.
92 SmallVector<int64_t> expectedScaleShape(tensorType.getShape().size(), 1);
93 for (auto [quantizedDimension, blockSize] : blockSizeInfo) {
94 if (quantizedDimension >= tensorType.getRank())
95 return op->emitError()
96 << "quantized dimension " << quantizedDimension
97 << " must be less than tensor rank " << tensorType.getRank();
98 if (!tensorType.isDynamicDim(quantizedDimension) &&
99 tensorType.getDimSize(quantizedDimension) % blockSize != 0)
100 return op->emitError()
101 << "tensor dimension size "
102 << tensorType.getDimSize(quantizedDimension) << " at axis "
103 << quantizedDimension
104 << " must be divisible by the corresponding block size "
105 << blockSize;
106 if (tensorType.isDynamicDim(quantizedDimension))
107 expectedScaleShape[quantizedDimension] = ShapedType::kDynamic;
108 else
109 expectedScaleShape[quantizedDimension] =
110 tensorType.getDimSize(quantizedDimension) / blockSize;
111 }
112
113 // Block sizes must be greater than 0 and divide the corresponding dimension
114 // size. While a block size b must be less than or equal to the corresponding
115 // dimension size d, this constraint is implicitly enforced by requiring that
116 // d % b == 0 when d != 0.
117 //
118 // However, a problem arises when d = 0. The divisibility constraint allows b
119 // to be any value, potentially violating the requirement that b <= d.
120 // Furthermore, if b is unspecified (implicitly equal to d), it violates the
121 // constraint that b > 0.
122 //
123 // Therefore, we explicitly disallow the case where d = 0 to maintain
124 // consistency and avoid these issues.
125 if (llvm::is_contained(tensorType.getShape(), 0)) {
126 return op->emitError() << "tensor dimension size of zero is not allowed "
127 "with sub-channel quantization";
128 }
129
130 auto scaleShape =
131 uniformQuantizedSubChannelType.getScales().getType().getShape();
132 if (scaleShape.size() != shape.size()) {
133 return op->emitError() << "Rank of scales " << scaleShape.size()
134 << " must match "
135 << "the rank of the tensor " << shape.size();
136 }
137
138 for (auto [index, scaleDim] : llvm::enumerate(expectedScaleShape)) {
139 if (expectedScaleShape[index] != ShapedType::kDynamic &&
140 expectedScaleShape[index] != scaleShape[index])
141 return op->emitError() << "dimension size " << scaleDim
142 << " of scales tensor at axis " << index
143 << " should match (tensor dimension at axis / "
144 "block sizes at axis) = "
145 << expectedScaleShape[index];
146 }
147
148 return success();
149}
150
151// Common verification logic for 'quant.dcast' and 'quant.qcast' ops.
152//
153// - quantizedType
154// Quantized type used in the input ('quant.dcast') or result ('quant.qcast'),
155// whether as a primitive type or in a tensor.
156//
157// - floatType
158// Float type used in the input ('quant.qcast') or result ('quant.dcast'),
159// whether as a primitive type or in a tensor.
160//
161// - containerType
162// Type of original input or result.
163//
164LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
165 FloatType floatType, Type containerType) {
166 if (quantizedType.getExpressedType() != floatType)
167 return op->emitError(
168 "expressed type in quantized type expected to match float type");
169
170 // Verify integrity of per-axis quantization information, if present.
171 if (auto quantizedPerAxisType =
172 dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
173 return verifyPerAxisQuantization(op, quantizedPerAxisType, containerType);
174 }
175
176 if (auto quantizedSubChannelType =
177 dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) {
178 return verifySubChannelQuantization(op, quantizedSubChannelType,
179 containerType);
180 }
181
182 // At this point the type is UniformQuantizedType
183 return success();
184}
185
186} // namespace
187
188//===----------------------------------------------------------------------===//
189// Dialect
190//===----------------------------------------------------------------------===//
191
192void QuantDialect::initialize() {
195 addOperations<
196#define GET_OP_LIST
197#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
198 >();
200}
201
202//===----------------------------------------------------------------------===//
203// DequantizeCastOp
204//===----------------------------------------------------------------------===//
205
206LogicalResult DequantizeCastOp::verify() {
207 return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
208 getInput().getType());
209}
210
211OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) {
212 // Matches x -> quant.qcast -> quant.dcast -> y, replacing the quant.dcast op
213 // with the value of x. Values x and y are guaranteed to be of the same type
214 // in this pattern.
215 auto srcQcastOp = getInput().getDefiningOp<QuantizeCastOp>();
216 if (!srcQcastOp)
217 return {};
218 assert(srcQcastOp.getInput().getType() == getType());
219 return srcQcastOp.getInput();
220}
221
222FloatType DequantizeCastOp::getFloatType() {
223 return cast<FloatType>(getElementTypeOrSelf(getResult().getType()));
224}
225
226QuantizedType DequantizeCastOp::getQuantizedType() {
227 return cast<QuantizedType>(getElementTypeOrSelf(getInput().getType()));
228}
229
230//===----------------------------------------------------------------------===//
231// QuantizeCastOp
232//===----------------------------------------------------------------------===//
233
234LogicalResult QuantizeCastOp::verify() {
235 return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
236 getInput().getType());
237}
238
239OpFoldResult QuantizeCastOp::fold(FoldAdaptor adaptor) {
240 // Matches x -> quant.dcast -> quant.qcast -> y, replacing the quant.qcast op
241 // with the value of x if the casts invert each other. Contrary to the folding
242 // pattern in quant.dcast (i.e., x -> quant.qcast -> quant.dcast -> y), values
243 // x and y are not guaranteed to be of the same type here, as they may use
244 // different quantization parameters.
245 auto srcDcastOp = getInput().getDefiningOp<DequantizeCastOp>();
246 if (!srcDcastOp || srcDcastOp.getInput().getType() != getType())
247 return {};
248 return srcDcastOp.getInput();
249}
250
251FloatType QuantizeCastOp::getFloatType() {
252 return cast<FloatType>(getElementTypeOrSelf(getInput().getType()));
253}
254
255QuantizedType QuantizeCastOp::getQuantizedType() {
256 return cast<QuantizedType>(getElementTypeOrSelf(getResult().getType()));
257}
258
259//===----------------------------------------------------------------------===//
260// StorageCastOp
261//===----------------------------------------------------------------------===//
262
263LogicalResult StorageCastOp::verify() {
264 auto quantizedType = getQuantizedType();
265 auto integerType = getIntegerType();
266 if (quantizedType.getStorageType() != integerType)
267 return emitError(
268 "storage type in quantized type expected to match integer type");
269
270 // Verify integrity of per-axis quantization information, if available. While
271 // the quantization type may appear in the input or the result, their tensor
272 // shapes are guaranteed to be identical at this point.
273 if (auto quantizedPerAxisType =
274 dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
275 return verifyPerAxisQuantization(*this, quantizedPerAxisType,
276 getInput().getType());
277 }
278
279 if (auto quantizedSunChannelType =
280 dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) {
281 return verifySubChannelQuantization(*this, quantizedSunChannelType,
282 getInput().getType());
283 }
284
285 // At this point the type is UniformQuantizedType
286 return success();
287}
288
289OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
290 // Matches x -> quant.scast -> quant.scast -> y, replacing the second
291 // quant.scast with the value of x if the casts invert each other.
292 auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
293 if (!srcScastOp || srcScastOp.getInput().getType() != getType())
294 return {};
295 return srcScastOp.getInput();
296}
297
298IntegerType StorageCastOp::getIntegerType() {
299 auto inputScalarType = getElementTypeOrSelf(getInput().getType());
300 if (auto integerType = dyn_cast<IntegerType>(inputScalarType))
301 return integerType;
302
303 auto resultScalarType = getElementTypeOrSelf(getResult().getType());
304 return cast<IntegerType>(resultScalarType);
305}
306
307QuantizedType StorageCastOp::getQuantizedType() {
308 auto inputScalarType = getElementTypeOrSelf(getInput().getType());
309 if (auto quantizedType = dyn_cast<QuantizedType>(inputScalarType))
310 return quantizedType;
311
312 auto resultScalarType = getElementTypeOrSelf(getResult().getType());
313 return cast<QuantizedType>(resultScalarType);
314}
315
316} // namespace quant
317} // namespace mlir
318
319#define GET_OP_CLASSES
320#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
return success()
A quantized type that maps storage to/from expressed types in an unspecified way.
Definition QuantTypes.h:203
A quantized type that infers its range from given min/max values.
Definition QuantTypes.h:524
Base class for all quantized types known to this dialect.
Definition QuantTypes.h:50
Represents per-axis (also known as per-channel quantization).
Definition QuantTypes.h:324
Represents sub-channel (also known as blockwise quantization).
Definition QuantTypes.h:409
Represents a family of uniform, quantized types.
Definition QuantTypes.h:264
FloatType getFloatType(MLIRContext *context, unsigned width)
Returns a supported MLIR floating point type of the given bit width or null if the bit width is not s...
void addBytecodeInterface(QuantDialect *dialect)
Add the interfaces necessary for encoding the quantization dialect components in bytecode.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.