MLIR 23.0.0git
NormalizeQuantTypes.cpp
Go to the documentation of this file.
1//===- NormalizeQuantTypes.cpp - Normalize quantized types
2//----------------------===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// Normalize generic quantized types to specific quantized types
11//
12//===----------------------------------------------------------------------===//
13
20#include "llvm/ADT/SmallVectorExtras.h"
21
22namespace mlir {
23namespace quant {
24
25#define GEN_PASS_DEF_NORMALIZEQUANTTYPES
26#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
27
28namespace {
29
30/// Returns true if the given sub-channel quantized type is convertible to a
31/// per-tensor quantized type. This is true if the sub-channel type has only
32/// one scale and one zero point.
33///
34/// Assumes that `tensorType` is a tensor with element type
35/// `quant::UniformQuantizedSubChannelType`.
36static bool isConvertibleToPerTensor(TensorType tensorType) {
37 return cast<UniformQuantizedSubChannelType>(tensorType.getElementType())
38 .getScales()
39 .getType()
40 .getNumElements() == 1;
41}
42
43/// Returns true if the given sub-channel quantized type is convertible to a
44/// per-axis quantized type. This is true if the shape of the scales tensor has
45/// all but one non-one value.
46///
47/// Assumes that `tensorType` is a tensor with element type
48/// `quant::UniformQuantizedSubChannelType`.
49static bool isConvertibleToPerAxis(TensorType tensorType) {
50 auto shape = cast<UniformQuantizedSubChannelType>(tensorType.getElementType())
51 .getScales()
52 .getType()
53 .getShape();
54 return llvm::count_if(shape, [](int64_t dim) { return dim != 1; }) == 1;
55}
56
57/// This class defines a type converter that converts sub-channel quantized
58/// types to per-tensor or per-axis quantized types whenever possible.
59class NormalizedQuantTypesConverter : public TypeConverter {
60
61 static Type convertType(Type type) {
62 auto tensorType = dyn_cast<TensorType>(type);
63 if (!tensorType) {
64 return type;
65 }
66
67 auto subChannelType =
68 dyn_cast<UniformQuantizedSubChannelType>(tensorType.getElementType());
69 if (!subChannelType) {
70 return type;
71 }
72
73 if (isConvertibleToPerTensor(tensorType)) {
74 double scale =
75 subChannelType.getScales().getValues<APFloat>()[0].convertToDouble();
76 int64_t zeroPoint =
77 subChannelType.getZeroPoints().getValues<APInt>()[0].getSExtValue();
78 auto perTensorType = UniformQuantizedType::get(
79 subChannelType.getFlags(), subChannelType.getStorageType(),
80 subChannelType.getExpressedType(), scale, zeroPoint,
81 subChannelType.getStorageTypeMin(),
82 subChannelType.getStorageTypeMax());
83 return tensorType.clone(perTensorType);
84 }
85
86 if (isConvertibleToPerAxis(tensorType)) {
87 auto shape = subChannelType.getScales().getType().getShape();
88 const auto *quantizedDimItr =
89 llvm::find_if(shape, [](int64_t dim) { return dim != 1; });
90 auto scales = llvm::map_to_vector(
91 subChannelType.getScales().getValues<APFloat>(),
92 [](const APFloat &scale) { return scale.convertToDouble(); });
93 auto zeroPoints = llvm::map_to_vector(
94 subChannelType.getZeroPoints().getValues<APInt>(),
95 [](const APInt &zeroPoint) { return zeroPoint.getSExtValue(); });
96 auto perAxisType = UniformQuantizedPerAxisType::get(
97 subChannelType.getFlags(), subChannelType.getStorageType(),
98 subChannelType.getExpressedType(), scales, zeroPoints,
99 quantizedDimItr - shape.begin(), subChannelType.getStorageTypeMin(),
100 subChannelType.getStorageTypeMax());
101 return tensorType.clone(perAxisType);
102 }
103 return type;
104 }
105
106public:
107 explicit NormalizedQuantTypesConverter() { addConversion(convertType); }
108};
109
110/// This class implements a conversion pattern that converts any generic
111/// operation with sub-channel quantized types to an equivalent operation with
112/// per-tensor or per-axis quantized types.
113class ConvertGenericOpwithSubChannelType : public ConversionPattern {
114public:
115 ConvertGenericOpwithSubChannelType(TypeConverter &typeConverter,
116 MLIRContext *context)
117 : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
118
119 LogicalResult
120 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
121 ConversionPatternRewriter &rewriter) const final {
122 SmallVector<Type> resultTypes;
123 if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
124 return failure();
125
126 auto *newOp = Operation::create(
127 op->getLoc(), op->getName(), resultTypes, operands, op->getAttrs(),
128 op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions());
129 for (auto regions : llvm::zip(op->getRegions(), newOp->getRegions())) {
130 Region &before = std::get<0>(regions);
131 Region &parent = std::get<1>(regions);
132 rewriter.inlineRegionBefore(before, parent, parent.end());
133 if (failed(rewriter.convertRegionTypes(&parent, *typeConverter)))
134 return failure();
135 }
136 rewriter.insert(newOp);
137 rewriter.replaceOp(op, newOp->getResults());
138 return success();
139 }
140};
141
142// Conversion pass
143class NormalizeQuantTypes
144 : public impl::NormalizeQuantTypesBase<NormalizeQuantTypes> {
145public:
146 void runOnOperation() override {
147
148 auto *context = &getContext();
149
150 NormalizedQuantTypesConverter typeConverter;
151 ConversionTarget target(*context);
152
153 // Determine legal operations.
154 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
155 return typeConverter.isSignatureLegal(op.getFunctionType()) &&
156 typeConverter.isLegal(&op.getBody());
157 });
158 target.markUnknownOpDynamicallyLegal([&](Operation *op) {
159 return typeConverter.isLegal(op->getOperandTypes()) &&
160 typeConverter.isLegal(op->getResultTypes());
161 });
162
163 // Register conversion patterns
164 RewritePatternSet patterns(context);
165 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
166 patterns, typeConverter);
167 patterns.add<ConvertGenericOpwithSubChannelType>(typeConverter, context);
168
169 // Apply conversion
170 if (failed(
171 applyFullConversion(getOperation(), target, std::move(patterns))))
172 signalPassFailure();
173 }
174};
175
176} // namespace
177
178} // namespace quant
179} // namespace mlir
return success()
b getContext())
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition Operation.cpp:67
static UniformQuantizedPerAxisType get(unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
static UniformQuantizedType get(unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns