18 #include "mlir/Dialect/Quant/IR/QuantOpsDialect.cpp.inc"
37 LogicalResult verifyPerAxisQuantization(Operation *op,
38 QuantizedType quantizedType,
40 auto quantizedPerAxisType = dyn_cast<UniformQuantizedPerAxisType>(quantizedType);
41 if (!quantizedPerAxisType)
44 auto tensorType = dyn_cast<TensorType>(containerType);
46 return op->emitError(
"scalar types may not use per-axis quantization");
48 if (!tensorType.hasRank())
51 int64_t quantizedDimension = quantizedPerAxisType.getQuantizedDimension();
52 if (quantizedDimension >= tensorType.getRank())
53 return op->emitError(
"quantized dimension must be less than tensor rank");
55 int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
56 if (quantizedDimensionSize != ShapedType::kDynamic &&
57 quantizedDimensionSize != (int64_t)quantizedPerAxisType.getScales().size())
59 "quantized dimension size does not match number of scales");
77 LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
78 FloatType floatType, Type containerType) {
79 if (quantizedType.getExpressedType() != floatType)
81 "expressed type in quantized type expected to match float type");
84 return verifyPerAxisQuantization(op, quantizedType, containerType);
94 void QuantDialect::initialize() {
95 addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
96 UniformQuantizedPerAxisType>();
99 #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
110 return verifyQuantizationOp(*
this, getQuantizedType(),
getFloatType(),
114 OpFoldResult DequantizeCastOp::fold(FoldAdaptor adaptor) {
118 auto srcQcastOp = getInput().getDefiningOp<QuantizeCastOp>();
121 assert(srcQcastOp.getInput().getType() ==
getType());
122 return srcQcastOp.getInput();
129 QuantizedType DequantizeCastOp::getQuantizedType() {
139 return verifyQuantizationOp(*
this, getQuantizedType(),
getFloatType(),
143 OpFoldResult QuantizeCastOp::fold(FoldAdaptor adaptor) {
149 auto srcDcastOp = getInput().getDefiningOp<DequantizeCastOp>();
150 if (!srcDcastOp || srcDcastOp.getInput().getType() !=
getType())
152 return srcDcastOp.getInput();
159 QuantizedType QuantizeCastOp::getQuantizedType() {
169 auto quantizedType = getQuantizedType();
170 auto integerType = getIntegerType();
171 if (quantizedType.getStorageType() != integerType)
173 "storage type in quantized type expected to match integer type");
178 return verifyPerAxisQuantization(*
this, quantizedType, getInput().
getType());
181 OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) {
184 auto srcScastOp = getInput().getDefiningOp<StorageCastOp>();
185 if (!srcScastOp || srcScastOp.getInput().getType() !=
getType())
187 return srcScastOp.getInput();
190 IntegerType StorageCastOp::getIntegerType() {
192 if (
auto integerType = dyn_cast<IntegerType>(inputScalarType))
196 return cast<IntegerType>(resultScalarType);
199 QuantizedType StorageCastOp::getQuantizedType() {
201 if (
auto quantizedType = dyn_cast<QuantizedType>(inputScalarType))
202 return quantizedType;
205 return cast<QuantizedType>(resultScalarType);
212 #define GET_OP_CLASSES
213 #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
FloatType getFloatType(MLIRContext *context, unsigned width)
Returns a supported MLIR floating point type of the given bit width or null if the bit width is not s...
void addBytecodeInterface(QuantDialect *dialect)
Add the interfaces necessary for encoding the quantization dialect components in bytecode.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...