28 #define GEN_PASS_DEF_LOWERQUANTOPS
29 #include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
35 Type getScalarType(Type inputType) {
36 if (
auto tensorType = dyn_cast<TensorType>(inputType))
37 return tensorType.getElementType();
44 SmallVector<OpFoldResult>
45 getScalarOrTensorShape(OpBuilder &builder, Location loc, Value input) {
46 if (isa<TensorType>(input.getType()))
54 Type getScalarOrTensorType(Type elementType, Type referenceType) {
55 if (
auto tensorType = dyn_cast<TensorType>(referenceType))
56 return tensorType.clone(elementType);
63 Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar,
65 ArrayRef<OpFoldResult> referenceShape) {
67 auto tensorType = dyn_cast<TensorType>(referenceType);
69 assert(referenceShape.empty());
75 builder.create<tensor::SplatOp>(loc, scalar, referenceShape);
76 return tensorConstant;
92 std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
95 auto *context = builder.getContext();
97 auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
98 Value inputSize = builder.create<shape::NumElementsOp>(
99 loc, builder.getIndexType(), inputShape);
103 auto flatInputShape = builder.create<tensor::FromElementsOp>(
104 loc, flatShapeType, inputSize);
107 auto inputType = cast<UnrankedTensorType>(input.getType());
108 auto elementType = inputType.getElementType();
111 auto flatInput = builder.create<tensor::ReshapeOp>(
112 loc, flatInputType, input, flatInputShape);
113 return std::make_pair(flatInput, inputShape);
138 std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
144 auto *context = builder.getContext();
145 auto indexType = builder.getIndexType();
147 auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
150 auto axisValue = builder.create<arith::ConstantIndexOp>(loc, axis);
151 auto axisNextValue = builder.create<arith::ConstantIndexOp>(loc, axis + 1);
152 auto shapeLeft = builder.create<shape::SplitAtOp>(
153 loc, TypeRange{shapeType, shapeType}, inputShape, axisValue)
155 auto sizeLeft = builder.create<shape::NumElementsOp>(
156 loc, indexType, shapeLeft);
157 auto shapeRight = builder.create<shape::SplitAtOp>(
158 loc, TypeRange{shapeType, shapeType}, inputShape, axisNextValue)
160 auto sizeRight = builder.create<shape::NumElementsOp>(
161 loc, indexType, shapeRight);
164 auto axisSizeValue = builder.create<arith::ConstantIndexOp>(loc, axisSize);
166 auto flatInputShape = builder.create<tensor::FromElementsOp>(
167 loc, flatShapeType, ValueRange{sizeLeft, axisSizeValue, sizeRight});
170 auto inputType = cast<UnrankedTensorType>(input.getType());
171 auto elementType = inputType.getElementType();
173 {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
174 auto flatInput = builder.create<tensor::ReshapeOp>(
175 loc, flatInputType, input, flatInputShape);
177 return std::make_pair(flatInput, inputShape);
188 Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
190 auto inputType = cast<RankedTensorType>(input.getType());
191 auto elementType = inputType.getElementType();
193 return builder.create<tensor::ReshapeOp>(loc, unrankedType, input, inputShape);
205 Value materializePerChannelScales(OpBuilder &builder, Location loc,
206 UniformQuantizedPerAxisType quantizedType) {
207 auto scales = quantizedType.getScales();
208 auto expressedType = quantizedType.getExpressedType();
209 auto scaleAttrs = llvm::map_to_vector(scales, [&](
double scale) -> Attribute {
210 return builder.getFloatAttr(expressedType, scale);
214 return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
226 Value materializePerChannelZeroPoints(
227 OpBuilder &builder, Location loc,
228 UniformQuantizedPerAxisType quantizedType) {
229 auto zeroPoints = quantizedType.getZeroPoints();
230 auto storageType = quantizedType.getStorageType();
231 auto zeroPointAttrs = llvm::map_to_vector(
233 [&](int64_t zeroPoint) -> Attribute {
234 return builder.getIntegerAttr(storageType, zeroPoint);
239 return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
255 Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
256 ArrayRef<OpFoldResult> inputShape,
257 QuantizedType quantizedType) {
260 if (!quantizedType.hasStorageTypeBounds())
264 auto inputType = input.getType();
265 auto storageType = quantizedType.getStorageType();
266 auto storageMinScalar = builder.create<arith::ConstantIntOp>(
267 loc, quantizedType.getStorageTypeMin(), storageType);
268 auto storageMaxScalar = builder.create<arith::ConstantIntOp>(
269 loc, quantizedType.getStorageTypeMax(), storageType);
270 auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar,
271 inputType, inputShape);
272 auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar,
273 inputType, inputShape);
276 if (quantizedType.isSigned()) {
277 input = builder.create<arith::MaxSIOp>(loc, input, storageMin);
278 input = builder.create<arith::MinSIOp>(loc, input, storageMax);
280 input = builder.create<arith::MaxUIOp>(loc, input, storageMin);
281 input = builder.create<arith::MinUIOp>(loc, input, storageMax);
287 Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
288 Type resultType,
bool isSigned) {
290 return builder.create<arith::FPToSIOp>(loc, resultType, input);
291 return builder.create<arith::FPToUIOp>(loc, resultType, input);
295 Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
296 Type resultType,
bool isSigned) {
298 return builder.create<arith::SIToFPOp>(loc, resultType, input);
299 return builder.create<arith::UIToFPOp>(loc, resultType, input);
306 Value quantizeValue(OpBuilder &builder, Location loc, Value input,
307 ArrayRef<OpFoldResult> inputShape, Value scale,
308 Value zeroPoint, QuantizedType quantizedType) {
310 auto inputType = input.getType();
311 scale = getScalarOrTensorConstant(
312 builder, loc, scale, inputType, inputShape);
315 auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
318 Value storedValueFloat = scaledValue;
321 zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
325 zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
327 quantizedType.isSigned());
331 builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint);
335 auto storageScalarOrTensorType =
336 getScalarOrTensorType(quantizedType.getStorageType(), inputType);
337 auto storedValueInt = convertFloatToInteger(
338 builder, loc, storedValueFloat, storageScalarOrTensorType,
339 quantizedType.isSigned());
342 auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt,
343 inputShape, quantizedType);
344 return storedValueClamped;
350 Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
351 ArrayRef<OpFoldResult> inputShape, Value scale,
352 Value zeroPoint, QuantizedType quantizedType) {
354 auto inputType = input.getType();
355 scale = getScalarOrTensorConstant(
356 builder, loc, scale, inputType, inputShape);
359 auto result = convertIntegerToFloat(
360 builder, loc, input, scale.getType(), quantizedType.isSigned());
365 zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
369 zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
371 quantizedType.isSigned());
374 result = builder.create<arith::SubFOp>(loc, result, zeroPoint);
378 result = builder.create<arith::MulFOp>(loc, result, scale);
402 Value convertRanked(OpBuilder &builder, Location loc, Operation *op,
403 Value input, ArrayRef<OpFoldResult> inputShape, Value scale,
404 Value zeroPoint, QuantizedType quantizedType) {
405 if (isa<QuantizeCastOp>(op))
406 return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
408 if (isa<DequantizeCastOp>(op))
409 return dequantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
411 llvm_unreachable(
"unexpected quant op");
426 Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
427 Value input, UniformQuantizedType quantizedType) {
429 auto expressedType = quantizedType.getExpressedType();
430 auto storageType = quantizedType.getStorageType();
432 builder.getFloatAttr(expressedType, quantizedType.getScale());
433 auto scale = builder.create<arith::ConstantOp>(loc, expressedType, scaleAttr);
435 builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
437 builder.create<arith::ConstantOp>(loc, storageType, zeroPointAttr);
439 auto inputShape = getScalarOrTensorShape(builder, loc, input);
440 return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint,
455 Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op,
456 Value input, UniformQuantizedType quantizedType) {
458 bool isUnranked = isa<UnrankedTensorType>(input.getType());
461 std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
464 auto result = convertPerLayerRanked(builder, loc, op, input, quantizedType);
468 result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
485 Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
487 UniformQuantizedPerAxisType quantizedType,
488 int64_t channelAxis) {
489 auto *context = builder.getContext();
491 auto inputType = cast<RankedTensorType>(input.getType());
492 auto inputRank = inputType.getRank();
494 auto scales = materializePerChannelScales(builder, loc, quantizedType);
496 materializePerChannelZeroPoints(builder, loc, quantizedType);
498 auto elementType = isa<FloatType>(inputType.getElementType())
499 ? quantizedType.getStorageType()
500 : quantizedType.getExpressedType();
502 Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType);
504 SmallVector<utils::IteratorType> iteratorTypes(
505 inputRank, utils::IteratorType::parallel);
507 inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
508 SmallVector<AffineMap> indexingMaps{
509 builder.getMultiDimIdentityMap(inputRank),
510 channelAxisAffineMap,
511 channelAxisAffineMap,
512 builder.getMultiDimIdentityMap(inputRank)
514 auto result = builder.create<linalg::GenericOp>(
517 ValueRange{input, scales, zeroPoints},
521 [&](OpBuilder& builder, Location loc, ValueRange args) {
522 assert(args.size() == 4);
523 auto input = args[0];
524 auto scale = args[1];
525 auto zeroPoint = args[2];
527 auto result = convertRanked(builder, loc, op, input, {}, scale,
528 zeroPoint, quantizedType);
530 builder.create<linalg::YieldOp>(loc, result);
548 Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
550 UniformQuantizedPerAxisType quantizedType) {
552 bool isUnranked = isa<UnrankedTensorType>(input.getType());
553 int64_t channelAxis = quantizedType.getQuantizedDimension();
554 int64_t channelAxisSize = (int64_t) quantizedType.getScales().size();
557 std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
558 builder, loc, input, channelAxis, channelAxisSize);
563 auto result = convertPerChannelRanked(builder, loc, op, input, quantizedType,
568 result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
586 Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
587 Value input, Type quantizedType) {
588 if (
auto uniformQuantizedType = dyn_cast<UniformQuantizedType>(quantizedType))
589 return convertPerLayer(builder, loc, op, input, uniformQuantizedType);
591 if (
auto uniformQuantizedPerAxisType =
592 dyn_cast<UniformQuantizedPerAxisType>(quantizedType))
593 return convertPerChannel(builder, loc, op, input,
594 uniformQuantizedPerAxisType);
596 llvm_unreachable(
"unexpected quantized type");
600 struct DequantizeCastOpConversion :
public OpConversionPattern<quant::DequantizeCastOp> {
604 matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor,
605 ConversionPatternRewriter &rewriter)
const override {
606 auto loc = op.getLoc();
607 auto input = op.getInput();
609 cast<QuantizedType>(getScalarType(op.getInput().getType()));
612 auto storageScalarOrTensorType =
613 getScalarOrTensorType(quantizedType.getStorageType(), input.getType());
614 input = rewriter.create<quant::StorageCastOp>(
615 loc, storageScalarOrTensorType, input);
617 auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
619 rewriter.replaceOp(op, result);
625 struct QuantizeCastOpConversion :
public OpConversionPattern<quant::QuantizeCastOp> {
629 matchAndRewrite(quant::QuantizeCastOp op, OpAdaptor adaptor,
630 ConversionPatternRewriter &rewriter)
const override {
631 auto loc = op.getLoc();
632 auto input = op.getInput();
633 auto quantizedType = getScalarType(op.getResult().getType());
636 auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
639 rewriter.replaceOpWithNewOp<quant::StorageCastOp>(
640 op, op.getResult().getType(), result);
645 struct LowerQuantOps :
public impl::LowerQuantOpsBase<LowerQuantOps> {
646 void runOnOperation()
override {
651 target.addLegalOp<quant::StorageCastOp>();
652 target.addIllegalDialect<quant::QuantDialect>();
653 target.addLegalDialect<
655 linalg::LinalgDialect,
657 tensor::TensorDialect
670 DequantizeCastOpConversion,
671 QuantizeCastOpConversion
static MLIRContext * getContext(OpFoldResult val)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
void populateLowerQuantOpsPatterns(RewritePatternSet &patterns)
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.