19#include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
36LogicalResult verifyPerAxisQuantization(
39 auto tensorType = dyn_cast<TensorType>(containerType);
41 return op->emitError(
"scalar types may not use per-axis quantization");
43 if (!tensorType.hasRank())
46 int32_t quantizedDimension =
47 uniformQuantizedPerAxisType.getQuantizedDimension();
48 if ((int64_t)quantizedDimension >= tensorType.getRank())
49 return op->emitError(
"quantized dimension must be less than tensor rank");
51 int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
52 if (quantizedDimensionSize != ShapedType::kDynamic &&
53 quantizedDimensionSize !=
54 (int64_t)uniformQuantizedPerAxisType.getScales().size())
56 "quantized dimension size does not match number of scales");
75LogicalResult verifySubChannelQuantization(
79 auto tensorType = dyn_cast<TensorType>(containerType);
81 return op->emitError(
"scalar types may not use sub-channel quantization");
83 if (!tensorType.hasRank())
85 "tensor containing the sub-channel quantized type must be ranked");
87 const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo =
88 uniformQuantizedSubChannelType.getBlockSizeInfo();
89 auto shape = tensorType.getShape();
93 SmallVector<int64_t> expectedScaleShape(tensorType.getShape().size(), 1);
94 for (
auto [quantizedDimension, blockSize] : blockSizeInfo) {
95 if (quantizedDimension >= tensorType.getRank())
96 return op->emitError()
97 <<
"quantized dimension " << quantizedDimension
98 <<
" must be less than tensor rank " << tensorType.getRank();
99 if (!tensorType.isDynamicDim(quantizedDimension) &&
100 tensorType.getDimSize(quantizedDimension) % blockSize != 0)
101 return op->emitError()
102 <<
"tensor dimension size "
103 << tensorType.getDimSize(quantizedDimension) <<
" at axis "
104 << quantizedDimension
105 <<
" must be divisible by the corresponding block size "
107 if (tensorType.isDynamicDim(quantizedDimension))
108 expectedScaleShape[quantizedDimension] = ShapedType::kDynamic;
110 expectedScaleShape[quantizedDimension] =
111 tensorType.getDimSize(quantizedDimension) / blockSize;
126 if (llvm::is_contained(tensorType.getShape(), 0)) {
127 return op->emitError() <<
"tensor dimension size of zero is not allowed "
128 "with sub-channel quantization";
132 uniformQuantizedSubChannelType.getScales().getType().getShape();
133 if (scaleShape.size() != shape.size()) {
134 return op->emitError() <<
"Rank of scales " << scaleShape.size()
136 <<
"the rank of the tensor " << shape.size();
139 for (
auto [index, scaleDim] : llvm::enumerate(expectedScaleShape)) {
140 if (expectedScaleShape[index] != ShapedType::kDynamic &&
141 expectedScaleShape[index] != scaleShape[index])
142 return op->emitError() <<
"dimension size " << scaleDim
143 <<
" of scales tensor at axis " << index
144 <<
" should match (tensor dimension at axis / "
145 "block sizes at axis) = "
146 << expectedScaleShape[index];
165LogicalResult verifyQuantizationOp(Operation *op,
QuantizedType quantizedType,
166 FloatType floatType, Type containerType) {
167 if (quantizedType.getExpressedType() != floatType)
168 return op->emitError(
169 "expressed type in quantized type expected to match float type");
172 if (
auto quantizedPerAxisType =
173 dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
174 return verifyPerAxisQuantization(op, quantizedPerAxisType, containerType);
177 if (
auto quantizedSubChannelType =
178 dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) {
179 return verifySubChannelQuantization(op, quantizedSubChannelType,
187struct QuantInlinerInterface :
public DialectInlinerInterface {
188 using DialectInlinerInterface::DialectInlinerInterface;
190 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
201void QuantDialect::initialize() {
207#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
210 addInterfaces<QuantInlinerInterface>();
217LogicalResult DequantizeCastOp::verify() {
218 return verifyQuantizationOp(*
this, getQuantizedType(),
getFloatType(),
222OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) {
226 auto srcQcastOp = getInput().getDefiningOp<QuantizeCastOp>();
229 assert(srcQcastOp.getInput().getType() ==
getType());
230 return srcQcastOp.getInput();
233FloatType DequantizeCastOp::getFloatType() {
245LogicalResult QuantizeCastOp::verify() {
246 return verifyQuantizationOp(*
this, getQuantizedType(),
getFloatType(),
250OpFoldResult QuantizeCastOp::fold(FoldAdaptor adaptor) {
256 auto srcDcastOp = getInput().getDefiningOp<DequantizeCastOp>();
257 if (!srcDcastOp || srcDcastOp.getInput().getType() !=
getType())
259 return srcDcastOp.getInput();
262FloatType QuantizeCastOp::getFloatType() {
274LogicalResult StorageCastOp::verify() {
275 auto quantizedType = getQuantizedType();
276 auto integerType = getIntegerType();
277 if (quantizedType.getStorageType() != integerType)
279 "storage type in quantized type expected to match integer type");
284 if (
auto quantizedPerAxisType =
285 dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
286 return verifyPerAxisQuantization(*
this, quantizedPerAxisType,
290 if (
auto quantizedSunChannelType =
291 dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) {
292 return verifySubChannelQuantization(*
this, quantizedSunChannelType,
300OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
303 auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
304 if (!srcScastOp || srcScastOp.getInput().getType() !=
getType())
306 return srcScastOp.getInput();
309IntegerType StorageCastOp::getIntegerType() {
311 if (
auto integerType = dyn_cast<IntegerType>(inputScalarType))
315 return cast<IntegerType>(resultScalarType);
320 if (
auto quantizedType = dyn_cast<QuantizedType>(inputScalarType))
321 return quantizedType;
324 return cast<QuantizedType>(resultScalarType);
330#define GET_OP_CLASSES
331#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
A quantized type that maps storage to/from expressed types in an unspecified way.
A quantized type that infers its range from given min/max values.
Base class for all quantized types known to this dialect.
FloatType getFloatType(MLIRContext *context, unsigned width)
Returns a supported MLIR floating point type of the given bit width or null if the bit width is not s...
void addBytecodeInterface(QuantDialect *dialect)
Add the interfaces necessary for encoding the quantization dialect components in bytecode.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.