19#include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
36LogicalResult verifyPerAxisQuantization(
39 auto tensorType = dyn_cast<TensorType>(containerType);
41 return op->emitError(
"scalar types may not use per-axis quantization");
43 if (!tensorType.hasRank())
46 int32_t quantizedDimension =
47 uniformQuantizedPerAxisType.getQuantizedDimension();
48 if ((int64_t)quantizedDimension >= tensorType.getRank())
49 return op->emitError(
"quantized dimension must be less than tensor rank");
51 int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
52 if (quantizedDimensionSize != ShapedType::kDynamic &&
53 quantizedDimensionSize !=
54 (int64_t)uniformQuantizedPerAxisType.getScales().size())
56 "quantized dimension size does not match number of scales");
75LogicalResult verifySubChannelQuantization(
79 auto tensorType = dyn_cast<TensorType>(containerType);
81 return op->emitError(
"scalar types may not use sub-channel quantization");
83 if (!tensorType.hasRank())
85 "tensor containing the sub-channel quantized type must be ranked");
87 const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo =
88 uniformQuantizedSubChannelType.getBlockSizeInfo();
89 auto shape = tensorType.getShape();
93 SmallVector<int64_t> expectedScaleShape(tensorType.getShape().size(), 1);
94 for (
auto [quantizedDimension, blockSize] : blockSizeInfo) {
95 if (quantizedDimension >= tensorType.getRank())
96 return op->emitError()
97 <<
"quantized dimension " << quantizedDimension
98 <<
" must be less than tensor rank " << tensorType.getRank();
99 if (!tensorType.isDynamicDim(quantizedDimension) &&
100 tensorType.getDimSize(quantizedDimension) % blockSize != 0)
101 return op->emitError()
102 <<
"tensor dimension size "
103 << tensorType.getDimSize(quantizedDimension) <<
" at axis "
104 << quantizedDimension
105 <<
" must be divisible by the corresponding block size "
107 if (tensorType.isDynamicDim(quantizedDimension))
108 expectedScaleShape[quantizedDimension] = ShapedType::kDynamic;
110 expectedScaleShape[quantizedDimension] =
111 tensorType.getDimSize(quantizedDimension) / blockSize;
126 if (llvm::is_contained(tensorType.getShape(), 0)) {
127 return op->emitError() <<
"tensor dimension size of zero is not allowed "
128 "with sub-channel quantization";
132 uniformQuantizedSubChannelType.getScales().getType().getShape();
133 if (scaleShape.size() != shape.size()) {
134 return op->emitError() <<
"Rank of scales " << scaleShape.size()
136 <<
"the rank of the tensor " << shape.size();
139 for (
auto [index, scaleDim] : llvm::enumerate(expectedScaleShape)) {
140 if (expectedScaleShape[index] != ShapedType::kDynamic &&
141 expectedScaleShape[index] != scaleShape[index])
142 return op->emitError() <<
"dimension size " << scaleDim
143 <<
" of scales tensor at axis " << index
144 <<
" should match (tensor dimension at axis / "
145 "block sizes at axis) = "
146 << expectedScaleShape[index];
165LogicalResult verifyQuantizationOp(Operation *op,
QuantizedType quantizedType,
166 FloatType floatType, Type containerType) {
167 if (quantizedType.getExpressedType() != floatType)
168 return op->emitError(
169 "expressed type in quantized type expected to match float type");
172 if (
auto quantizedPerAxisType =
173 dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
174 return verifyPerAxisQuantization(op, quantizedPerAxisType, containerType);
177 if (
auto quantizedSubChannelType =
178 dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) {
179 return verifySubChannelQuantization(op, quantizedSubChannelType,
187struct QuantInlinerInterface :
public DialectInlinerInterface {
188 using DialectInlinerInterface::DialectInlinerInterface;
190 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
201void QuantDialect::initialize() {
206#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
209 addInterfaces<QuantInlinerInterface>();
216LogicalResult DequantizeCastOp::verify() {
217 return verifyQuantizationOp(*
this, getQuantizedType(),
getFloatType(),
221OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) {
225 auto srcQcastOp = getInput().getDefiningOp<QuantizeCastOp>();
228 assert(srcQcastOp.getInput().getType() ==
getType());
229 return srcQcastOp.getInput();
232FloatType DequantizeCastOp::getFloatType() {
244LogicalResult QuantizeCastOp::verify() {
245 return verifyQuantizationOp(*
this, getQuantizedType(),
getFloatType(),
249OpFoldResult QuantizeCastOp::fold(FoldAdaptor adaptor) {
255 auto srcDcastOp = getInput().getDefiningOp<DequantizeCastOp>();
256 if (!srcDcastOp || srcDcastOp.getInput().getType() !=
getType())
258 return srcDcastOp.getInput();
261FloatType QuantizeCastOp::getFloatType() {
273LogicalResult StorageCastOp::verify() {
274 auto quantizedType = getQuantizedType();
275 auto integerType = getIntegerType();
276 if (quantizedType.getStorageType() != integerType)
278 "storage type in quantized type expected to match integer type");
283 if (
auto quantizedPerAxisType =
284 dyn_cast<UniformQuantizedPerAxisType>(quantizedType)) {
285 return verifyPerAxisQuantization(*
this, quantizedPerAxisType,
289 if (
auto quantizedSunChannelType =
290 dyn_cast<UniformQuantizedSubChannelType>(quantizedType)) {
291 return verifySubChannelQuantization(*
this, quantizedSunChannelType,
299OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
302 auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
303 if (!srcScastOp || srcScastOp.getInput().getType() !=
getType())
305 return srcScastOp.getInput();
308IntegerType StorageCastOp::getIntegerType() {
310 if (
auto integerType = dyn_cast<IntegerType>(inputScalarType))
314 return cast<IntegerType>(resultScalarType);
319 if (
auto quantizedType = dyn_cast<QuantizedType>(inputScalarType))
320 return quantizedType;
323 return cast<QuantizedType>(resultScalarType);
329#define GET_OP_CLASSES
330#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
A quantized type that maps storage to/from expressed types in an unspecified way.
A quantized type that infers its range from given min/max values.
Base class for all quantized types known to this dialect.
FloatType getFloatType(MLIRContext *context, unsigned width)
Returns a supported MLIR floating point type of the given bit width or null if the bit width is not s...
void addBytecodeInterface(QuantDialect *dialect)
Add the interfaces necessary for encoding the quantization dialect components in bytecode.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.