MLIR  20.0.0git
FakeQuantSupport.cpp
Go to the documentation of this file.
1 //===- FakeQuantSupport.cpp - Support utilities for FakeQuant ops ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 
12 using namespace mlir;
13 using namespace mlir::quant;
14 
15 static bool getDefaultStorageParams(unsigned numBits, bool narrowRange,
16  bool isSigned, MLIRContext *ctx,
17  Type &storageType, int64_t &qmin,
18  int64_t &qmax) {
19  // Hard-coded type mapping from TFLite.
20  if (numBits <= 8) {
21  storageType = IntegerType::get(ctx, 8);
22  if (isSigned) {
23  qmin = -128;
24  qmax = 127;
25  } else {
26  qmin = 0;
27  qmax = 255;
28  }
29  } else if (numBits <= 16) {
30  storageType = IntegerType::get(ctx, 16);
31  if (isSigned) {
32  qmin = -32768;
33  qmax = 32767;
34  } else {
35  qmin = 0;
36  qmax = 65535;
37  }
38  } else if (numBits <= 32) {
39  storageType = IntegerType::get(ctx, 32);
40  if (isSigned) {
43  } else {
46  }
47  } else {
48  return true;
49  }
50 
51  // Handle narrowRange.
52  if (narrowRange) {
53  qmin += 1;
54  }
55  return false;
56 }
57 
58 // This is a specific implementation of nudging:
59 // If 0.0 < rmin < rmax or rmin < rmax < 0.0, the range will be shifted
60 // to include 0.0, but the range width size (rmax-rmin) isn't changed. The zero
61 // point is derived from the shifted range, and the scale isn't changed. As
62 // a consequence some values, which are supposed in the original [rmin, rmax]
63 // range will be outside the shifted range and be clamped during quantization.
64 // TODO: we should nudge the scale as well, but that requires the
65 // fake quant op used in the training to use the nudged scale as well.
66 static void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin,
67  double rmax, double &scale,
68  int64_t &nudgedZeroPoint) {
69  // Determine the scale.
70  const double qminDouble = qmin;
71  const double qmaxDouble = qmax;
72  scale = (rmax - rmin) / (qmaxDouble - qminDouble);
73 
74  // Zero point computation.
75  // In float, solve the affine equation for any known pair
76  // (real value, corresponding quantized value), of which, two such pairs
77  // are known: (rmin, qmin), (rmax, qmax).
78  // The arithmetic error on the zero point computed from either pair will be
79  // roughly machine_epsilon * (sum of absolute values of terms).
80  // Use the variant that adds the smaller error.
81  const double zeroPointFromMin = qminDouble - rmin / scale;
82  const double zeroPointFromMinError =
83  std::abs(qminDouble) + std::abs(rmin / scale);
84  const double zeroPointFromMax = qmaxDouble - rmax / scale;
85  const double zeroPointFromMaxError =
86  std::abs(qmaxDouble) + std::abs(rmax / scale);
87 
88  const double zeroPointDouble = (zeroPointFromMinError < zeroPointFromMaxError)
89  ? zeroPointFromMin
90  : zeroPointFromMax;
91 
92  // Now nudge the zero point to be an integer.
93  nudgedZeroPoint = 0;
94  if (zeroPointDouble < qminDouble) {
95  nudgedZeroPoint = qmin;
96  } else if (zeroPointDouble > qmaxDouble) {
97  nudgedZeroPoint = qmax;
98  } else {
99  nudgedZeroPoint = round(zeroPointDouble);
100  }
101 
102  // By construction, the nudged zero point should always be in range.
103  assert(nudgedZeroPoint >= qmin);
104  assert(nudgedZeroPoint <= qmax);
105 }
106 
108 mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin,
109  double rmax, bool narrowRange,
110  Type expressedType, bool isSigned) {
111  MLIRContext *ctx = expressedType.getContext();
112  unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
113  Type storageType;
114  int64_t qmin;
115  int64_t qmax;
116  if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
117  qmin, qmax)) {
118  return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits,
119  nullptr);
120  }
121 
122  // Special case where min/max is close enough. The tensor contents are all
123  // 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero
124  // points and dequantized to 0.0.
125  if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
127  loc, flags, storageType, expressedType, 1.0, qmin, qmin, qmax);
128  }
129 
130  double scale;
131  int64_t nudgedZeroPoint;
132  getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
133 
134  return UniformQuantizedType::getChecked(loc, flags, storageType,
135  expressedType, scale, nudgedZeroPoint,
136  qmin, qmax);
137 }
138 
140  Location loc, unsigned numBits, int32_t quantizedDimension,
141  ArrayRef<double> rmins, ArrayRef<double> rmaxs, bool narrowRange,
142  Type expressedType, bool isSigned) {
143  size_t axisSize = rmins.size();
144  if (axisSize != rmaxs.size()) {
145  return (emitError(loc, "mismatched per-axis min and max size: ")
146  << axisSize << " vs. " << rmaxs.size(),
147  nullptr);
148  }
149 
150  MLIRContext *ctx = expressedType.getContext();
151  Type storageType;
152  int64_t qmin;
153  int64_t qmax;
154  if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType,
155  qmin, qmax)) {
156  return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits,
157  nullptr);
158  }
159 
160  SmallVector<double, 4> scales;
161  SmallVector<int64_t, 4> zeroPoints;
162  scales.reserve(axisSize);
163  zeroPoints.reserve(axisSize);
164  for (size_t axis = 0; axis != axisSize; ++axis) {
165  double rmin = rmins[axis];
166  double rmax = rmaxs[axis];
167  if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
168  scales.push_back(1.0);
169  zeroPoints.push_back(qmin);
170  continue;
171  }
172 
173  double scale;
174  int64_t nudgedZeroPoint;
175  getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
176  scales.push_back(scale);
177  zeroPoints.push_back(nudgedZeroPoint);
178  }
179 
180  unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
182  loc, flags, storageType, expressedType, scales, zeroPoints,
183  quantizedDimension, qmin, qmax);
184 }
static void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, double rmax, double &scale, int64_t &nudgedZeroPoint)
static bool getDefaultStorageParams(unsigned numBits, bool narrowRange, bool isSigned, MLIRContext *ctx, Type &storageType, int64_t &qmin, int64_t &qmax)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
Represents per-axis (also known as per-channel quantization).
Definition: QuantTypes.h:321
static UniformQuantizedPerAxisType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:348
Represents a family of uniform, quantized types.
Definition: QuantTypes.h:261
static UniformQuantizedType getChecked(function_ref< InFlightDiagnostic()> emitError, unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all specified parameters checked.
Definition: QuantTypes.cpp:292
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin, double rmax, bool narrowRange, Type expressedType, bool isSigned=false)
Converts per-layer FakeQuant attributes to the corresponding type.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...