MLIR 22.0.0git
UniformSupport.h
Go to the documentation of this file.
1//===- UniformSupport.h - Support utilities for uniform quant ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_QUANT_UTILS_UNIFORMSUPPORT_H_
10#define MLIR_DIALECT_QUANT_UTILS_UNIFORMSUPPORT_H_
11
12#include <utility>
13
16#include "mlir/IR/Types.h"
17#include "llvm/ADT/APFloat.h"
18#include "llvm/ADT/APInt.h"
19#include "llvm/ADT/APSInt.h"
20
21namespace mlir {
22namespace quant {
23
24/// Performs type conversion from an arbitrary input type to a type
25/// that is expressed by a QuantizedType.
26///
27/// This handles cases where the inputType is a supported primitive type
28/// (i.e. f32, bf16, etc) or a vector/tensor type based on a supported
29/// elemental type.
30///
31/// Since conversion often involves introspecting some attributes of the
32/// input type in order to determine how to represent it, this is a two step
33/// process.
35 /// Creates a converter for the given input type.
37
38 /// Converts the inputType to be based on the given elemental type,
39 /// returning the new type (or nullptr and emit an error on failure).
40 Type convert(QuantizedType elementalType) const;
41
42 /// Whether the conversion is legal.
43 explicit operator bool() const { return (bool)expressedType; }
44
45 /// The input type that is being converted from.
46 /// This may be an elemental or composite type.
48
49 /// Supported, elemental expressed type (i.e. f32).
50 /// Will be nullptr if conversion is not supported.
52};
53
54/// Reference implementation of converting between real numbers and values
55/// represented by a UniformQuantizedType.
56/// Note that this is not expected to be speedy and may be superseded eventually
57/// by a more optimal implementation.
58/// Also, the interface assumes that quantization is done per-layer and will
59/// need to be wider for various per-channel schemes. As such, this is a
60/// placeholder.
62public:
65 uniformType.getScale(),
66 static_cast<double>(uniformType.getZeroPoint()),
67 static_cast<double>(uniformType.getStorageTypeMin()),
68 static_cast<double>(uniformType.getStorageTypeMax()),
69 uniformType.getStorageTypeIntegralWidth(), uniformType.isSigned()) {
70 assert(isa<FloatType>(uniformType.getExpressedType()));
71 assert(uniformType.getStorageType().isSignlessInteger());
72 }
73
74 UniformQuantizedValueConverter(double scale, double zeroPoint,
75 double clampMin, double clampMax,
76 uint32_t storageBitWidth, bool isSigned)
77 : scale(scale), zeroPoint(zeroPoint), clampMin(clampMin),
78 clampMax(clampMax), scaleDouble(scale), zeroPointDouble(zeroPoint),
79 clampMinDouble(clampMin), clampMaxDouble(clampMax),
80 storageBitWidth(storageBitWidth), isSigned(isSigned),
81 roundMode(APFloat::rmNearestTiesToAway) {}
82
83 UniformQuantizedValueConverter(double scale, double zeroPoint,
84 const APFloat &clampMin,
85 const APFloat &clampMax,
86 uint32_t storageBitWidth, bool isSigned)
87 : scale(scale), zeroPoint(zeroPoint), clampMin(clampMin),
88 clampMax(clampMax), scaleDouble(scale), zeroPointDouble(zeroPoint),
89 clampMinDouble(clampMin.convertToDouble()),
90 clampMaxDouble(clampMax.convertToDouble()),
91 storageBitWidth(storageBitWidth), isSigned(isSigned),
92 roundMode(APFloat::rmNearestTiesToAway) {}
93
94 virtual APInt quantizeFloatToInt(APFloat expressedValue) const {
95 // This function is a performance critical code path in quantization
96 // since it runs for each single float parameter value.
97
98 // Specialize f32->u8/i8 case to optimize performance.
99 if (&expressedValue.getSemantics() == &APFloat::IEEEsingle() &&
100 storageBitWidth == 8 &&
101 roundMode == llvm::APFloatBase::rmNearestTiesToAway) {
102 return quantizeF32ToInt8(expressedValue);
103 }
104
105 bool lossy;
106 expressedValue.convert(scale.getSemantics(), roundMode, &lossy);
107 // fixedpoint = clamp(clampMin, clampMax, (
108 // roundHalfToEven(expressed / scale) + zeroPoint))
109 APFloat scaled = (expressedValue / scale);
110 scaled.roundToIntegral(roundMode);
111 scaled.add(zeroPoint, roundMode);
112 APFloat fixedpoint = llvm::minimum(scaled, clampMax);
113 fixedpoint = llvm::maximum(fixedpoint, clampMin);
114
115 llvm::APSInt result(storageBitWidth, !isSigned);
116 fixedpoint.convertToInteger(result, roundMode, &lossy);
117
118 return std::move(result);
119 }
120
121 int64_t quantizeFloatToInt64(APFloat expressedValue) const {
122 APInt qValue = quantizeFloatToInt(std::move(expressedValue));
123 return isSigned ? qValue.getSExtValue() : qValue.getZExtValue();
124 }
125
127
128private:
129 // An optimized implementation to quantize f32 to i8/u8 with C++ native
130 // arithmetic.
131 virtual APInt quantizeF32ToInt8(APFloat expressedValue) const {
132 assert(&expressedValue.getSemantics() == &APFloat::IEEEsingle());
133 assert(storageBitWidth == 8);
134 assert(roundMode == llvm::APFloatBase::rmNearestTiesToAway);
135
136 const float realValue = expressedValue.convertToFloat();
137
138 const double scaled = realValue / scaleDouble + zeroPointDouble;
139 // Round to nearest integer with halfway cases rounded away from zero.
140 const double scaledRounded = std::round(scaled);
141 const double clamped =
142 std::min(std::max(scaledRounded, clampMinDouble), clampMaxDouble);
143
144 uint64_t signlessResult;
145 if (isSigned) {
146 int64_t clampedInt = static_cast<int8_t>(clamped);
147 memcpy(&signlessResult, &clampedInt, sizeof(clampedInt));
148 } else {
149 signlessResult = static_cast<uint8_t>(clamped);
150 }
151 return APInt(storageBitWidth, signlessResult);
152 }
153
154 // Keep both APFloat and double versions of the quantization parameters
155 // around since they will be used in generic and specialized arithmetic,
156 // respectively.
157 const APFloat scale;
158 const APFloat zeroPoint;
159 const APFloat clampMin;
160 const APFloat clampMax;
161
162 const double scaleDouble;
163 const double zeroPointDouble;
164 const double clampMinDouble;
165 const double clampMaxDouble;
166
167 const uint32_t storageBitWidth;
168 const bool isSigned;
169 const llvm::APFloat::roundingMode roundMode;
170};
171
172/// An utility class to quantize an attribute by the per-axis quantization
173/// parameters. The size of the quantization dim in the converted elements
174/// attribute should match the size of scales/zeroPoints vectors in the
175/// quantization parameters.
177public:
179 UniformQuantizedPerAxisType uniformType)
180 : scales(uniformType.getScales()),
181 zeroPoints(uniformType.getZeroPoints()),
182 clampMin(static_cast<double>(uniformType.getStorageTypeMin())),
183 clampMax(static_cast<double>(uniformType.getStorageTypeMax())),
184 storageBitWidth(uniformType.getStorageTypeIntegralWidth()),
185 isSigned(uniformType.isSigned()),
186 quantizationDim(uniformType.getQuantizedDimension()) {
187 assert(isa<FloatType>(uniformType.getExpressedType()));
188 assert(uniformType.getStorageType().isSignlessInteger());
189 assert(scales.size() == zeroPoints.size());
190 }
191
192 /// Quantize an Attribute by the quantization parameters. Return nullptr if
193 /// the conversion fails or the input array isn't an ElementsAttr.
194 ElementsAttr convert(Attribute realValue);
195
196private:
197 /// Quantize an DenseFPElementsAttr by the quantization parameters.
199
200 /// Get a uniform converter for the index-th chunk along the quantizationDim.
201 /// All the elements in this chunk is quantized by the returned converter.
202 UniformQuantizedValueConverter getPerChunkConverter(int index) const {
203 UniformQuantizedValueConverter converter(scales[index], zeroPoints[index],
204 clampMin, clampMax,
205 storageBitWidth, isSigned);
206 return converter;
207 }
208
209 const ArrayRef<double> scales;
210 const ArrayRef<int64_t> zeroPoints;
211 const APFloat clampMin;
212 const APFloat clampMax;
213 const uint32_t storageBitWidth;
214 const bool isSigned;
215 int32_t quantizationDim;
216};
217
218} // namespace quant
219} // namespace mlir
220
221#endif // MLIR_DIALECT_QUANT_UTILS_UNIFORMSUPPORT_H_
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
Definition TosaOps.cpp:2596
Attributes are known-constant values of operations.
Definition Attributes.h:25
An attribute that represents a reference to a dense vector or tensor object.
An attribute that represents a reference to a dense float vector or tensor object.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Base class for all quantized types known to this dialect.
Definition QuantTypes.h:50
Represents per-axis (also known as per-channel quantization).
Definition QuantTypes.h:324
ElementsAttr convert(Attribute realValue)
Quantize an Attribute by the quantization parameters.
UniformQuantizedPerAxisValueConverter(UniformQuantizedPerAxisType uniformType)
Represents a family of uniform, quantized types.
Definition QuantTypes.h:264
Reference implementation of converting between real numbers and values represented by a UniformQuanti...
UniformQuantizedValueConverter(double scale, double zeroPoint, const APFloat &clampMin, const APFloat &clampMax, uint32_t storageBitWidth, bool isSigned)
int64_t quantizeFloatToInt64(APFloat expressedValue) const
virtual APInt quantizeFloatToInt(APFloat expressedValue) const
UniformQuantizedValueConverter(double scale, double zeroPoint, double clampMin, double clampMax, uint32_t storageBitWidth, bool isSigned)
UniformQuantizedValueConverter(UniformQuantizedType uniformType)
Include the generated interface declarations.
Performs type conversion from an arbitrary input type to a type that is expressed by a QuantizedType.
static ExpressedToQuantizedConverter forInputType(Type inputType)
Creates a converter for the given input type.
const Type inputType
The input type that is being converted from.
Type convert(QuantizedType elementalType) const
Converts the inputType to be based on the given elemental type, returning the new type (or nullptr an...
const Type expressedType
Supported, elemental expressed type (i.e.