MLIR  22.0.0git
NormalizeQuantTypes.cpp
Go to the documentation of this file.
1 //===- NormalizeQuantTypes.cpp - Normalize quantized types
2 //----------------------===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Normalize generic quantized types to specific quantized types
11 //
12 //===----------------------------------------------------------------------===//
13 
20 
21 namespace mlir {
22 namespace quant {
23 
24 #define GEN_PASS_DEF_NORMALIZEQUANTTYPES
25 #include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
26 
27 namespace {
28 
29 /// Returns true if the given sub-channel quantized type is convertible to a
30 /// per-tensor quantized type. This is true if the sub-channel type has only
31 /// one scale and one zero point.
32 ///
33 /// Assumes that `tensorType` is a tensor with element type
34 /// `quant::UniformQuantizedSubChannelType`.
35 static bool isConvertibleToPerTensor(TensorType tensorType) {
36  return cast<UniformQuantizedSubChannelType>(tensorType.getElementType())
37  .getScales()
38  .getType()
39  .getNumElements() == 1;
40 }
41 
42 /// Returns true if the given sub-channel quantized type is convertible to a
43 /// per-axis quantized type. This is true if the shape of the scales tensor has
44 /// all but one non-one value.
45 ///
46 /// Assumes that `tensorType` is a tensor with element type
47 /// `quant::UniformQuantizedSubChannelType`.
48 static bool isConvertibleToPerAxis(TensorType tensorType) {
49  auto shape = cast<UniformQuantizedSubChannelType>(tensorType.getElementType())
50  .getScales()
51  .getType()
52  .getShape();
53  return llvm::count_if(shape, [](int64_t dim) { return dim != 1; }) == 1;
54 }
55 
56 /// This class defines a type converter that converts sub-channel quantized
57 /// types to per-tensor or per-axis quantized types whenever possible.
58 class NormalizedQuantTypesConverter : public TypeConverter {
59 
60  static Type convertType(Type type) {
61  auto tensorType = dyn_cast<TensorType>(type);
62  if (!tensorType) {
63  return type;
64  }
65 
66  auto subChannelType =
67  dyn_cast<UniformQuantizedSubChannelType>(tensorType.getElementType());
68  if (!subChannelType) {
69  return type;
70  }
71 
72  if (isConvertibleToPerTensor(tensorType)) {
73  double scale =
74  subChannelType.getScales().getValues<APFloat>()[0].convertToDouble();
75  int64_t zeroPoint =
76  subChannelType.getZeroPoints().getValues<APInt>()[0].getSExtValue();
77  auto perTensorType = UniformQuantizedType::get(
78  subChannelType.getFlags(), subChannelType.getStorageType(),
79  subChannelType.getExpressedType(), scale, zeroPoint,
80  subChannelType.getStorageTypeMin(),
81  subChannelType.getStorageTypeMax());
82  return tensorType.clone(perTensorType);
83  }
84 
85  if (isConvertibleToPerAxis(tensorType)) {
86  auto shape = subChannelType.getScales().getType().getShape();
87  auto quantizedDimItr =
88  llvm::find_if(shape, [](int64_t dim) { return dim != 1; });
89  auto scales = llvm::to_vector(llvm::map_range(
90  subChannelType.getScales().getValues<APFloat>(),
91  [](APFloat scale) { return scale.convertToDouble(); }));
92  auto zeroPoints = llvm::to_vector(llvm::map_range(
93  subChannelType.getZeroPoints().getValues<APInt>(),
94  [](APInt zeroPoint) { return zeroPoint.getSExtValue(); }));
95  auto perAxisType = UniformQuantizedPerAxisType::get(
96  subChannelType.getFlags(), subChannelType.getStorageType(),
97  subChannelType.getExpressedType(), scales, zeroPoints,
98  quantizedDimItr - shape.begin(), subChannelType.getStorageTypeMin(),
99  subChannelType.getStorageTypeMax());
100  return tensorType.clone(perAxisType);
101  }
102  return type;
103  }
104 
105 public:
106  explicit NormalizedQuantTypesConverter() { addConversion(convertType); }
107 };
108 
109 /// This class implements a conversion pattern that converts any generic
110 /// operation with sub-channel quantized types to an equivalent operation with
111 /// per-tensor or per-axis quantized types.
112 class ConvertGenericOpwithSubChannelType : public ConversionPattern {
113 public:
114  ConvertGenericOpwithSubChannelType(TypeConverter &typeConverter,
115  MLIRContext *context)
116  : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
117 
118  LogicalResult
119  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
120  ConversionPatternRewriter &rewriter) const final {
121  SmallVector<Type> resultTypes;
122  if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
123  return failure();
124 
125  auto *newOp = Operation::create(
126  op->getLoc(), op->getName(), resultTypes, operands, op->getAttrs(),
127  op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions());
128  for (auto regions : llvm::zip(op->getRegions(), newOp->getRegions())) {
129  Region &before = std::get<0>(regions);
130  Region &parent = std::get<1>(regions);
131  rewriter.inlineRegionBefore(before, parent, parent.end());
132  if (failed(rewriter.convertRegionTypes(&parent, *typeConverter)))
133  return failure();
134  }
135  rewriter.insert(newOp);
136  rewriter.replaceOp(op, newOp->getResults());
137  return success();
138  }
139 };
140 
141 // Conversion pass
142 class NormalizeQuantTypes
143  : public impl::NormalizeQuantTypesBase<NormalizeQuantTypes> {
144 public:
145  void runOnOperation() override {
146 
147  auto *context = &getContext();
148 
149  NormalizedQuantTypesConverter typeConverter;
150  ConversionTarget target(*context);
151 
152  // Determine legal operations.
153  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
154  return typeConverter.isSignatureLegal(op.getFunctionType()) &&
155  typeConverter.isLegal(&op.getBody());
156  });
157  target.markUnknownOpDynamicallyLegal([&](Operation *op) {
158  return typeConverter.isLegal(op->getOperandTypes()) &&
159  typeConverter.isLegal(op->getResultTypes());
160  });
161 
162  // Register conversion patterns
163  RewritePatternSet patterns(context);
164  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
165  patterns, typeConverter);
166  patterns.add<ConvertGenericOpwithSubChannelType>(typeConverter, context);
167 
168  // Apply conversion
169  if (failed(
170  applyFullConversion(getOperation(), target, std::move(patterns))))
171  signalPassFailure();
172  }
173 };
174 
175 } // namespace
176 
177 } // namespace quant
178 } // namespace mlir
static MLIRContext * getContext(OpFoldResult val)
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:66
static UniformQuantizedPerAxisType get(unsigned flags, Type storageType, Type expressedType, ArrayRef< double > scales, ArrayRef< int64_t > zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:335
static UniformQuantizedType get(unsigned flags, Type storageType, Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
Gets an instance of the type with all parameters specified but not checked.
Definition: QuantTypes.cpp:280
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
const FrozenRewritePatternSet & patterns