MLIR 23.0.0git
QuantOps.cpp
Go to the documentation of this file.
1//===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "TypeDetail.h"
11
18
19#include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
20
21namespace mlir {
22namespace quant {
23
24namespace {
25
26// Verify the integrity of per-axis quantization information, if present.
27//
28// - uniformQuantizedPerAxisType
29// A quantized type with per-axis quantization.
30//
31// - containerType
32// Original input or result type of the operation using the provided quantized
33// type. Used to ensure that the quantized type appears within a tensor and
34// that the tensor is compatible with per-axis quantization information.
35//
36LogicalResult verifyPerAxisQuantization(
37 Operation *op, UniformQuantizedPerAxisType uniformQuantizedPerAxisType,
38 Type containerType) {
39 auto tensorType = dyn_cast<TensorType>(containerType);
40 if (!tensorType)
41 return op->emitError("scalar types may not use per-axis quantization");
42
43 if (!tensorType.hasRank())
44 return success();
45
46 int32_t quantizedDimension =
47 uniformQuantizedPerAxisType.getQuantizedDimension();
48 if ((int64_t)quantizedDimension >= tensorType.getRank())
49 return op->emitError("quantized dimension must be less than tensor rank");
50
51 int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
52 if (quantizedDimensionSize != ShapedType::kDynamic &&
53 quantizedDimensionSize !=
54 (int64_t)uniformQuantizedPerAxisType.getScales().size())
55 return op->emitError(
56 "quantized dimension size does not match number of scales");
57
58 return success();
59}
60
61// Verifies that the sub-channel quantization parameters are consistent with
62// the given container type. The function checks the following:
63//
64// - The container type must be a ranked tensor type.
65// - Each quantized dimension must be less than the rank of the tensor.
66// - The size of each dimension at the quantized dimension must be divisible
67// by the corresponding block size.
68// - The scale dimension size at each axis index should match the tensor
69// dimension at the index divided by the corresponding block size.
70//
71// The `uniformQuantizedSubChannelType` argument provides the sub-channel
72// quantization parameters, and the `containerType` argument specifies the
73// type of the container holding the quantized data.
74//
75LogicalResult verifySubChannelQuantization(
76 Operation *op,
77 UniformQuantizedSubChannelType uniformQuantizedSubChannelType,
78 Type containerType) {
79 auto tensorType = dyn_cast<TensorType>(containerType);
80 if (!tensorType)
81 return op->emitError("scalar types may not use sub-channel quantization");
82
83 if (!tensorType.hasRank())
84 return op->emitError(
85 "tensor containing the sub-channel quantized type must be ranked");
86
87 const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo =
88 uniformQuantizedSubChannelType.getBlockSizeInfo();
89 auto shape = tensorType.getShape();
90
91 // The dimension size of scale for an axis which is not specified as quantized
92 // dimension should be 1.
93 SmallVector<int64_t> expectedScaleShape(tensorType.getShape().size(), 1);
94 for (auto [quantizedDimension, blockSize] : blockSizeInfo) {
95 if (quantizedDimension >= tensorType.getRank())
96 return op->emitError()
97 << "quantized dimension " << quantizedDimension
98 << " must be less than tensor rank " << tensorType.getRank();
99 if (!tensorType.isDynamicDim(quantizedDimension) &&
100 tensorType.getDimSize(quantizedDimension) % blockSize != 0)
101 return op->emitError()
102 << "tensor dimension size "
103 << tensorType.getDimSize(quantizedDimension) << " at axis "
104 << quantizedDimension
105 << " must be divisible by the corresponding block size "
106 << blockSize;
107 if (tensorType.isDynamicDim(quantizedDimension))
108 expectedScaleShape[quantizedDimension] = ShapedType::kDynamic;
109 else
110 expectedScaleShape[quantizedDimension] =
111 tensorType.getDimSize(quantizedDimension) / blockSize;
112 }
113
114 // Block sizes must be greater than 0 and divide the corresponding dimension
115 // size. While a block size b must be less than or equal to the corresponding
116 // dimension size d, this constraint is implicitly enforced by requiring that
117 // d % b == 0 when d != 0.
118 //
119 // However, a problem arises when d = 0. The divisibility constraint allows b
120 // to be any value, potentially violating the requirement that b <= d.
121 // Furthermore, if b is unspecified (implicitly equal to d), it violates the
122 // constraint that b > 0.
123 //
124 // Therefore, we explicitly disallow the case where d = 0 to maintain
125 // consistency and avoid these issues.
126 if (llvm::is_contained(tensorType.getShape(), 0)) {
127 return op->emitError() << "tensor dimension size of zero is not allowed "
128 "with sub-channel quantization";
129 }
130
131 auto scaleShape =
132 uniformQuantizedSubChannelType.getScales().getType().getShape();
133 if (scaleShape.size() != shape.size()) {
134 return op->emitError() << "Rank of scales " << scaleShape.size()
135 << " must match "
136 << "the rank of the tensor " << shape.size();
137 }
138
139 for (auto [index, scaleDim] : llvm::enumerate(expectedScaleShape)) {
140 if (expectedScaleShape[index] != ShapedType::kDynamic &&
141 expectedScaleShape[index] != scaleShape[index])
142 return op->emitError() << "dimension size " << scaleDim
143 << " of scales tensor at axis " << index
144 << " should match (tensor dimension at axis / "
145 "block sizes at axis) = "
146 << expectedScaleShape[index];
147 }
148
149 return success();
150}
151
152// Common verification logic for 'quant.dcast' and 'quant.qcast' ops.
153//
154// - quantizedType
155// Quantized type used in the input ('quant.dcast') or result ('quant.qcast'),
156// whether as a primitive type or in a tensor.
157//
158// - floatType
159// Float type used in the input ('quant.qcast') or result ('quant.dcast'),
160// whether as a primitive type or in a tensor.
161//
162// - containerType
163// Type of original input or result.
164//
165LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
166 FloatType floatType, Type containerType) {
167 if (quantizedType.getExpressedType() != floatType)
168 return op->emitError(
169 "expressed type in quantized type expected to match float type");
170
171 // Verify integrity of per-axis quantization information, if present.
172 if (auto quantizedPerAxisType =
173 dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
174 return verifyPerAxisQuantization(op, quantizedPerAxisType, containerType);
175 }
176
177 if (auto quantizedSubChannelType =
178 dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) {
179 return verifySubChannelQuantization(op, quantizedSubChannelType,
180 containerType);
181 }
182
183 // At this point the type is UniformQuantizedType
184 return success();
185}
186
187struct QuantInlinerInterface : public DialectInlinerInterface {
188 using DialectInlinerInterface::DialectInlinerInterface;
189 /// All quant dialect ops can be inlined.
190 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
191 return true;
192 }
193};
194
195} // namespace
196
197//===----------------------------------------------------------------------===//
198// Dialect
199//===----------------------------------------------------------------------===//
200
201void QuantDialect::initialize() {
204 addOperations<
205#define GET_OP_LIST
206#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
207 >();
209 addInterfaces<QuantInlinerInterface>();
210}
211
212//===----------------------------------------------------------------------===//
213// DequantizeCastOp
214//===----------------------------------------------------------------------===//
215
216LogicalResult DequantizeCastOp::verify() {
217 return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
218 getInput().getType());
219}
220
221OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) {
222 // Matches x -> quant.qcast -> quant.dcast -> y, replacing the quant.dcast op
223 // with the value of x. Values x and y are guaranteed to be of the same type
224 // in this pattern.
225 auto srcQcastOp = getInput().getDefiningOp<QuantizeCastOp>();
226 if (!srcQcastOp)
227 return {};
228 assert(srcQcastOp.getInput().getType() == getType());
229 return srcQcastOp.getInput();
230}
231
232FloatType DequantizeCastOp::getFloatType() {
233 return cast<FloatType>(getElementTypeOrSelf(getResult().getType()));
234}
235
236QuantizedType DequantizeCastOp::getQuantizedType() {
237 return cast<QuantizedType>(getElementTypeOrSelf(getInput().getType()));
238}
239
240//===----------------------------------------------------------------------===//
241// QuantizeCastOp
242//===----------------------------------------------------------------------===//
243
244LogicalResult QuantizeCastOp::verify() {
245 return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
246 getInput().getType());
247}
248
249OpFoldResult QuantizeCastOp::fold(FoldAdaptor adaptor) {
250 // Matches x -> quant.dcast -> quant.qcast -> y, replacing the quant.qcast op
251 // with the value of x if the casts invert each other. Contrary to the folding
252 // pattern in quant.dcast (i.e., x -> quant.qcast -> quant.dcast -> y), values
253 // x and y are not guaranteed to be of the same type here, as they may use
254 // different quantization parameters.
255 auto srcDcastOp = getInput().getDefiningOp<DequantizeCastOp>();
256 if (!srcDcastOp || srcDcastOp.getInput().getType() != getType())
257 return {};
258 return srcDcastOp.getInput();
259}
260
261FloatType QuantizeCastOp::getFloatType() {
262 return cast<FloatType>(getElementTypeOrSelf(getInput().getType()));
263}
264
265QuantizedType QuantizeCastOp::getQuantizedType() {
266 return cast<QuantizedType>(getElementTypeOrSelf(getResult().getType()));
267}
268
269//===----------------------------------------------------------------------===//
270// StorageCastOp
271//===----------------------------------------------------------------------===//
272
273LogicalResult StorageCastOp::verify() {
274 auto quantizedType = getQuantizedType();
275 auto integerType = getIntegerType();
276 if (quantizedType.getStorageType() != integerType)
277 return emitError(
278 "storage type in quantized type expected to match integer type");
279
280 // Verify integrity of per-axis quantization information, if available. While
281 // the quantization type may appear in the input or the result, their tensor
282 // shapes are guaranteed to be identical at this point.
283 if (auto quantizedPerAxisType =
284 dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
285 return verifyPerAxisQuantization(*this, quantizedPerAxisType,
286 getInput().getType());
287 }
288
289 if (auto quantizedSunChannelType =
290 dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) {
291 return verifySubChannelQuantization(*this, quantizedSunChannelType,
292 getInput().getType());
293 }
294
295 // At this point the type is UniformQuantizedType
296 return success();
297}
298
299OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
300 // Matches x -> quant.scast -> quant.scast -> y, replacing the second
301 // quant.scast with the value of x if the casts invert each other.
302 auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
303 if (!srcScastOp || srcScastOp.getInput().getType() != getType())
304 return {};
305 return srcScastOp.getInput();
306}
307
308IntegerType StorageCastOp::getIntegerType() {
309 auto inputScalarType = getElementTypeOrSelf(getInput().getType());
310 if (auto integerType = dyn_cast<IntegerType>(inputScalarType))
311 return integerType;
312
313 auto resultScalarType = getElementTypeOrSelf(getResult().getType());
314 return cast<IntegerType>(resultScalarType);
315}
316
317QuantizedType StorageCastOp::getQuantizedType() {
318 auto inputScalarType = getElementTypeOrSelf(getInput().getType());
319 if (auto quantizedType = dyn_cast<QuantizedType>(inputScalarType))
320 return quantizedType;
321
322 auto resultScalarType = getElementTypeOrSelf(getResult().getType());
323 return cast<QuantizedType>(resultScalarType);
324}
325
326} // namespace quant
327} // namespace mlir
328
329#define GET_OP_CLASSES
330#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
return success()
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
A quantized type that maps storage to/from expressed types in an unspecified way.
Definition QuantTypes.h:203
A quantized type that infers its range from given min/max values.
Definition QuantTypes.h:524
Base class for all quantized types known to this dialect.
Definition QuantTypes.h:50
Represents per-axis (also known as per-channel quantization).
Definition QuantTypes.h:324
Represents sub-channel (also known as blockwise quantization).
Definition QuantTypes.h:409
Represents a family of uniform, quantized types.
Definition QuantTypes.h:264
FloatType getFloatType(MLIRContext *context, unsigned width)
Returns a supported MLIR floating point type of the given bit width or null if the bit width is not s...
void addBytecodeInterface(QuantDialect *dialect)
Add the interfaces necessary for encoding the quantization dialect components in bytecode.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.