MLIR  20.0.0git
StripFuncQuantTypes.cpp
Go to the documentation of this file.
1 //===- StripFuncQuantTypes.cpp - Strip quantized types --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Strips quantized types from function headers.
10 //
11 //===----------------------------------------------------------------------===//
12 
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
25 
26 namespace mlir {
27 namespace quant {
28 
29 #define GEN_PASS_DEF_STRIPFUNCQUANTTYPES
30 #include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
31 
32 namespace {
33 
34 class QuantizedTypeConverter : public TypeConverter {
35 
36  static Type convertQuantizedType(QuantizedType quantizedType) {
37  return quantizedType.getStorageType();
38  }
39 
40  static Type convertTensorType(TensorType tensorType) {
41  if (auto quantizedType =
42  dyn_cast<QuantizedType>(tensorType.getElementType()))
43  return tensorType.clone(convertQuantizedType(quantizedType));
44  return tensorType;
45  }
46 
47  static Value materializeConversion(OpBuilder &builder, Type type,
48  ValueRange inputs, Location loc) {
49  assert(inputs.size() == 1);
50  return builder.create<quant::StorageCastOp>(loc, type, inputs[0]);
51  }
52 
53 public:
54  explicit QuantizedTypeConverter() {
55  addConversion([](Type type) { return type; });
56  addConversion(convertQuantizedType);
57  addConversion(convertTensorType);
58 
59  addArgumentMaterialization(materializeConversion);
60  addSourceMaterialization(materializeConversion);
61  addTargetMaterialization(materializeConversion);
62  }
63 };
64 
65 // Conversion pass
66 class StripFuncQuantTypes
67  : public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
68 
69  // Return whether a type is considered legal when occurring in the header of
70  // a function or as an operand to a 'return' op.
71  static bool isLegalType(Type type) {
72  if (auto tensorType = dyn_cast<TensorType>(type))
73  return isLegalType(tensorType.getElementType());
74  return !isa<quant::QuantizedType>(type);
75  }
76 
77 public:
78  void runOnOperation() override {
79 
80  auto moduleOp = cast<ModuleOp>(getOperation());
81  auto *context = &getContext();
82 
83  QuantizedTypeConverter typeConverter;
84  ConversionTarget target(*context);
85  RewritePatternSet patterns(context);
86 
87  // Mark func.func, func.return, and func.call illegal if they contain any
88  // quantized types.
89  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
90  return typeConverter.isSignatureLegal(op.getFunctionType()) &&
91  typeConverter.isLegal(&op.getBody());
92  });
93  target.addDynamicallyLegalOp<func::ReturnOp>(
94  [&](func::ReturnOp op) { return typeConverter.isLegal(op); });
95  target.addDynamicallyLegalOp<func::CallOp>(
96  [&](func::CallOp op) { return typeConverter.isLegal(op); });
97 
98  // Register conversion patterns
99  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
100  patterns, typeConverter);
101  populateReturnOpTypeConversionPattern(patterns, typeConverter);
102  populateCallOpTypeConversionPattern(patterns, typeConverter);
103 
104  // Apply conversion
105  if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
106  signalPassFailure();
107  }
108 };
109 
110 } // namespace
111 
112 } // namespace quant
113 } // namespace mlir
static MLIRContext * getContext(OpFoldResult val)
static Value materializeConversion(const DialectInlinerInterface *interface, SmallVectorImpl< Operation * > &castOps, OpBuilder &castBuilder, Value arg, Type type, Location conversionLoc)
Utility function used to generate a cast operation from the given interface, or return nullptr if a c...
static bool isLegalType(Type type)
Returns true if the given type is considered as legal for SPIR-V conversion.
Include the generated interface declarations.
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.