MLIR 22.0.0git
FakeQuantSupport.cpp
Go to the documentation of this file.
1//===- FakeQuantSupport.cpp - Support utilities for FakeQuant ops ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11
12using namespace mlir;
13using namespace mlir::quant;
14
15static bool getDefaultStorageParams(unsigned numBits, bool narrowRange,
16 bool isSigned, MLIRContext *ctx,
17 Type &storageType, int64_t &qmin,
18 int64_t &qmax) {
19 // Hard-coded type mapping from TFLite.
20 if (numBits <= 8) {
21 storageType = IntegerType::get(ctx, 8);
22 if (isSigned) {
23 qmin = -128;
24 qmax = 127;
25 } else {
26 qmin = 0;
27 qmax = 255;
28 }
29 } else if (numBits <= 16) {
30 storageType = IntegerType::get(ctx, 16);
31 if (isSigned) {
32 qmin = -32768;
33 qmax = 32767;
34 } else {
35 qmin = 0;
36 qmax = 65535;
37 }
38 } else if (numBits <= 32) {
39 storageType = IntegerType::get(ctx, 32);
40 if (isSigned) {
41 qmin = std::numeric_limits<int32_t>::min();
42 qmax = std::numeric_limits<int32_t>::max();
43 } else {
44 qmin = std::numeric_limits<uint32_t>::min();
45 qmax = std::numeric_limits<uint32_t>::max();
46 }
47 } else {
48 return true;
49 }
50
51 // Handle narrowRange.
52 if (narrowRange) {
53 qmin += 1;
54 }
55 return false;
56}
57
58// This is a specific implementation of nudging:
59// If 0.0 < rmin < rmax or rmin < rmax < 0.0, the range will be shifted
60// to include 0.0, but the range width size (rmax-rmin) isn't changed. The zero
61// point is derived from the shifted range, and the scale isn't changed. As
62// a consequence some values, which are supposed in the original [rmin, rmax]
63// range will be outside the shifted range and be clamped during quantization.
64// TODO: we should nudge the scale as well, but that requires the
65// fake quant op used in the training to use the nudged scale as well.
66static void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin,
67 double rmax, double &scale,
68 int64_t &nudgedZeroPoint) {
69 // Determine the scale.
70 const double qminDouble = qmin;
71 const double qmaxDouble = qmax;
72 scale = (rmax - rmin) / (qmaxDouble - qminDouble);
73
74 // Zero point computation.
75 // In float, solve the affine equation for any known pair
76 // (real value, corresponding quantized value), of which, two such pairs
77 // are known: (rmin, qmin), (rmax, qmax).
78 // The arithmetic error on the zero point computed from either pair will be
79 // roughly machine_epsilon * (sum of absolute values of terms).
80 // Use the variant that adds the smaller error.
81 const double zeroPointFromMin = qminDouble - rmin / scale;
82 const double zeroPointFromMinError =
83 std::abs(qminDouble) + std::abs(rmin / scale);
84 const double zeroPointFromMax = qmaxDouble - rmax / scale;
85 const double zeroPointFromMaxError =
86 std::abs(qmaxDouble) + std::abs(rmax / scale);
87
88 const double zeroPointDouble = (zeroPointFromMinError < zeroPointFromMaxError)
89 ? zeroPointFromMin
90 : zeroPointFromMax;
91
92 // Now nudge the zero point to be an integer.
93 nudgedZeroPoint = 0;
94 if (zeroPointDouble < qminDouble) {
95 nudgedZeroPoint = qmin;
96 } else if (zeroPointDouble > qmaxDouble) {
97 nudgedZeroPoint = qmax;
98 } else {
99 nudgedZeroPoint = round(zeroPointDouble);
100 }
101
102 // By construction, the nudged zero point should always be in range.
103 assert(nudgedZeroPoint >= qmin);
104 assert(nudgedZeroPoint <= qmax);
105}
106
108mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin,
109 double rmax, bool narrowRange,
110 Type expressedType, bool isSigned) {
111 MLIRContext *ctx = expressedType.getContext();
112 unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
113 Type storageType;
114 int64_t qmin;
115 int64_t qmax;
116 if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
117 qmin, qmax)) {
118 return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits,
119 nullptr);
120 }
121
122 // Special case where min/max is close enough. The tensor contents are all
123 // 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero
124 // points and dequantized to 0.0.
125 if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
127 loc, flags, storageType, expressedType, 1.0, qmin, qmin, qmax);
128 }
129
130 double scale;
131 int64_t nudgedZeroPoint;
132 getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
133
134 return UniformQuantizedType::getChecked(loc, flags, storageType,
135 expressedType, scale, nudgedZeroPoint,
136 qmin, qmax);
137}
138
140 Location loc, unsigned numBits, int32_t quantizedDimension,
141 ArrayRef<double> rmins, ArrayRef<double> rmaxs, bool narrowRange,
142 Type expressedType, bool isSigned) {
143 size_t axisSize = rmins.size();
144 if (axisSize != rmaxs.size()) {
145 return (emitError(loc, "mismatched per-axis min and max size: ")
146 << axisSize << " vs. " << rmaxs.size(),
147 nullptr);
148 }
149
150 MLIRContext *ctx = expressedType.getContext();
151 Type storageType;
152 int64_t qmin;
153 int64_t qmax;
154 if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
155 qmin, qmax)) {
156 return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits,
157 nullptr);
158 }
159
161 SmallVector<int64_t, 4> zeroPoints;
162 scales.reserve(axisSize);
163 zeroPoints.reserve(axisSize);
164 for (size_t axis = 0; axis != axisSize; ++axis) {
165 double rmin = rmins[axis];
166 double rmax = rmaxs[axis];
167 if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
168 scales.push_back(1.0);
169 zeroPoints.push_back(qmin);
170 continue;
171 }
172
173 double scale;
174 int64_t nudgedZeroPoint;
175 getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
176 scales.push_back(scale);
177 zeroPoints.push_back(nudgedZeroPoint);
178 }
179
180 unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
182 loc, flags, storageType, expressedType, scales, zeroPoints,
183 quantizedDimension, qmin, qmax);
184}
static void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, double rmax, double &scale, int64_t &nudgedZeroPoint)
static bool getDefaultStorageParams(unsigned numBits, bool narrowRange, bool isSigned, MLIRContext *ctx, Type &storageType, int64_t &qmin, int64_t &qmax)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
Represents per-axis (also known as per-channel quantization).
Definition QuantTypes.h:324
static UniformQuantizedPerAxisType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
Represents a family of uniform, quantized types.
Definition QuantTypes.h:264
static UniformQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin, double rmax, bool narrowRange, Type expressedType, bool isSigned=false)
Converts per-layer FakeQuant attributes to the corresponding type.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.