24 #define GEN_PASS_DEF_NORMALIZEQUANTTYPES
25 #include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
35 static bool isConvertibleToPerTensor(TensorType tensorType) {
36 return cast<UniformQuantizedSubChannelType>(tensorType.getElementType())
39 .getNumElements() == 1;
48 static bool isConvertibleToPerAxis(TensorType tensorType) {
49 auto shape = cast<UniformQuantizedSubChannelType>(tensorType.getElementType())
53 return llvm::count_if(shape, [](int64_t dim) {
return dim != 1; }) == 1;
58 class NormalizedQuantTypesConverter :
public TypeConverter {
60 static Type convertType(Type type) {
61 auto tensorType = dyn_cast<TensorType>(type);
67 dyn_cast<UniformQuantizedSubChannelType>(tensorType.getElementType());
68 if (!subChannelType) {
72 if (isConvertibleToPerTensor(tensorType)) {
74 subChannelType.getScales().getValues<APFloat>()[0].convertToDouble();
76 subChannelType.getZeroPoints().getValues<APInt>()[0].getSExtValue();
78 subChannelType.getFlags(), subChannelType.getStorageType(),
79 subChannelType.getExpressedType(), scale, zeroPoint,
80 subChannelType.getStorageTypeMin(),
81 subChannelType.getStorageTypeMax());
82 return tensorType.clone(perTensorType);
85 if (isConvertibleToPerAxis(tensorType)) {
86 auto shape = subChannelType.getScales().getType().getShape();
87 auto quantizedDimItr =
88 llvm::find_if(shape, [](int64_t dim) {
return dim != 1; });
89 auto scales = llvm::to_vector(llvm::map_range(
90 subChannelType.getScales().getValues<APFloat>(),
91 [](APFloat scale) { return scale.convertToDouble(); }));
92 auto zeroPoints = llvm::to_vector(llvm::map_range(
93 subChannelType.getZeroPoints().getValues<APInt>(),
94 [](APInt zeroPoint) { return zeroPoint.getSExtValue(); }));
96 subChannelType.getFlags(), subChannelType.getStorageType(),
97 subChannelType.getExpressedType(), scales, zeroPoints,
98 quantizedDimItr - shape.begin(), subChannelType.getStorageTypeMin(),
99 subChannelType.getStorageTypeMax());
100 return tensorType.clone(perAxisType);
106 explicit NormalizedQuantTypesConverter() { addConversion(convertType); }
112 class ConvertGenericOpwithSubChannelType :
public ConversionPattern {
114 ConvertGenericOpwithSubChannelType(TypeConverter &typeConverter,
115 MLIRContext *context)
116 : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) {}
119 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
120 ConversionPatternRewriter &rewriter)
const final {
121 SmallVector<Type> resultTypes;
122 if (
failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
126 op->getLoc(), op->getName(), resultTypes, operands, op->getAttrs(),
127 op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions());
128 for (
auto regions : llvm::zip(op->getRegions(), newOp->getRegions())) {
129 Region &before = std::get<0>(regions);
130 Region &parent = std::get<1>(regions);
131 rewriter.inlineRegionBefore(before, parent, parent.end());
132 if (
failed(rewriter.convertRegionTypes(&parent, *typeConverter)))
135 rewriter.insert(newOp);
136 rewriter.replaceOp(op, newOp->getResults());
142 class NormalizeQuantTypes
143 :
public impl::NormalizeQuantTypesBase<NormalizeQuantTypes> {
145 void runOnOperation()
override {
149 NormalizedQuantTypesConverter typeConverter;
150 ConversionTarget target(*context);
153 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
154 return typeConverter.isSignatureLegal(op.getFunctionType()) &&
155 typeConverter.isLegal(&op.getBody());
157 target.markUnknownOpDynamicallyLegal([&](Operation *op) {
158 return typeConverter.isLegal(op->getOperandTypes()) &&
159 typeConverter.isLegal(op->getResultTypes());
163 RewritePatternSet
patterns(context);
164 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
166 patterns.add<ConvertGenericOpwithSubChannelType>(typeConverter, context);
static MLIRContext * getContext(OpFoldResult val)
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Include the generated interface declarations.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
const FrozenRewritePatternSet & patterns