MLIR 22.0.0git
NormalizeQuantTypes.cpp
Go to the documentation of this file.
1//===- NormalizeQuantTypes.cpp - Normalize quantized types
2//----------------------===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// Normalize generic quantized types to specific quantized types
11//
12//===----------------------------------------------------------------------===//
13
20
21namespace mlir {
22namespace quant {
23
24#define GEN_PASS_DEF_NORMALIZEQUANTTYPES
25#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
26
27namespace {
28
29/// Returns true if the given sub-channel quantized type is convertible to a
30/// per-tensor quantized type. This is true if the sub-channel type has only
31/// one scale and one zero point.
32///
33/// Assumes that `tensorType` is a tensor with element type
34/// `quant::UniformQuantizedSubChannelType`.
35static bool isConvertibleToPerTensor(TensorType tensorType) {
36 return cast<UniformQuantizedSubChannelType>(tensorType.getElementType())
37 .getScales()
38 .getType()
39 .getNumElements() == 1;
40}
41
42/// Returns true if the given sub-channel quantized type is convertible to a
43/// per-axis quantized type. This is true if the shape of the scales tensor has
44/// all but one non-one value.
45///
46/// Assumes that `tensorType` is a tensor with element type
47/// `quant::UniformQuantizedSubChannelType`.
48static bool isConvertibleToPerAxis(TensorType tensorType) {
49 auto shape = cast<UniformQuantizedSubChannelType>(tensorType.getElementType())
50 .getScales()
51 .getType()
52 .getShape();
53 return llvm::count_if(shape, [](int64_t dim) { return dim != 1; }) == 1;
54}
55
56/// This class defines a type converter that converts sub-channel quantized
57/// types to per-tensor or per-axis quantized types whenever possible.
58class NormalizedQuantTypesConverter : public TypeConverter {
59
60 static Type convertType(Type type) {
61 auto tensorType = dyn_cast<TensorType>(type);
62 if (!tensorType) {
63 return type;
64 }
65
66 auto subChannelType =
67 dyn_cast<UniformQuantizedSubChannelType>(tensorType.getElementType());
68 if (!subChannelType) {
69 return type;
70 }
71
72 if (isConvertibleToPerTensor(tensorType)) {
73 double scale =
74 subChannelType.getScales().getValues<APFloat>()[0].convertToDouble();
75 int64_t zeroPoint =
76 subChannelType.getZeroPoints().getValues<APInt>()[0].getSExtValue();
77 auto perTensorType = UniformQuantizedType::get(
78 subChannelType.getFlags(), subChannelType.getStorageType(),
79 subChannelType.getExpressedType(), scale, zeroPoint,
80 subChannelType.getStorageTypeMin(),
81 subChannelType.getStorageTypeMax());
82 return tensorType.clone(perTensorType);
83 }
84
85 if (isConvertibleToPerAxis(tensorType)) {
86 auto shape = subChannelType.getScales().getType().getShape();
87 const auto *quantizedDimItr =
88 llvm::find_if(shape, [](int64_t dim) { return dim != 1; });
89 auto scales = llvm::to_vector(llvm::map_range(
90 subChannelType.getScales().getValues<APFloat>(),
91 [](const APFloat &scale) { return scale.convertToDouble(); }));
92 auto zeroPoints = llvm::to_vector(llvm::map_range(
93 subChannelType.getZeroPoints().getValues<APInt>(),
94 [](const APInt &zeroPoint) { return zeroPoint.getSExtValue(); }));
95 auto perAxisType = UniformQuantizedPerAxisType::get(
96 subChannelType.getFlags(), subChannelType.getStorageType(),
97 subChannelType.getExpressedType(), scales, zeroPoints,
98 quantizedDimItr - shape.begin(), subChannelType.getStorageTypeMin(),
99 subChannelType.getStorageTypeMax());
100 return tensorType.clone(perAxisType);
101 }
102 return type;
103 }
104
105public:
106 explicit NormalizedQuantTypesConverter() { addConversion(convertType); }
107};
109/// This class implements a conversion pattern that converts any generic
110/// operation with sub-channel quantized types to an equivalent operation with
111/// per-tensor or per-axis quantized types.
112class ConvertGenericOpwithSubChannelType : public ConversionPattern {
113public:
114 ConvertGenericOpwithSubChannelType(TypeConverter &typeConverter,
115 MLIRContext *context)
116 : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
117
118 LogicalResult
119 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
120 ConversionPatternRewriter &rewriter) const final {
121 SmallVector<Type> resultTypes;
122 if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
123 return failure();
124
125 auto *newOp = Operation::create(
126 op->getLoc(), op->getName(), resultTypes, operands, op->getAttrs(),
127 op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions());
128 for (auto regions : llvm::zip(op->getRegions(), newOp->getRegions())) {
129 Region &before = std::get<0>(regions);
130 Region &parent = std::get<1>(regions);
131 rewriter.inlineRegionBefore(before, parent, parent.end());
132 if (failed(rewriter.convertRegionTypes(&parent, *typeConverter)))
133 return failure();
134 }
135 rewriter.insert(newOp);
136 rewriter.replaceOp(op, newOp->getResults());
137 return success();
138 }
139};
140
141// Conversion pass
142class NormalizeQuantTypes
143 : public impl::NormalizeQuantTypesBase<NormalizeQuantTypes> {
144public:
145 void runOnOperation() override {
146
147 auto *context = &getContext();
148
149 NormalizedQuantTypesConverter typeConverter;
150 ConversionTarget target(*context);
151
152 // Determine legal operations.
153 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
154 return typeConverter.isSignatureLegal(op.getFunctionType()) &&
155 typeConverter.isLegal(&op.getBody());
156 });
157 target.markUnknownOpDynamicallyLegal([&](Operation *op) {
158 return typeConverter.isLegal(op->getOperandTypes()) &&
159 typeConverter.isLegal(op->getResultTypes());
160 });
162 // Register conversion patterns
164 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
165 patterns, typeConverter);
166 patterns.add<ConvertGenericOpwithSubChannelType>(typeConverter, context);
167
168 // Apply conversion
169 if (failed(
170 applyFullConversion(getOperation(), target, std::move(patterns))))
171 signalPassFailure();
172 }
173};
174
175} // namespace
176
177} // namespace quant
178} // namespace mlir
return success()
b getContext())
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition Operation.cpp:67
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
iterator end()
Definition Region.h:56
static UniformQuantizedPerAxisType get(unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
static UniformQuantizedType get(unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns