MLIR 23.0.0git
QuantOps.cpp
Go to the documentation of this file.
1//===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "TypeDetail.h"
11
18
19#include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
20
21namespace mlir {
22namespace quant {
23
24namespace {
25
26// Verify the integrity of per-axis quantization information, if present.
27//
28// - uniformQuantizedPerAxisType
29// A quantized type with per-axis quantization.
30//
31// - containerType
32// Original input or result type of the operation using the provided quantized
33// type. Used to ensure that the quantized type appears within a tensor and
34// that the tensor is compatible with per-axis quantization information.
35//
36LogicalResult verifyPerAxisQuantization(
37 Operation *op, UniformQuantizedPerAxisType uniformQuantizedPerAxisType,
38 Type containerType) {
39 auto tensorType = dyn_cast<TensorType>(containerType);
40 if (!tensorType)
41 return op->emitError("scalar types may not use per-axis quantization");
42
43 if (!tensorType.hasRank())
44 return success();
45
46 int32_t quantizedDimension =
47 uniformQuantizedPerAxisType.getQuantizedDimension();
48 if ((int64_t)quantizedDimension >= tensorType.getRank())
49 return op->emitError("quantized dimension must be less than tensor rank");
50
51 int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
52 if (quantizedDimensionSize != ShapedType::kDynamic &&
53 quantizedDimensionSize !=
54 (int64_t)uniformQuantizedPerAxisType.getScales().size())
55 return op->emitError(
56 "quantized dimension size does not match number of scales");
57
58 return success();
59}
60
61// Verifies that the sub-channel quantization parameters are consistent with
62// the given container type. The function checks the following:
63//
64// - The container type must be a ranked tensor type.
65// - Each quantized dimension must be less than the rank of the tensor.
66// - The size of each dimension at the quantized dimension must be divisible
67// by the corresponding block size.
68// - The scale dimension size at each axis index should match the tensor
69// dimension at the index divided by the corresponding block size.
70//
71// The `uniformQuantizedSubChannelType` argument provides the sub-channel
72// quantization parameters, and the `containerType` argument specifies the
73// type of the container holding the quantized data.
74//
75LogicalResult verifySubChannelQuantization(
76 Operation *op,
77 UniformQuantizedSubChannelType uniformQuantizedSubChannelType,
78 Type containerType) {
79 auto tensorType = dyn_cast<TensorType>(containerType);
80 if (!tensorType)
81 return op->emitError("scalar types may not use sub-channel quantization");
82
83 if (!tensorType.hasRank())
84 return op->emitError(
85 "tensor containing the sub-channel quantized type must be ranked");
86
87 const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo =
88 uniformQuantizedSubChannelType.getBlockSizeInfo();
89 auto shape = tensorType.getShape();
90
91 // The dimension size of scale for an axis which is not specified as quantized
92 // dimension should be 1.
93 SmallVector<int64_t> expectedScaleShape(tensorType.getShape().size(), 1);
94 for (auto [quantizedDimension, blockSize] : blockSizeInfo) {
95 if (quantizedDimension >= tensorType.getRank())
96 return op->emitError()
97 << "quantized dimension " << quantizedDimension
98 << " must be less than tensor rank " << tensorType.getRank();
99 if (!tensorType.isDynamicDim(quantizedDimension) &&
100 tensorType.getDimSize(quantizedDimension) % blockSize != 0)
101 return op->emitError()
102 << "tensor dimension size "
103 << tensorType.getDimSize(quantizedDimension) << " at axis "
104 << quantizedDimension
105 << " must be divisible by the corresponding block size "
106 << blockSize;
107 if (tensorType.isDynamicDim(quantizedDimension))
108 expectedScaleShape[quantizedDimension] = ShapedType::kDynamic;
109 else
110 expectedScaleShape[quantizedDimension] =
111 tensorType.getDimSize(quantizedDimension) / blockSize;
112 }
113
114 // Block sizes must be greater than 0 and divide the corresponding dimension
115 // size. While a block size b must be less than or equal to the corresponding
116 // dimension size d, this constraint is implicitly enforced by requiring that
117 // d % b == 0 when d != 0.
118 //
119 // However, a problem arises when d = 0. The divisibility constraint allows b
120 // to be any value, potentially violating the requirement that b <= d.
121 // Furthermore, if b is unspecified (implicitly equal to d), it violates the
122 // constraint that b > 0.
123 //
124 // Therefore, we explicitly disallow the case where d = 0 to maintain
125 // consistency and avoid these issues.
126 if (llvm::is_contained(tensorType.getShape(), 0)) {
127 return op->emitError() << "tensor dimension size of zero is not allowed "
128 "with sub-channel quantization";
129 }
130
131 auto scaleShape =
132 uniformQuantizedSubChannelType.getScales().getType().getShape();
133 if (scaleShape.size() != shape.size()) {
134 return op->emitError() << "Rank of scales " << scaleShape.size()
135 << " must match "
136 << "the rank of the tensor " << shape.size();
137 }
138
139 for (auto [index, scaleDim] : llvm::enumerate(expectedScaleShape)) {
140 if (expectedScaleShape[index] != ShapedType::kDynamic &&
141 expectedScaleShape[index] != scaleShape[index])
142 return op->emitError() << "dimension size " << scaleDim
143 << " of scales tensor at axis " << index
144 << " should match (tensor dimension at axis / "
145 "block sizes at axis) = "
146 << expectedScaleShape[index];
147 }
148
149 return success();
150}
151
152// Common verification logic for 'quant.dcast' and 'quant.qcast' ops.
153//
154// - quantizedType
155// Quantized type used in the input ('quant.dcast') or result ('quant.qcast'),
156// whether as a primitive type or in a tensor.
157//
158// - floatType
159// Float type used in the input ('quant.qcast') or result ('quant.dcast'),
160// whether as a primitive type or in a tensor.
161//
162// - containerType
163// Type of original input or result.
164//
165LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
166 FloatType floatType, Type containerType) {
167 if (quantizedType.getExpressedType() != floatType)
168 return op->emitError(
169 "expressed type in quantized type expected to match float type");
170
171 // Verify integrity of per-axis quantization information, if present.
172 if (auto quantizedPerAxisType =
173 dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
174 return verifyPerAxisQuantization(op, quantizedPerAxisType, containerType);
175 }
176
177 if (auto quantizedSubChannelType =
178 dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) {
179 return verifySubChannelQuantization(op, quantizedSubChannelType,
180 containerType);
181 }
182
183 // At this point the type is UniformQuantizedType
184 return success();
185}
186
187struct QuantInlinerInterface : public DialectInlinerInterface {
188 using DialectInlinerInterface::DialectInlinerInterface;
189 /// All quant dialect ops can be inlined.
190 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
191 return true;
192 }
193};
194
195} // namespace
196
197//===----------------------------------------------------------------------===//
198// Dialect
199//===----------------------------------------------------------------------===//
200
201void QuantDialect::initialize() {
204 QuantileType>();
205 addOperations<
206#define GET_OP_LIST
207#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
208 >();
210 addInterfaces<QuantInlinerInterface>();
211}
212
213//===----------------------------------------------------------------------===//
214// DequantizeCastOp
215//===----------------------------------------------------------------------===//
216
217LogicalResult DequantizeCastOp::verify() {
218 return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
219 getInput().getType());
220}
221
222OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) {
223 // Matches x -> quant.qcast -> quant.dcast -> y, replacing the quant.dcast op
224 // with the value of x. Values x and y are guaranteed to be of the same type
225 // in this pattern.
226 auto srcQcastOp = getInput().getDefiningOp<QuantizeCastOp>();
227 if (!srcQcastOp)
228 return {};
229 assert(srcQcastOp.getInput().getType() == getType());
230 return srcQcastOp.getInput();
231}
232
233FloatType DequantizeCastOp::getFloatType() {
234 return cast<FloatType>(getElementTypeOrSelf(getResult().getType()));
235}
236
237QuantizedType DequantizeCastOp::getQuantizedType() {
238 return cast<QuantizedType>(getElementTypeOrSelf(getInput().getType()));
239}
240
241//===----------------------------------------------------------------------===//
242// QuantizeCastOp
243//===----------------------------------------------------------------------===//
244
245LogicalResult QuantizeCastOp::verify() {
246 return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
247 getInput().getType());
248}
249
250OpFoldResult QuantizeCastOp::fold(FoldAdaptor adaptor) {
251 // Matches x -> quant.dcast -> quant.qcast -> y, replacing the quant.qcast op
252 // with the value of x if the casts invert each other. Contrary to the folding
253 // pattern in quant.dcast (i.e., x -> quant.qcast -> quant.dcast -> y), values
254 // x and y are not guaranteed to be of the same type here, as they may use
255 // different quantization parameters.
256 auto srcDcastOp = getInput().getDefiningOp<DequantizeCastOp>();
257 if (!srcDcastOp || srcDcastOp.getInput().getType() != getType())
258 return {};
259 return srcDcastOp.getInput();
260}
261
262FloatType QuantizeCastOp::getFloatType() {
263 return cast<FloatType>(getElementTypeOrSelf(getInput().getType()));
264}
265
266QuantizedType QuantizeCastOp::getQuantizedType() {
267 return cast<QuantizedType>(getElementTypeOrSelf(getResult().getType()));
268}
269
270//===----------------------------------------------------------------------===//
271// StorageCastOp
272//===----------------------------------------------------------------------===//
273
274LogicalResult StorageCastOp::verify() {
275 auto quantizedType = getQuantizedType();
276 auto integerType = getIntegerType();
277 if (quantizedType.getStorageType() != integerType)
278 return emitError(
279 "storage type in quantized type expected to match integer type");
280
281 // Verify integrity of per-axis quantization information, if available. While
282 // the quantization type may appear in the input or the result, their tensor
283 // shapes are guaranteed to be identical at this point.
284 if (auto quantizedPerAxisType =
285 dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
286 return verifyPerAxisQuantization(*this, quantizedPerAxisType,
287 getInput().getType());
288 }
289
290 if (auto quantizedSunChannelType =
291 dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) {
292 return verifySubChannelQuantization(*this, quantizedSunChannelType,
293 getInput().getType());
294 }
295
296 // At this point the type is UniformQuantizedType
297 return success();
298}
299
300OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
301 // Matches x -> quant.scast -> quant.scast -> y, replacing the second
302 // quant.scast with the value of x if the casts invert each other.
303 auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
304 if (!srcScastOp || srcScastOp.getInput().getType() != getType())
305 return {};
306 return srcScastOp.getInput();
307}
308
309IntegerType StorageCastOp::getIntegerType() {
310 auto inputScalarType = getElementTypeOrSelf(getInput().getType());
311 if (auto integerType = dyn_cast<IntegerType>(inputScalarType))
312 return integerType;
313
314 auto resultScalarType = getElementTypeOrSelf(getResult().getType());
315 return cast<IntegerType>(resultScalarType);
316}
317
318QuantizedType StorageCastOp::getQuantizedType() {
319 auto inputScalarType = getElementTypeOrSelf(getInput().getType());
320 if (auto quantizedType = dyn_cast<QuantizedType>(inputScalarType))
321 return quantizedType;
322
323 auto resultScalarType = getElementTypeOrSelf(getResult().getType());
324 return cast<QuantizedType>(resultScalarType);
325}
326
327} // namespace quant
328} // namespace mlir
329
330#define GET_OP_CLASSES
331#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
return success()
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
A quantized type that maps storage to/from expressed types in an unspecified way.
Definition QuantTypes.h:204
A quantized type that infers its range from given min/max values.
Definition QuantTypes.h:525
Base class for all quantized types known to this dialect.
Definition QuantTypes.h:51
Represents per-axis (also known as per-channel quantization).
Definition QuantTypes.h:325
Represents sub-channel (also known as blockwise quantization).
Definition QuantTypes.h:410
Represents a family of uniform, quantized types.
Definition QuantTypes.h:265
FloatType getFloatType(MLIRContext *context, unsigned width)
Returns a supported MLIR floating point type of the given bit width or null if the bit width is not s...
void addBytecodeInterface(QuantDialect *dialect)
Add the interfaces necessary for encoding the quantization dialect components in bytecode.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.