MLIR  21.0.0git
NormalizeQuantTypes.cpp
Go to the documentation of this file.
1 //===- NormalizeQuantTypes.cpp - Normalize quantized types
2 //----------------------===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Normalize generic quantized types to specific quantized types
11 //
12 //===----------------------------------------------------------------------===//
13 
21 
22 namespace mlir {
23 namespace quant {
24 
25 #define GEN_PASS_DEF_NORMALIZEQUANTTYPES
26 #include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
27 
28 namespace {
29 
30 /// Returns true if the given sub-channel quantized type is convertible to a
31 /// per-tensor quantized type. This is true if the sub-channel type has only
32 /// one scale and one zero point.
33 ///
34 /// Assumes that `tensorType` is a tensor with element type
35 /// `quant::UniformQuantizedSubChannelType`.
36 static bool isConvertibleToPerTensor(TensorType tensorType) {
37  return cast<UniformQuantizedSubChannelType>(tensorType.getElementType())
38  .getScales()
39  .getType()
40  .getNumElements() == 1;
41 }
42 
43 /// Returns true if the given sub-channel quantized type is convertible to a
44 /// per-axis quantized type. This is true if the shape of the scales tensor has
45 /// all but one non-one value.
46 ///
47 /// Assumes that `tensorType` is a tensor with element type
48 /// `quant::UniformQuantizedSubChannelType`.
49 static bool isConvertibleToPerAxis(TensorType tensorType) {
50  auto shape = cast<UniformQuantizedSubChannelType>(tensorType.getElementType())
51  .getScales()
52  .getType()
53  .getShape();
54  return llvm::count_if(shape, [](int64_t dim) { return dim != 1; }) == 1;
55 }
56 
57 /// This class defines a type converter that converts sub-channel quantized
58 /// types to per-tensor or per-axis quantized types whenever possible.
59 class NormalizedQuantTypesConverter : public TypeConverter {
60 
61  static Type convertType(Type type) {
62  auto tensorType = dyn_cast<TensorType>(type);
63  if (!tensorType) {
64  return type;
65  }
66 
67  auto subChannelType =
68  dyn_cast<UniformQuantizedSubChannelType>(tensorType.getElementType());
69  if (!subChannelType) {
70  return type;
71  }
72 
73  if (isConvertibleToPerTensor(tensorType)) {
74  double scale =
75  subChannelType.getScales().getValues<APFloat>()[0].convertToDouble();
76  int64_t zeroPoint =
77  subChannelType.getZeroPoints().getValues<APInt>()[0].getSExtValue();
78  auto perTensorType = UniformQuantizedType::get(
79  subChannelType.getFlags(), subChannelType.getStorageType(),
80  subChannelType.getExpressedType(), scale, zeroPoint,
81  subChannelType.getStorageTypeMin(),
82  subChannelType.getStorageTypeMax());
83  return tensorType.clone(perTensorType);
84  }
85 
86  if (isConvertibleToPerAxis(tensorType)) {
87  auto shape = subChannelType.getScales().getType().getShape();
88  auto quantizedDimItr =
89  llvm::find_if(shape, [](int64_t dim) { return dim != 1; });
90  auto scales = llvm::to_vector(llvm::map_range(
91  subChannelType.getScales().getValues<APFloat>(),
92  [](APFloat scale) { return scale.convertToDouble(); }));
93  auto zeroPoints = llvm::to_vector(llvm::map_range(
94  subChannelType.getZeroPoints().getValues<APInt>(),
95  [](APInt zeroPoint) { return zeroPoint.getSExtValue(); }));
96  auto perAxisType = UniformQuantizedPerAxisType::get(
97  subChannelType.getFlags(), subChannelType.getStorageType(),
98  subChannelType.getExpressedType(), scales, zeroPoints,
99  quantizedDimItr - shape.begin(), subChannelType.getStorageTypeMin(),
100  subChannelType.getStorageTypeMax());
101  return tensorType.clone(perAxisType);
102  }
103  return type;
104  }
105 
106 public:
107  explicit NormalizedQuantTypesConverter() { addConversion(convertType); }
108 };
109 
110 /// This class implements a conversion pattern that converts any generic
111 /// operation with sub-channel quantized types to an equivalent operation with
112 /// per-tensor or per-axis quantized types.
113 class ConvertGenericOpwithSubChannelType : public ConversionPattern {
114 public:
115  ConvertGenericOpwithSubChannelType(TypeConverter &typeConverter,
116  MLIRContext *context)
117  : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
118 
119  LogicalResult
120  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
121  ConversionPatternRewriter &rewriter) const final {
122  SmallVector<Type> resultTypes;
123  if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
124  return failure();
125 
126  auto *newOp = Operation::create(
127  op->getLoc(), op->getName(), resultTypes, operands, op->getAttrs(),
128  op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions());
129  for (auto regions : llvm::zip(op->getRegions(), newOp->getRegions())) {
130  Region &before = std::get<0>(regions);
131  Region &parent = std::get<1>(regions);
132  rewriter.inlineRegionBefore(before, parent, parent.end());
133  if (failed(rewriter.convertRegionTypes(&parent, *typeConverter)))
134  return failure();
135  }
136  rewriter.insert(newOp);
137  rewriter.replaceOp(op, newOp->getResults());
138  return success();
139  }
140 };
141 
142 // Conversion pass
143 class NormalizeQuantTypes
144  : public impl::NormalizeQuantTypesBase<NormalizeQuantTypes> {
145 public:
146  void runOnOperation() override {
147 
148  auto *context = &getContext();
149 
150  NormalizedQuantTypesConverter typeConverter;
151  ConversionTarget target(*context);
152 
153  // Determine legal operations.
154  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
155  return typeConverter.isSignatureLegal(op.getFunctionType()) &&
156  typeConverter.isLegal(&op.getBody());
157  });
158  target.markUnknownOpDynamicallyLegal([&](Operation *op) {
159  return typeConverter.isLegal(op->getOperandTypes()) &&
160  typeConverter.isLegal(op->getResultTypes());
161  });
162 
163  // Register conversion patterns
164  RewritePatternSet patterns(context);
165  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
166  patterns, typeConverter);
167  patterns.add<ConvertGenericOpwithSubChannelType>(typeConverter, context);
168 
169  // Apply conversion
170  if (failed(
171  applyFullConversion(getOperation(), target, std::move(patterns))))
172  signalPassFailure();
173  }
174 };
175 
176 } // namespace
177 
178 } // namespace quant
179 } // namespace mlir
static MLIRContext * getContext(OpFoldResult val)
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
static UniformQuantizedPerAxisType get(unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:338
static UniformQuantizedType get(unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:283
Include the generated interface declarations.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
const FrozenRewritePatternSet & patterns