MLIR  20.0.0git
StripFuncQuantTypes.cpp
Go to the documentation of this file.
1 //===- StripFuncQuantTypes.cpp - Strip quantized types --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Strips quantized types from function headers.
10 //
11 //===----------------------------------------------------------------------===//
12 
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
25 
26 namespace mlir {
27 namespace quant {
28 
29 #define GEN_PASS_DEF_STRIPFUNCQUANTTYPES
30 #include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
31 
32 namespace {
33 
34 class QuantizedTypeConverter : public TypeConverter {
35 
36  static Type convertQuantizedType(QuantizedType quantizedType) {
37  return quantizedType.getStorageType();
38  }
39 
40  static Type convertTensorType(TensorType tensorType) {
41  if (auto quantizedType =
42  dyn_cast<QuantizedType>(tensorType.getElementType()))
43  return tensorType.clone(convertQuantizedType(quantizedType));
44  return tensorType;
45  }
46 
47  static Value materializeConversion(OpBuilder &builder, Type type,
48  ValueRange inputs, Location loc) {
49  assert(inputs.size() == 1);
50  return builder.create<quant::StorageCastOp>(loc, type, inputs[0]);
51  }
52 
53 public:
54  explicit QuantizedTypeConverter() {
55  addConversion([](Type type) { return type; });
56  addConversion(convertQuantizedType);
57  addConversion(convertTensorType);
58 
59  addSourceMaterialization(materializeConversion);
60  addTargetMaterialization(materializeConversion);
61  }
62 };
63 
64 // Conversion pass
65 class StripFuncQuantTypes
66  : public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
67 
68  // Return whether a type is considered legal when occurring in the header of
69  // a function or as an operand to a 'return' op.
70  static bool isLegalType(Type type) {
71  if (auto tensorType = dyn_cast<TensorType>(type))
72  return isLegalType(tensorType.getElementType());
73  return !isa<quant::QuantizedType>(type);
74  }
75 
76 public:
77  void runOnOperation() override {
78 
79  auto moduleOp = cast<ModuleOp>(getOperation());
80  auto *context = &getContext();
81 
82  QuantizedTypeConverter typeConverter;
83  ConversionTarget target(*context);
84  RewritePatternSet patterns(context);
85 
86  // Mark func.func, func.return, and func.call illegal if they contain any
87  // quantized types.
88  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
89  return typeConverter.isSignatureLegal(op.getFunctionType()) &&
90  typeConverter.isLegal(&op.getBody());
91  });
92  target.addDynamicallyLegalOp<func::ReturnOp>(
93  [&](func::ReturnOp op) { return typeConverter.isLegal(op); });
94  target.addDynamicallyLegalOp<func::CallOp>(
95  [&](func::CallOp op) { return typeConverter.isLegal(op); });
96 
97  // Register conversion patterns
98  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
102 
103  // Apply conversion
104  if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
105  signalPassFailure();
106  }
107 };
108 
109 } // namespace
110 
111 } // namespace quant
112 } // namespace mlir
static MLIRContext * getContext(OpFoldResult val)
static Value materializeConversion(const DialectInlinerInterface *interface, SmallVectorImpl< Operation * > &castOps, OpBuilder &castBuilder, Value arg, Type type, Location conversionLoc)
Utility function used to generate a cast operation from the given interface, or return nullptr if a c...
static bool isLegalType(Type type)
Returns true if the given type is considered as legal for SPIR-V conversion.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
bool isSignatureLegal(FunctionType ty) const
Return true if the inputs and outputs of the given function type are legal.
Include the generated interface declarations.
TypeConverter & typeConverter
const FrozenRewritePatternSet & patterns
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.