MLIR  21.0.0git
QuantUtils.cpp
Go to the documentation of this file.
1 //===- QuantUtils.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains TOSA numerical support functions and quantization
10 // attribute builders.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
16 using namespace mlir;
17 using namespace mlir::tosa;
18 
19 /// From a scale value, generates multiplier and shift values where
20 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
21 /// multiplier = mantissa*2^shift for 16-bit scaling.
22 static void computeMultiplierAndShiftTosaScale16(double scale,
23  int32_t &multiplier,
24  int32_t &shift) {
25 
26  const double mantissa = std::frexp(scale, &shift);
27  auto shiftedM = std::round(mantissa * (int64_t(1) << 15));
28 
29  // Can't be greater than 1.0.
30  assert(shiftedM <= (int64_t(1) << 15) &&
31  "Shifted mantissa exceeds 16 signed bits");
32 
33  if (shiftedM == (int64_t(1) << 15)) {
34  shiftedM /= 2;
35  shift++;
36  }
37 
38  // TOSA expects right shift to be positive and embed (1 << 15) into right
39  // shift bits.
40  shift = (-shift) + 15;
41 
42  assert(shiftedM <= std::numeric_limits<int32_t>::max() &&
43  "Shifted mantissa exceeds 32-bit signed output type");
44 
45  multiplier = static_cast<int32_t>(shiftedM);
46 
47  // Shifting tops out at 62 bits. Right shift to make 62 bits the max.
48  // The limit of 62 on shift allows the shift to be decomposed as
49  // two right shifts of 31.
50  if (shift > 62) {
51  // Shifting the multiplier by more than 31-bits is unnecessary.
52  multiplier = multiplier >> std::min<int32_t>(31, shift - 62);
53  shift = 62;
54  }
55 }
56 
57 /// From a scale value, generates multiplier and shift values where
58 /// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that
59 /// multiplier = mantissa*2^shift for 32-bit scaling.
60 static void computeMultiplierAndShiftTosaScale32(double scale,
61  int32_t &multiplier,
62  int32_t &shift) {
63 
64  const double mantissa = std::frexp(scale, &shift);
65  auto shiftedM = std::round(mantissa * (int64_t(1) << 31));
66 
67  // Can't be greater than 1.0.
68  assert(shiftedM <= (int64_t(1) << 31) &&
69  "Shifted mantissa exceeds 32 signed bits");
70  if (shiftedM == (int64_t(1) << 31)) {
71  shiftedM /= 2;
72  shift++;
73  }
74 
75  // TOSA expects right shift to be positive, and embed (1 << 31) into right
76  // shift bits.
77  shift = (-shift) + 31;
78 
79  assert(shiftedM <= std::numeric_limits<int32_t>::max() &&
80  "Shifted mantissa exceeds 32-bit signed output type");
81 
82  multiplier = static_cast<int32_t>(shiftedM);
83 
84  // Shifting tops out at 62 bits. Right shift to make 62 bits the max.
85  // The limit of 62 on shift allows the shift to be decomposed as
86  // two right shifts of 31.
87  if (shift > 62) {
88  // Shifting the multiplier by more than 32-bits is unnecessary.
89  multiplier = multiplier >> std::min<int32_t>(31, shift - 62);
90  shift = 62;
91  }
92 }
93 
94 /// Generates a quantized multiplier/shift from double.
95 bool mlir::tosa::computeMultiplierAndShift(double scale, int32_t &multiplier,
96  int32_t &shift, int32_t scaleWidth) {
97 
98  switch (scaleWidth) {
99  case 16:
100  computeMultiplierAndShiftTosaScale16(scale, multiplier, shift);
101 
102  // In some cases computeMultiplierAndShiftTosaScale16 can return
103  // a value less then 2, which is not valid in the TOSA spec.
104  return (!(shift < 2));
105  case 32:
106  computeMultiplierAndShiftTosaScale32(scale, multiplier, shift);
107 
108  // In some cases computeMultiplierAndShiftTosaScale32 can return
109  // a value less then 2, which is not valid in the TOSA spec.
110  return (!(shift < 2));
111  default:
112  assert(0 && "Unsupported Tosa quantized_scale regime specified!");
113  return false;
114  }
115 }
116 
117 #define GET_UQTYPE(inputType) \
118  (llvm::dyn_cast<quant::UniformQuantizedType>((inputType).getElementType()))
119 #define GET_QTYPE(inputType) \
120  (llvm::dyn_cast<quant::QuantizedType>((inputType).getElementType()))
121 
122 static std::optional<std::pair<std::int64_t, std::int64_t>>
123 getConvZeroPoints(Value input, Value weight) {
124 
125  auto inputType = dyn_cast<ShapedType>(input.getType());
126  auto weightType = dyn_cast<ShapedType>(weight.getType());
127 
128  if (!inputType || !weightType)
129  return std::nullopt;
130 
131  auto inputQType = GET_UQTYPE(inputType);
132  auto weightPerTensorQType = GET_UQTYPE(weightType);
133  auto weightPerAxisQType =
134  dyn_cast<quant::UniformQuantizedPerAxisType>(weightType.getElementType());
135 
136  // Weights must be either per-tensor quantized or per-axis quantized.
137  assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType) &&
138  "Weights must be either per-tensor or per-axis quantized");
139 
140  // Either all quantized or all not quantized.
141  assert(!((bool)inputQType ^
142  ((bool)weightPerTensorQType || (bool)weightPerAxisQType)) &&
143  "Inputs and weights must be all quantized or all not quantized");
144 
145  if (inputQType) {
146  int64_t inputZp = inputQType.getZeroPoint();
147  int64_t weightZp = 0;
148 
149  if (weightPerTensorQType) {
150  weightZp = weightPerTensorQType.getZeroPoint();
151  } else if (weightPerAxisQType) {
152  weightZp = weightPerAxisQType.getZeroPoints().front();
153  }
154 
155  return std::make_pair(inputZp, weightZp);
156  }
157 
158  return std::nullopt;
159 }
160 
161 std::pair<Value, Value>
163  std::int64_t inputZp, weightZp;
164 
165  auto inputEType = getElementTypeOrSelf(input.getType());
166  auto weightEType = getElementTypeOrSelf(weight.getType());
167 
168  if (mlir::isa<FloatType>(inputEType) && mlir::isa<FloatType>(weightEType)) {
169  inputZp = 0;
170  weightZp = 0;
171  } else {
172  auto maybeZps = getConvZeroPoints(input, weight);
173  if (!maybeZps.has_value())
174  return {};
175 
176  inputZp = maybeZps->first;
177  weightZp = maybeZps->second;
178  }
179 
180  auto maybeInputZpValue =
181  createZeroPointTensor(builder, input.getLoc(), inputEType, inputZp);
182  if (!maybeInputZpValue.has_value())
183  return {};
184 
185  auto maybeWeightZpValue =
186  createZeroPointTensor(builder, weight.getLoc(), weightEType, weightZp);
187  if (!maybeWeightZpValue.has_value())
188  return {};
189 
190  return std::make_pair(*maybeInputZpValue, *maybeWeightZpValue);
191 }
192 
193 /// Method to build ConvOpQuantizationAttr, called from
194 /// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder:
195 /// input_zp: input zeropoint
196 /// weight_zp: weight zeropoint.
197 ConvOpQuantizationAttr
199  Value weight) {
200 
201  auto maybeZps = getConvZeroPoints(input, weight);
202  if (!maybeZps.has_value())
203  return nullptr;
204 
205  return builder.getAttr<tosa::ConvOpQuantizationAttr>(maybeZps->first,
206  maybeZps->second);
207 }
208 
209 /// Builds MatMulOpQuantizationAttr, called from
210 /// MatMulOpQuantInfoBuilder:
211 /// aZp: input a zeropoint
212 /// bZp: input b zeropoint.
213 MatMulOpQuantizationAttr
215  Value b) {
216 
217  auto aType = dyn_cast<ShapedType>(a.getType());
218  auto bType = dyn_cast<ShapedType>(b.getType());
219 
220  if (!aType || !bType)
221  return nullptr;
222 
223  auto aQType = GET_UQTYPE(aType);
224  auto bQType = GET_UQTYPE(bType);
225 
226  // A and B are either all quantized or all not quantized.
227  assert(!((bool)aQType ^ (bool)bQType) &&
228  "Matmul operands must be all quantized or all not quantized");
229 
230  if (aQType) {
231  return builder.getAttr<tosa::MatMulOpQuantizationAttr>(
232  aQType.getZeroPoint(), bQType.getZeroPoint());
233  }
234 
235  return nullptr;
236 }
237 
238 /// Builds UnaryOpQuantizationAttr
239 /// UnaryOpQuantInfoBuilder:
240 /// inputZp: input zeropoint
241 /// outputZp: output zeropoint.
242 UnaryOpQuantizationAttr
244  Type outputRawType) {
245 
246  auto inputType = dyn_cast<ShapedType>(input.getType());
247  auto outputType = dyn_cast<ShapedType>(outputRawType);
248 
249  if (!inputType || !outputType)
250  return nullptr;
251 
252  auto inputQType = GET_UQTYPE(inputType);
253  auto outputQType = GET_UQTYPE(outputType);
254 
255  // Either all quantized or all not quantized.
256  assert(!((bool)inputQType ^ (bool)outputQType) &&
257  "Unary inputs/outputs must be all quantized or all not quantized");
258 
259  if (inputQType) {
260  return builder.getAttr<UnaryOpQuantizationAttr>(inputQType.getZeroPoint(),
261  outputQType.getZeroPoint());
262  }
263 
264  return nullptr;
265 }
266 
267 /// Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder:
268 /// inputZp: input zeropoint.
269 PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder,
270  Value input) {
271 
272  auto inputType = dyn_cast<ShapedType>(input.getType());
273 
274  if (!inputType)
275  return nullptr;
276 
277  auto inputQType = GET_UQTYPE(inputType);
278 
279  if (inputQType) {
280  return builder.getAttr<tosa::PadOpQuantizationAttr>(
281  inputQType.getZeroPoint());
282  }
283 
284  return nullptr;
285 }
286 
287 /// Builds output type for a quantized ConvOp with the right bitwidth.
288 /// This is called by the builder when dealing with quantized content.
290  Value input, Value weight) {
291 
292  auto inputType = dyn_cast<ShapedType>(input.getType());
293  auto weightType = dyn_cast<ShapedType>(weight.getType());
294 
295  assert(inputType && weightType &&
296  "Could not extract input or weight tensors from Conv op");
297 
298  auto inputQType = GET_QTYPE(inputType);
299  auto weightQType = GET_QTYPE(weightType);
300 
301  assert(inputQType && weightQType &&
302  "Could not extract input or weight tensor types from Conv op");
303 
304  unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
305  unsigned weightBits = weightQType.getStorageTypeIntegralWidth();
306 
307  auto outputShapedType = dyn_cast<ShapedType>(outputType);
308  assert(outputShapedType &&
309  "Could not extract output shape type from Conv op");
310 
311  IntegerType accElementType;
312  if (inputBits == 16 && weightBits == 8)
313  accElementType = builder.getIntegerType(48);
314  else
315  accElementType = builder.getI32Type();
316  auto accType = outputShapedType.clone(accElementType);
317  return accType;
318 }
319 
320 /// Builds Tosa quantization attributes from min/max values.
322  Attribute minAttr, Attribute maxAttr,
323  IntegerAttr quantBits, int filterQuantDim,
324  bool isSigned, BoolAttr narrowRange) {
325 
326  quant::QuantizedType retType;
327 
328  auto convfunc =
330 
331  auto minElems = dyn_cast<DenseFPElementsAttr>(minAttr);
332  auto maxElems = dyn_cast<DenseFPElementsAttr>(maxAttr);
333 
335 
336  // At least one is per-axis quantized elementsattr.
337  if (minElems || maxElems) {
338  // Must have the same number of elements.
339  if (minElems.getNumElements() != maxElems.getNumElements())
340  return {};
341  min.reserve(minElems.getNumElements());
342  max.reserve(maxElems.getNumElements());
343  for (auto i : minElems)
344  min.push_back(FloatAttr::getValueAsDouble(i));
345  for (auto i : maxElems)
346  max.push_back(FloatAttr::getValueAsDouble(i));
347  } else { // Just a single FP value.
348  auto minVal = dyn_cast<FloatAttr>(minAttr);
349  if (minVal)
350  min.push_back(minVal.getValueAsDouble());
351  else
352  return {};
353  auto maxVal = dyn_cast<FloatAttr>(maxAttr);
354  if (maxVal)
355  max.push_back(maxVal.getValueAsDouble());
356  else
357  return {};
358  }
359 
360  if (min.size() == max.size()) {
361  if (min.size() == 1) { // Per-tensor quantization with one min/max pair.
362  retType = quant::fakeQuantAttrsToType(
363  builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0],
364  narrowRange.getValue(), convfunc.expressedType, isSigned);
365  } else if (min.size() > 1) { // Per-axis quant on filterQuantDim.
366  auto shape = dyn_cast<ShapedType>(inputDType);
367  if (!shape)
368  return {};
369  if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) {
370  retType = quant::fakeQuantAttrsToType(
371  builder.getUnknownLoc(), quantBits.getInt(), filterQuantDim, min[0],
372  max[0], narrowRange.getValue(), convfunc.expressedType, isSigned);
373  }
374  } else {
375  return {};
376  }
377  } else {
378  return {};
379  }
380 
381  if (!retType)
382  return {};
383 
384  return convfunc.convert(retType);
385 }
386 
387 /// Builds Tosa quantization attributes from min/max values.
388 TypeAttr
390  Attribute minAttr, Attribute maxAttr,
391  IntegerAttr quantBits, int filterQuantDim,
392  bool isSigned, BoolAttr narrowRange) {
393 
394  return TypeAttr::get(buildQTypeFromMinMax(builder, inputDtype, minAttr,
395  maxAttr, quantBits, filterQuantDim,
396  isSigned, narrowRange));
397 }
static std::pair< Value, Value > frexp(ImplicitLocOpBuilder &builder, Value arg, bool isPositive=false)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define GET_UQTYPE(inputType)
Definition: QuantUtils.cpp:117
static std::optional< std::pair< std::int64_t, std::int64_t > > getConvZeroPoints(Value input, Value weight)
Definition: QuantUtils.cpp:123
static void computeMultiplierAndShiftTosaScale16(double scale, int32_t &multiplier, int32_t &shift)
From a scale value, generates multiplier and shift values where mantissa is in [-1....
Definition: QuantUtils.cpp:22
#define GET_QTYPE(inputType)
Definition: QuantUtils.cpp:119
static void computeMultiplierAndShiftTosaScale32(double scale, int32_t &multiplier, int32_t &shift)
From a scale value, generates multiplier and shift values where mantissa is in [-1....
Definition: QuantUtils.cpp:60
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
Location getUnknownLoc()
Definition: Builders.cpp:27
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:96
This class helps build Operations.
Definition: Builders.h:205
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Base class for all quantized types known to this dialect.
Definition: QuantTypes.h:49
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin, double rmax, bool narrowRange, Type expressedType, bool isSigned=false)
Converts per-layer FakeQuant attributes to the corresponding type.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Definition: QuantUtils.cpp:198
TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDType, Attribute minAttr, Attribute maxAttr, IntegerAttr quantBits, int filterQuantDim, bool isSigned, BoolAttr narrowRange)
Builds Tosa quantization attributes from min/max values.
Definition: QuantUtils.cpp:389
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
Definition: QuantUtils.cpp:289
bool computeMultiplierAndShift(double scale, int32_t &multiplier, int32_t &shift, int32_t scaleWidth)
From a scale value, computes multiplier and shift values for 16 or 32-bit scale widths.
Definition: QuantUtils.cpp:95
Type buildQTypeFromMinMax(OpBuilder builder, Type inputDType, Attribute minAttr, Attribute maxAttr, IntegerAttr quantBits, int filterQuantDim, bool isSigned, BoolAttr narrowRange)
Builds Tosa quantization attributes from min/max values.
Definition: QuantUtils.cpp:321
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
Definition: QuantUtils.cpp:269
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
Definition: QuantUtils.cpp:162
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
Definition: QuantUtils.cpp:214
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
Definition: TosaOps.cpp:3299
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Definition: QuantUtils.cpp:243
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
static ExpressedToQuantizedConverter forInputType(Type inputType)
Creates a converter for the given input type.