MLIR  20.0.0git
QuantOps.cpp
Go to the documentation of this file.
1 //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "QuantDialectBytecode.h"
10 #include "TypeDetail.h"
11 
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/TypeUtilities.h"
17 
18 #include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
19 
20 
21 namespace mlir {
22 namespace quant {
23 
24 namespace {
25 
26 // Verify the integrity of per-axis quantization information, if present.
27 //
28 // - quantizedType
29 // Any quantized type. Any quantized type with no per-axis quantization is
30 // ignored.
31 //
32 // - containerType
33 // Original input or result type of the operation using the provided quantized
34 // type. Used to ensure that the quantized type appears within a tensor and
35 // that the tensor is compatible with per-axis quantization information.
36 //
37 LogicalResult verifyPerAxisQuantization(Operation *op,
38  QuantizedType quantizedType,
39  Type containerType) {
40  auto quantizedPerAxisType = dyn_cast<UniformQuantizedPerAxisType>(quantizedType);
41  if (!quantizedPerAxisType)
42  return success();
43 
44  auto tensorType = dyn_cast<TensorType>(containerType);
45  if (!tensorType)
46  return op->emitError("scalar types may not use per-axis quantization");
47 
48  if (!tensorType.hasRank())
49  return success();
50 
51  int64_t quantizedDimension = quantizedPerAxisType.getQuantizedDimension();
52  if (quantizedDimension >= tensorType.getRank())
53  return op->emitError("quantized dimension must be less than tensor rank");
54 
55  int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
56  if (quantizedDimensionSize != ShapedType::kDynamic &&
57  quantizedDimensionSize != (int64_t)quantizedPerAxisType.getScales().size())
58  return op->emitError(
59  "quantized dimension size does not match number of scales");
60 
61  return success();
62 }
63 
64 // Common verification logic for 'quant.dcast' and 'quant.qcast' ops.
65 //
66 // - quantizedType
67 // Quantized type used in the input ('quant.dcast') or result ('quant.qcast'),
68 // whether as a primitive type or in a tensor.
69 //
70 // - floatType
71 // Float type used in the input ('quant.qcast') or result ('quant.dcast'),
72 // whether as a primitive type or in a tensor.
73 //
74 // - containerType
75 // Type of original input or result.
76 //
77 LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
78  FloatType floatType, Type containerType) {
79  if (quantizedType.getExpressedType() != floatType)
80  return op->emitError(
81  "expressed type in quantized type expected to match float type");
82 
83  // Veriy integrity of per-axis quantization information, if present.
84  return verifyPerAxisQuantization(op, quantizedType, containerType);
85 }
86 
87 } // namespace
88 
89 
90 //===----------------------------------------------------------------------===//
91 // Dialect
92 //===----------------------------------------------------------------------===//
93 
94 void QuantDialect::initialize() {
95  addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
96  UniformQuantizedPerAxisType>();
97  addOperations<
98 #define GET_OP_LIST
99 #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
100  >();
102 }
103 
104 
105 //===----------------------------------------------------------------------===//
106 // DequantizeCastOp
107 //===----------------------------------------------------------------------===//
108 
109 LogicalResult DequantizeCastOp::verify() {
110  return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
111  getInput().getType());
112 }
113 
114 OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) {
115  // Matches x -> quant.qcast -> quant.dcast -> y, replacing the quant.dcast op
116  // with the value of x. Values x and y are guaranteed to be of the same type
117  // in this pattern.
118  auto srcQcastOp = getInput().getDefiningOp<QuantizeCastOp>();
119  if (!srcQcastOp)
120  return {};
121  assert(srcQcastOp.getInput().getType() == getType());
122  return srcQcastOp.getInput();
123 }
124 
125 FloatType DequantizeCastOp::getFloatType() {
126  return cast<FloatType>(getElementTypeOrSelf(getResult().getType()));
127 }
128 
129 QuantizedType DequantizeCastOp::getQuantizedType() {
130  return cast<QuantizedType>(getElementTypeOrSelf(getInput().getType()));
131 }
132 
133 
134 //===----------------------------------------------------------------------===//
135 // QuantizeCastOp
136 //===----------------------------------------------------------------------===//
137 
138 LogicalResult QuantizeCastOp::verify() {
139  return verifyQuantizationOp(*this, getQuantizedType(), getFloatType(),
140  getInput().getType());
141 }
142 
143 OpFoldResult QuantizeCastOp::fold(FoldAdaptor adaptor) {
144  // Matches x -> quant.dcast -> quant.qcast -> y, replacing the quant.qcast op
145  // with the value of x if the casts invert each other. Contrary to the folding
146  // pattern in quant.dcast (i.e., x -> quant.qcast -> quant.dcast -> y), values
147  // x and y are not guaranteed to be of the same type here, as they may use
148  // different quantization parameters.
149  auto srcDcastOp = getInput().getDefiningOp<DequantizeCastOp>();
150  if (!srcDcastOp || srcDcastOp.getInput().getType() != getType())
151  return {};
152  return srcDcastOp.getInput();
153 }
154 
155 FloatType QuantizeCastOp::getFloatType() {
156  return cast<FloatType>(getElementTypeOrSelf(getInput().getType()));
157 }
158 
159 QuantizedType QuantizeCastOp::getQuantizedType() {
160  return cast<QuantizedType>(getElementTypeOrSelf(getResult().getType()));
161 }
162 
163 
164 //===----------------------------------------------------------------------===//
165 // StorageCastOp
166 //===----------------------------------------------------------------------===//
167 
168 LogicalResult StorageCastOp::verify() {
169  auto quantizedType = getQuantizedType();
170  auto integerType = getIntegerType();
171  if (quantizedType.getStorageType() != integerType)
172  return emitError(
173  "storage type in quantized type expected to match integer type");
174 
175  // Verify integrity of per-axis quantization information, if available. While
176  // the quantization type may appear in the input or the result, their tensor
177  // shapes are guaranteed to be identical at this point.
178  return verifyPerAxisQuantization(*this, quantizedType, getInput().getType());
179 }
180 
181 OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
182  // Matches x -> quant.scast -> quant.scast -> y, replacing the second
183  // quant.scast with the value of x if the casts invert each other.
184  auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
185  if (!srcScastOp || srcScastOp.getInput().getType() != getType())
186  return {};
187  return srcScastOp.getInput();
188 }
189 
190 IntegerType StorageCastOp::getIntegerType() {
191  auto inputScalarType = getElementTypeOrSelf(getInput().getType());
192  if (auto integerType = dyn_cast<IntegerType>(inputScalarType))
193  return integerType;
194 
195  auto resultScalarType = getElementTypeOrSelf(getResult().getType());
196  return cast<IntegerType>(resultScalarType);
197 }
198 
199 QuantizedType StorageCastOp::getQuantizedType() {
200  auto inputScalarType = getElementTypeOrSelf(getInput().getType());
201  if (auto quantizedType = dyn_cast<QuantizedType>(inputScalarType))
202  return quantizedType;
203 
204  auto resultScalarType = getElementTypeOrSelf(getResult().getType());
205  return cast<QuantizedType>(resultScalarType);
206 }
207 
208 
209 } // namespace quant
210 } // namespace mlir
211 
212 #define GET_OP_CLASSES
213 #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
214 
FloatType getFloatType(MLIRContext *context, unsigned width)
Returns a supported MLIR floating point type of the given bit width or null if the bit width is not s...
void addBytecodeInterface(QuantDialect *dialect)
Add the interfaces necessary for encoding the quantization dialect components in bytecode.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425