20#include "llvm/ADT/SmallVectorExtras.h"
25#define GEN_PASS_DEF_NORMALIZEQUANTTYPES
26#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
36static bool isConvertibleToPerTensor(TensorType tensorType) {
37 return cast<UniformQuantizedSubChannelType>(tensorType.getElementType())
40 .getNumElements() == 1;
49static bool isConvertibleToPerAxis(TensorType tensorType) {
50 auto shape = cast<UniformQuantizedSubChannelType>(tensorType.getElementType())
54 return llvm::count_if(shape, [](
int64_t dim) {
return dim != 1; }) == 1;
61 static Type convertType(Type type) {
62 auto tensorType = dyn_cast<TensorType>(type);
68 dyn_cast<UniformQuantizedSubChannelType>(tensorType.getElementType());
69 if (!subChannelType) {
73 if (isConvertibleToPerTensor(tensorType)) {
75 subChannelType.getScales().getValues<APFloat>()[0].convertToDouble();
77 subChannelType.getZeroPoints().getValues<APInt>()[0].getSExtValue();
79 subChannelType.getFlags(), subChannelType.getStorageType(),
80 subChannelType.getExpressedType(), scale, zeroPoint,
81 subChannelType.getStorageTypeMin(),
82 subChannelType.getStorageTypeMax());
83 return tensorType.clone(perTensorType);
86 if (isConvertibleToPerAxis(tensorType)) {
87 auto shape = subChannelType.getScales().getType().getShape();
88 const auto *quantizedDimItr =
89 llvm::find_if(shape, [](
int64_t dim) {
return dim != 1; });
90 auto scales = llvm::map_to_vector(
91 subChannelType.getScales().getValues<APFloat>(),
92 [](
const APFloat &scale) { return scale.convertToDouble(); });
93 auto zeroPoints = llvm::map_to_vector(
94 subChannelType.getZeroPoints().getValues<APInt>(),
95 [](
const APInt &zeroPoint) { return zeroPoint.getSExtValue(); });
97 subChannelType.getFlags(), subChannelType.getStorageType(),
98 subChannelType.getExpressedType(), scales, zeroPoints,
99 quantizedDimItr - shape.begin(), subChannelType.getStorageTypeMin(),
100 subChannelType.getStorageTypeMax());
101 return tensorType.clone(perAxisType);
107 explicit NormalizedQuantTypesConverter() { addConversion(convertType); }
115 ConvertGenericOpwithSubChannelType(
TypeConverter &typeConverter,
116 MLIRContext *context)
120 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
121 ConversionPatternRewriter &rewriter)
const final {
122 SmallVector<Type> resultTypes;
123 if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
127 op->getLoc(), op->getName(), resultTypes, operands, op->getAttrs(),
128 op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions());
129 for (
auto regions : llvm::zip(op->getRegions(), newOp->getRegions())) {
130 Region &before = std::get<0>(regions);
131 Region &parent = std::get<1>(regions);
132 rewriter.inlineRegionBefore(before, parent, parent.end());
133 if (failed(rewriter.convertRegionTypes(&parent, *typeConverter)))
136 rewriter.insert(newOp);
137 rewriter.replaceOp(op, newOp->getResults());
143class NormalizeQuantTypes
144 :
public impl::NormalizeQuantTypesBase<NormalizeQuantTypes> {
146 void runOnOperation()
override {
150 NormalizedQuantTypesConverter typeConverter;
154 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
155 return typeConverter.isSignatureLegal(op.getFunctionType()) &&
156 typeConverter.isLegal(&op.getBody());
158 target.markUnknownOpDynamicallyLegal([&](Operation *op) {
159 return typeConverter.isLegal(op->getOperandTypes()) &&
160 typeConverter.isLegal(op->getResultTypes());
164 RewritePatternSet
patterns(context);
165 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
167 patterns.add<ConvertGenericOpwithSubChannelType>(typeConverter, context);
171 applyFullConversion(getOperation(),
target, std::move(
patterns))))
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns