MLIR  21.0.0git
StripFuncQuantTypes.cpp
Go to the documentation of this file.
1 //===- StripFuncQuantTypes.cpp - Strip quantized types --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Strips quantized types from function headers.
10 //
11 //===----------------------------------------------------------------------===//
12 
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
25 
26 namespace mlir {
27 namespace quant {
28 
29 #define GEN_PASS_DEF_STRIPFUNCQUANTTYPES
30 #include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
31 
32 namespace {
33 
34 class QuantizedTypeConverter : public TypeConverter {
35 
36  static Type convertQuantizedType(QuantizedType quantizedType) {
37  return quantizedType.getStorageType();
38  }
39 
40  static Type convertTensorType(TensorType tensorType) {
41  if (auto quantizedType =
42  dyn_cast<QuantizedType>(tensorType.getElementType()))
43  return tensorType.clone(convertQuantizedType(quantizedType));
44  return tensorType;
45  }
46 
47  static Value materializeConversion(OpBuilder &builder, Type type,
48  ValueRange inputs, Location loc) {
49  assert(inputs.size() == 1);
50  return builder.create<quant::StorageCastOp>(loc, type, inputs[0]);
51  }
52 
53 public:
54  explicit QuantizedTypeConverter() {
55  addConversion([](Type type) { return type; });
56  addConversion(convertQuantizedType);
57  addConversion(convertTensorType);
58 
59  addSourceMaterialization(materializeConversion);
60  addTargetMaterialization(materializeConversion);
61  }
62 };
63 
64 // Conversion pass
65 class StripFuncQuantTypes
66  : public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
67 
68  // Return whether a type is considered legal when occurring in the header of
69  // a function or as an operand to a 'return' op.
70  static bool isLegalType(Type type) {
71  if (auto tensorType = dyn_cast<TensorType>(type))
72  return isLegalType(tensorType.getElementType());
73  return !isa<quant::QuantizedType>(type);
74  }
75 
76 public:
77  void runOnOperation() override {
78 
79  auto moduleOp = cast<ModuleOp>(getOperation());
80  auto *context = &getContext();
81 
82  QuantizedTypeConverter typeConverter;
83  ConversionTarget target(*context);
84  RewritePatternSet patterns(context);
85 
86  // Mark func.func, func.return, and func.call illegal if they contain any
87  // quantized types.
88  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
89  return typeConverter.isSignatureLegal(op.getFunctionType()) &&
90  typeConverter.isLegal(&op.getBody());
91  });
92  target.addDynamicallyLegalOp<func::ReturnOp>(
93  [&](func::ReturnOp op) { return typeConverter.isLegal(op); });
94  target.addDynamicallyLegalOp<func::CallOp>(
95  [&](func::CallOp op) { return typeConverter.isLegal(op); });
96 
97  // Register conversion patterns
98  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
99  patterns, typeConverter);
102 
103  // Apply conversion
104  if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
105  signalPassFailure();
106  }
107 };
108 
109 } // namespace
110 
111 } // namespace quant
112 } // namespace mlir
static MLIRContext * getContext(OpFoldResult val)
static Value materializeConversion(const DialectInlinerInterface *interface, SmallVectorImpl< Operation * > &castOps, OpBuilder &castBuilder, Value arg, Type type, Location conversionLoc)
Utility function used to generate a cast operation from the given interface, or return nullptr if a c...
static bool isLegalType(Type type)
Returns true if the given type is considered as legal for SPIR-V conversion.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.