29 #define GEN_PASS_DEF_STRIPFUNCQUANTTYPES
30 #include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
34 class QuantizedTypeConverter :
public TypeConverter {
36 static Type convertQuantizedType(QuantizedType quantizedType) {
37 return quantizedType.getStorageType();
40 static Type convertTensorType(TensorType tensorType) {
41 if (
auto quantizedType =
42 dyn_cast<QuantizedType>(tensorType.getElementType()))
43 return tensorType.clone(convertQuantizedType(quantizedType));
48 ValueRange inputs, Location loc) {
49 assert(inputs.size() == 1);
50 return builder.create<quant::StorageCastOp>(loc, type, inputs[0]);
54 explicit QuantizedTypeConverter() {
55 addConversion([](Type type) {
return type; });
56 addConversion(convertQuantizedType);
57 addConversion(convertTensorType);
66 class StripFuncQuantTypes
67 :
public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
72 if (
auto tensorType = dyn_cast<TensorType>(type))
74 return !isa<quant::QuantizedType>(type);
78 void runOnOperation()
override {
80 auto moduleOp = cast<ModuleOp>(getOperation());
83 QuantizedTypeConverter typeConverter;
84 ConversionTarget target(*context);
85 RewritePatternSet patterns(context);
89 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
90 return typeConverter.isSignatureLegal(op.getFunctionType()) &&
91 typeConverter.isLegal(&op.getBody());
93 target.addDynamicallyLegalOp<func::ReturnOp>(
94 [&](func::ReturnOp op) {
return typeConverter.isLegal(op); });
95 target.addDynamicallyLegalOp<func::CallOp>(
96 [&](func::CallOp op) {
return typeConverter.isLegal(op); });
99 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
100 patterns, typeConverter);
static MLIRContext * getContext(OpFoldResult val)
static Value materializeConversion(const DialectInlinerInterface *interface, SmallVectorImpl< Operation * > &castOps, OpBuilder &castBuilder, Value arg, Type type, Location conversionLoc)
Utility function used to generate a cast operation from the given interface, or return nullptr if a c...
static bool isLegalType(Type type)
Returns true if the given type is considered as legal for SPIR-V conversion.
Include the generated interface declarations.
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.