MLIR  14.0.0git
QuantOps.cpp
Go to the documentation of this file.
1 //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "TypeDetail.h"
11 
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/MLIRContext.h"
15 #include "mlir/IR/Matchers.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "llvm/ADT/Twine.h"
19 #include "llvm/Support/MathExtras.h"
20 #include <numeric>
21 
22 using namespace mlir;
23 using namespace mlir::quant;
24 using namespace mlir::quant::detail;
25 
26 #include "mlir/Dialect/Quant/QuantOpsDialect.cpp.inc"
27 
28 void QuantizationDialect::initialize() {
31  addOperations<
32 #define GET_OP_LIST
33 #include "mlir/Dialect/Quant/QuantOps.cpp.inc"
34  >();
35 }
36 
37 OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
38  // Matches x -> [scast -> scast] -> y, replacing the second scast with the
39  // value of x if the casts invert each other.
40  auto srcScastOp = arg().getDefiningOp<StorageCastOp>();
41  if (!srcScastOp || srcScastOp.arg().getType() != getType())
42  return OpFoldResult();
43  return srcScastOp.arg();
44 }
45 
46 /// The quantization specification should match the expressed type.
47 static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) {
48  if (auto typeAttr = quantSpec.dyn_cast<TypeAttr>()) {
49  Type spec = typeAttr.getValue();
50  if (spec.isa<TensorType, VectorType>())
51  return false;
52 
53  // The spec should be either a quantized type which is compatible to the
54  // expressed type, or a primitive type which is as same as the
55  // (element type of) the expressed type.
56  if (auto quantizedType = spec.dyn_cast<QuantizedType>())
57  return quantizedType.isCompatibleExpressedType(expressed);
58 
59  if (auto tensorType = expressed.dyn_cast<TensorType>())
60  return spec == tensorType.getElementType();
61 
62  if (auto vectorType = expressed.dyn_cast<VectorType>())
63  return spec == vectorType.getElementType();
64  }
65  return false;
66 }
67 
68 static LogicalResult verifyRegionOp(QuantizeRegionOp op) {
69  // There are specifications for both inputs and outputs.
70  if (op.getNumOperands() != op.input_specs().size() ||
71  op.getNumResults() != op.output_specs().size())
72  return op.emitOpError(
73  "has unmatched operands/results number and spec attributes number");
74 
75  // Verify that quantization specifications are valid.
76  for (auto input : llvm::zip(op.getOperandTypes(), op.input_specs())) {
77  Type inputType = std::get<0>(input);
78  Attribute inputSpec = std::get<1>(input);
79  if (!isValidQuantizationSpec(inputSpec, inputType)) {
80  return op.emitOpError() << "has incompatible specification " << inputSpec
81  << " and input type " << inputType;
82  }
83  }
84 
85  for (auto result : llvm::zip(op.getResultTypes(), op.output_specs())) {
86  Type outputType = std::get<0>(result);
87  Attribute outputSpec = std::get<1>(result);
88  if (!isValidQuantizationSpec(outputSpec, outputType)) {
89  return op.emitOpError() << "has incompatible specification " << outputSpec
90  << " and output type " << outputType;
91  }
92  }
93  return success();
94 }
95 
96 #define GET_OP_CLASSES
97 #include "mlir/Dialect/Quant/QuantOps.cpp.inc"
Include the generated interface declarations.
This class represents a single result from folding an operation.
Definition: OpDefinition.h:214
Represents a family of uniform, quantized types.
Definition: QuantTypes.h:256
A quantized type that infers its range from given min/max values.
Definition: QuantTypes.h:383
static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed)
The quantization specification should match the expressed type.
Definition: QuantOps.cpp:47
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
A quantized type that maps storage to/from expressed types in an unspecified way. ...
Definition: QuantTypes.h:197
U dyn_cast() const
Definition: Types.h:244
Attributes are known-constant values of operations.
Definition: Attributes.h:24
Represents per-axis (also known as per-channel quantization).
Definition: QuantTypes.h:314
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:73
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
U dyn_cast() const
Definition: Attributes.h:117
static LogicalResult verifyRegionOp(QuantizeRegionOp op)
Definition: QuantOps.cpp:68
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
Base class for all quantized types known to this dialect.
Definition: QuantTypes.h:52
bool isa() const
Definition: Types.h:234