MLIR  21.0.0git
StripFuncQuantTypes.cpp
Go to the documentation of this file.
1 //===- StripFuncQuantTypes.cpp - Strip quantized types --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Strips quantized types from function headers.
10 //
11 //===----------------------------------------------------------------------===//
12 
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
25 
26 namespace mlir {
27 namespace quant {
28 
29 #define GEN_PASS_DEF_STRIPFUNCQUANTTYPES
30 #include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
31 
32 namespace {
33 
34 class QuantizedTypeConverter : public TypeConverter {
35 
36  static Type convertQuantizedType(QuantizedType quantizedType) {
37  return quantizedType.getStorageType();
38  }
39 
40  static Type convertTensorType(TensorType tensorType) {
41  if (auto quantizedType =
42  dyn_cast<QuantizedType>(tensorType.getElementType()))
43  return tensorType.clone(convertQuantizedType(quantizedType));
44  return tensorType;
45  }
46 
47  static Value materializeConversion(OpBuilder &builder, Type type,
48  ValueRange inputs, Location loc) {
49  return builder.create<quant::StorageCastOp>(loc, type,
50  llvm::getSingleElement(inputs));
51  }
52 
53 public:
54  explicit QuantizedTypeConverter() {
55  addConversion([](Type type) { return type; });
56  addConversion(convertQuantizedType);
57  addConversion(convertTensorType);
58 
59  addSourceMaterialization(materializeConversion);
60  addTargetMaterialization(materializeConversion);
61  }
62 };
63 
64 // Conversion pass
65 class StripFuncQuantTypes
66  : public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
67 
68 public:
69  void runOnOperation() override {
70 
71  auto moduleOp = cast<ModuleOp>(getOperation());
72  auto *context = &getContext();
73 
74  QuantizedTypeConverter typeConverter;
75  ConversionTarget target(*context);
76  RewritePatternSet patterns(context);
77 
78  // Mark func.func, func.return, and func.call illegal if they contain any
79  // quantized types.
80  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
81  return typeConverter.isSignatureLegal(op.getFunctionType()) &&
82  typeConverter.isLegal(&op.getBody());
83  });
84  target.addDynamicallyLegalOp<func::ReturnOp>(
85  [&](func::ReturnOp op) { return typeConverter.isLegal(op); });
86  target.addDynamicallyLegalOp<func::CallOp>(
87  [&](func::CallOp op) { return typeConverter.isLegal(op); });
88 
89  // Register conversion patterns
90  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
91  patterns, typeConverter);
94 
95  // Apply conversion
96  if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
97  signalPassFailure();
98  }
99 };
100 
101 } // namespace
102 
103 } // namespace quant
104 } // namespace mlir
static MLIRContext * getContext(OpFoldResult val)
static Value materializeConversion(const DialectInlinerInterface *interface, SmallVectorImpl< Operation * > &castOps, OpBuilder &castBuilder, Value arg, Type type, Location conversionLoc)
Utility function used to generate a cast operation from the given interface, or return nullptr if a c...
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.