MLIR  14.0.0git
UniformSupport.cpp
Go to the documentation of this file.
1 //===- UniformSupport.cpp - Support utilities for uniform quant -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/BuiltinTypes.h"
11 #include <numeric>
12 
13 using namespace mlir;
14 using namespace mlir::quant;
15 
16 static bool isQuantizablePrimitiveType(Type inputType) {
17  return inputType.isa<FloatType>();
18 }
19 
22  if (inputType.isa<TensorType, VectorType>()) {
23  Type elementType = inputType.cast<ShapedType>().getElementType();
24  if (!isQuantizablePrimitiveType(elementType))
26  return ExpressedToQuantizedConverter{inputType, elementType};
27  }
28  // Supported primitive type (which just is the expressed type).
29  if (isQuantizablePrimitiveType(inputType))
30  return ExpressedToQuantizedConverter{inputType, inputType};
31  // Unsupported.
33 }
34 
36  assert(expressedType && "convert() on unsupported conversion");
37  if (auto tensorType = inputType.dyn_cast<RankedTensorType>())
38  return RankedTensorType::get(tensorType.getShape(), elementalType);
39  if (auto tensorType = inputType.dyn_cast<UnrankedTensorType>())
40  return UnrankedTensorType::get(elementalType);
41  if (auto vectorType = inputType.dyn_cast<VectorType>())
42  return VectorType::get(vectorType.getShape(), elementalType);
43 
44  // If the expressed types match, just use the new elemental type.
45  if (elementalType.getExpressedType() == expressedType)
46  return elementalType;
47  // Unsupported.
48  return nullptr;
49 }
50 
51 ElementsAttr
53  if (auto attr = realValue.dyn_cast<DenseFPElementsAttr>()) {
54  return convert(attr);
55  }
56  // TODO: handles sparse elements attribute
57  return nullptr;
58 }
59 
62  // Creates the converter for each chunk. Normally the size of the
63  // quantization dim is 3, so we can cache all the converters.
64  ShapedType type = attr.getType();
65  size_t dimSize = type.getDimSize(quantizationDim);
66  if (dimSize != scales.size()) {
67  return {};
68  }
70  converters.reserve(dimSize);
71  for (int i = 0, e = dimSize; i != e; ++i) {
72  converters.push_back(getPerChunkConverter(i));
73  }
74 
75  // Scan the elements of the dense elements attributes and quantize them by
76  // using the right quantization parameters.
77  int64_t flattenIndex = 0;
78  auto shape = type.getShape();
79  int64_t chunkSize =
80  std::accumulate(std::next(shape.begin(), quantizationDim + 1),
81  shape.end(), 1, std::multiplies<int64_t>());
82  Type newElementType = IntegerType::get(attr.getContext(), storageBitWidth);
83  return attr.mapValues(newElementType, [&](const APFloat &old) {
84  int chunkIndex = (flattenIndex++) / chunkSize;
85  return converters[chunkIndex % dimSize].quantizeFloatToInt(old);
86  });
87 }
Include the generated interface declarations.
An attribute that represents a reference to a dense float vector or tensor object.
Type getExpressedType() const
Gets the original expressed type that this quantized type approximates.
Definition: QuantTypes.cpp:81
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:639
const Type expressedType
Supported, elemental expressed type (i.e.
An attribute that represents a reference to a dense vector or tensor object.
U dyn_cast() const
Definition: Types.h:244
Performs type conversion from an arbitrary input type to a type that is expressed by a QuantizedType...
Attributes are known-constant values of operations.
Definition: Attributes.h:24
static bool isQuantizablePrimitiveType(Type inputType)
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:73
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
ElementsAttr convert(Attribute realValue)
Quantize an Attribute by the quantization parameters.
U dyn_cast() const
Definition: Attributes.h:117
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
Base class for all quantized types known to this dialect.
Definition: QuantTypes.h:52
Type convert(QuantizedType elementalType) const
Converts the inputType to be based on the given elemental type, returning the new type (or nullptr an...
const Type inputType
The input type that is being converted from.
static ExpressedToQuantizedConverter forInputType(Type inputType)
Creates a converter for the given input type.
bool isa() const
Definition: Types.h:234
DenseElementsAttr mapValues(Type newElementType, function_ref< APInt(const APFloat &)> mapping) const
Generates a new DenseElementsAttr by mapping each value attribute, and constructing the DenseElements...
U cast() const
Definition: Types.h:250