MLIR  22.0.0git
StripFuncQuantTypes.cpp
Go to the documentation of this file.
1 //===- StripFuncQuantTypes.cpp - Strip quantized types --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Strips quantized types from function headers.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/PatternMatch.h"
20 
21 namespace mlir {
22 namespace quant {
23 
24 #define GEN_PASS_DEF_STRIPFUNCQUANTTYPES
25 #include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
26 
27 namespace {
28 
29 class QuantizedTypeConverter : public TypeConverter {
30 
31  static Type convertQuantizedType(QuantizedType quantizedType) {
32  return quantizedType.getStorageType();
33  }
34 
35  static Type convertTensorType(TensorType tensorType) {
36  if (auto quantizedType =
37  dyn_cast<QuantizedType>(tensorType.getElementType()))
38  return tensorType.clone(convertQuantizedType(quantizedType));
39  return tensorType;
40  }
41 
42  static Value materializeConversion(OpBuilder &builder, Type type,
43  ValueRange inputs, Location loc) {
44  return quant::StorageCastOp::create(builder, loc, type,
45  llvm::getSingleElement(inputs));
46  }
47 
48 public:
49  explicit QuantizedTypeConverter() {
50  addConversion([](Type type) { return type; });
51  addConversion(convertQuantizedType);
52  addConversion(convertTensorType);
53 
54  addSourceMaterialization(materializeConversion);
55  addTargetMaterialization(materializeConversion);
56  }
57 };
58 
59 // Conversion pass
60 class StripFuncQuantTypes
61  : public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
62 
63 public:
64  void runOnOperation() override {
65 
66  auto moduleOp = cast<ModuleOp>(getOperation());
67  auto *context = &getContext();
68 
69  QuantizedTypeConverter typeConverter;
70  ConversionTarget target(*context);
71  RewritePatternSet patterns(context);
72 
73  // Mark func.func, func.return, and func.call illegal if they contain any
74  // quantized types.
75  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
76  return typeConverter.isSignatureLegal(op.getFunctionType()) &&
77  typeConverter.isLegal(&op.getBody());
78  });
79  target.addDynamicallyLegalOp<func::ReturnOp>(
80  [&](func::ReturnOp op) { return typeConverter.isLegal(op); });
81  target.addDynamicallyLegalOp<func::CallOp>(
82  [&](func::CallOp op) { return typeConverter.isLegal(op); });
83 
84  // Register conversion patterns
85  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
86  patterns, typeConverter);
89 
90  // Apply conversion
91  if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
92  signalPassFailure();
93  }
94 };
95 
96 } // namespace
97 
98 } // namespace quant
99 } // namespace mlir
static MLIRContext * getContext(OpFoldResult val)
static Value materializeConversion(const DialectInlinerInterface *interface, SmallVectorImpl< Operation * > &castOps, OpBuilder &castBuilder, Value arg, Type type, Location conversionLoc)
Utility function used to generate a cast operation from the given interface, or return nullptr if a c...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateCallOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to convert the operand and result types of a CallOp with the ...
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.