18 #include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
35 LogicalResult verifyPerAxisQuantization(
36 Operation *op, UniformQuantizedPerAxisType uniformQuantizedPerAxisType,
38 auto tensorType = dyn_cast<TensorType>(containerType);
40 return op->emitError(
"scalar types may not use per-axis quantization");
42 if (!tensorType.hasRank())
45 int32_t quantizedDimension =
46 uniformQuantizedPerAxisType.getQuantizedDimension();
47 if ((int64_t)quantizedDimension >= tensorType.getRank())
48 return op->emitError(
"quantized dimension must be less than tensor rank");
50 int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
51 if (quantizedDimensionSize != ShapedType::kDynamic &&
52 quantizedDimensionSize !=
53 (int64_t)uniformQuantizedPerAxisType.getScales().size())
55 "quantized dimension size does not match number of scales");
74 LogicalResult verifySubChannelQuantization(
76 UniformQuantizedSubChannelType uniformQuantizedSubChannelType,
78 auto tensorType = dyn_cast<TensorType>(containerType);
80 return op->emitError(
"scalar types may not use sub-channel quantization");
82 if (!tensorType.hasRank())
84 "tensor containing the sub-channel quantized type must be ranked");
86 const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo =
87 uniformQuantizedSubChannelType.getBlockSizeInfo();
88 auto shape = tensorType.getShape();
92 SmallVector<int64_t> expectedScaleShape(tensorType.getShape().size(), 1);
93 for (
auto [quantizedDimension, blockSize] : blockSizeInfo) {
94 if (quantizedDimension >= tensorType.getRank())
95 return op->emitError()
96 <<
"quantized dimension " << quantizedDimension
97 <<
" must be less than tensor rank " << tensorType.getRank();
98 if (!tensorType.isDynamicDim(quantizedDimension) &&
99 tensorType.getDimSize(quantizedDimension) % blockSize != 0)
100 return op->emitError()
101 <<
"tensor dimension size "
102 << tensorType.getDimSize(quantizedDimension) <<
" at axis "
103 << quantizedDimension
104 <<
" must be divisible by the corresponding block size "
106 if (tensorType.isDynamicDim(quantizedDimension))
107 expectedScaleShape[quantizedDimension] = ShapedType::kDynamic;
109 expectedScaleShape[quantizedDimension] =
110 tensorType.getDimSize(quantizedDimension) / blockSize;
125 if (llvm::find(tensorType.getShape(), 0) != tensorType.getShape().end()) {
126 return op->emitError() <<
"tensor dimension size of zero is not allowed "
127 "with sub-channel quantization";
131 uniformQuantizedSubChannelType.getScales().getType().getShape();
132 if (scaleShape.size() != shape.size()) {
133 return op->emitError() <<
"Rank of scales " << scaleShape.size()
135 <<
"the rank of the tensor " << shape.size();
139 if (expectedScaleShape[index] != ShapedType::kDynamic &&
140 expectedScaleShape[index] != scaleShape[index])
141 return op->emitError() <<
"dimension size " << scaleDim
142 <<
" of scales tensor at axis " << index
143 <<
" should match (tensor dimension at axis / "
144 "block sizes at axis) = "
145 << expectedScaleShape[index];
164 LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
165 FloatType floatType, Type containerType) {
166 if (quantizedType.getExpressedType() != floatType)
167 return op->emitError(
168 "expressed type in quantized type expected to match float type");
171 if (
auto quantizedPerAxisType =
172 dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
173 return verifyPerAxisQuantization(op, quantizedPerAxisType, containerType);
176 if (
auto quantizedSubChannelType =
177 dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) {
178 return verifySubChannelQuantization(op, quantizedSubChannelType,
192 void QuantDialect::initialize() {
193 addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
194 UniformQuantizedPerAxisType, UniformQuantizedSubChannelType>();
197 #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
207 return verifyQuantizationOp(*
this, getQuantizedType(),
getFloatType(),
211 OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) {
215 auto srcQcastOp = getInput().getDefiningOp<QuantizeCastOp>();
218 assert(srcQcastOp.getInput().getType() ==
getType());
219 return srcQcastOp.getInput();
226 QuantizedType DequantizeCastOp::getQuantizedType() {
235 return verifyQuantizationOp(*
this, getQuantizedType(),
getFloatType(),
239 OpFoldResult QuantizeCastOp::fold(FoldAdaptor adaptor) {
245 auto srcDcastOp = getInput().getDefiningOp<DequantizeCastOp>();
246 if (!srcDcastOp || srcDcastOp.getInput().getType() !=
getType())
248 return srcDcastOp.getInput();
255 QuantizedType QuantizeCastOp::getQuantizedType() {
264 auto quantizedType = getQuantizedType();
265 auto integerType = getIntegerType();
266 if (quantizedType.getStorageType() != integerType)
268 "storage type in quantized type expected to match integer type");
273 if (
auto quantizedPerAxisType =
274 dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
275 return verifyPerAxisQuantization(*
this, quantizedPerAxisType,
279 if (
auto quantizedSunChannelType =
280 dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) {
281 return verifySubChannelQuantization(*
this, quantizedSunChannelType,
289 OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
292 auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
293 if (!srcScastOp || srcScastOp.getInput().getType() !=
getType())
295 return srcScastOp.getInput();
298 IntegerType StorageCastOp::getIntegerType() {
300 if (
auto integerType = dyn_cast<IntegerType>(inputScalarType))
304 return cast<IntegerType>(resultScalarType);
307 QuantizedType StorageCastOp::getQuantizedType() {
309 if (
auto quantizedType = dyn_cast<QuantizedType>(inputScalarType))
310 return quantizedType;
313 return cast<QuantizedType>(resultScalarType);
319 #define GET_OP_CLASSES
320 #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
FloatType getFloatType(MLIRContext *context, unsigned width)
Returns a supported MLIR floating point type of the given bit width or null if the bit width is not s...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void addBytecodeInterface(QuantDialect *dialect)
Add the interfaces necessary for encoding the quantization dialect components in bytecode.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...