28#define GEN_PASS_DEF_LOWERQUANTOPS
29#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
36 if (
auto tensorType = dyn_cast<TensorType>(inputType))
37 return tensorType.getElementType();
46 if (isa<TensorType>(input.
getType()))
54Type getScalarOrTensorType(
Type elementType,
Type referenceType) {
55 if (
auto tensorType = dyn_cast<TensorType>(referenceType))
56 return tensorType.clone(elementType);
67 auto tensorType = dyn_cast<TensorType>(referenceType);
69 assert(referenceShape.empty());
75 tensor::SplatOp::create(builder, loc, scalar, referenceShape);
76 return tensorConstant;
97 auto inputShape = shape::ShapeOfOp::create(builder, loc, shapeType, input);
98 Value inputSize = shape::NumElementsOp::create(
103 auto flatInputShape =
104 tensor::FromElementsOp::create(builder, loc, flatShapeType, inputSize);
107 auto inputType = cast<UnrankedTensorType>(input.
getType());
108 auto elementType = inputType.getElementType();
110 RankedTensorType::get({ShapedType::kDynamic}, elementType);
111 auto flatInput = tensor::ReshapeOp::create(builder, loc, flatInputType, input,
113 return std::make_pair(flatInput, inputShape);
138std::pair<Value, Value>
145 auto inputShape = shape::ShapeOfOp::create(builder, loc, shapeType, input);
151 shape::SplitAtOp::create(builder, loc,
TypeRange{shapeType, shapeType},
152 inputShape, axisValue)
155 shape::NumElementsOp::create(builder, loc, indexType, shapeLeft);
157 shape::SplitAtOp::create(builder, loc,
TypeRange{shapeType, shapeType},
158 inputShape, axisNextValue)
161 shape::NumElementsOp::create(builder, loc, indexType, shapeRight);
166 auto flatInputShape = tensor::FromElementsOp::create(
167 builder, loc, flatShapeType,
168 ValueRange{sizeLeft, axisSizeValue, sizeRight});
171 auto inputType = cast<UnrankedTensorType>(input.
getType());
172 auto elementType = inputType.getElementType();
173 auto flatInputType = RankedTensorType::get(
174 {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
175 auto flatInput = tensor::ReshapeOp::create(builder, loc, flatInputType, input,
178 return std::make_pair(flatInput, inputShape);
189Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
191 auto inputType = cast<RankedTensorType>(input.getType());
192 auto elementType = inputType.getElementType();
193 auto unrankedType = UnrankedTensorType::get(elementType);
194 return tensor::ReshapeOp::create(builder, loc, unrankedType, input,
207Value materializePerChannelScales(OpBuilder &builder, Location loc,
209 auto scales = quantizedType.getScales();
210 auto expressedType = quantizedType.getExpressedType();
211 auto scaleAttrs = llvm::map_to_vector(scales, [&](
double scale) -> Attribute {
212 return builder.getFloatAttr(expressedType, scale);
215 RankedTensorType::get({(int64_t)scales.size()}, expressedType);
217 return arith::ConstantOp::create(builder, loc, tensorType, scalesAttr);
229Value materializePerChannelZeroPoints(
230 OpBuilder &builder, Location loc,
232 auto zeroPoints = quantizedType.getZeroPoints();
233 auto storageType = quantizedType.getStorageType();
234 auto zeroPointAttrs =
235 llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> Attribute {
236 return builder.getIntegerAttr(storageType, zeroPoint);
239 RankedTensorType::get({(int64_t)zeroPoints.size()}, storageType);
241 return arith::ConstantOp::create(builder, loc, tensorType, zeroPointsAttr);
253Value materializeSubChannelScales(
254 OpBuilder &builder, Location loc,
256 auto scales = quantizedType.getScales();
257 auto expressedType = quantizedType.getExpressedType();
258 auto scaleAttrs = llvm::map_to_vector(
259 scales.getValues<APFloat>(), [&](APFloat scale) -> Attribute {
260 return builder.getFloatAttr(expressedType, scale);
263 RankedTensorType::get(scales.getType().getShape(), expressedType);
265 return arith::ConstantOp::create(builder, loc, tensorType, scalesAttr);
277Value materializeSubChannelZeroPoints(
278 OpBuilder &builder, Location loc,
280 auto zeroPoints = quantizedType.getZeroPoints();
281 auto storageType = quantizedType.getStorageType();
282 auto zeroPointAttrs = llvm::map_to_vector(
283 zeroPoints.getValues<APInt>(), [&](APInt zeroPoint) -> Attribute {
284 return builder.getIntegerAttr(storageType, zeroPoint);
287 RankedTensorType::get(zeroPoints.getType().getShape(), storageType);
289 return arith::ConstantOp::create(builder, loc, tensorType, zeroPointsAttr);
305Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
306 ArrayRef<OpFoldResult> inputShape,
310 if (!quantizedType.hasStorageTypeBounds())
314 auto inputType = input.getType();
315 auto storageType = quantizedType.getStorageType();
317 builder, loc, storageType, quantizedType.getStorageTypeMin());
319 builder, loc, storageType, quantizedType.getStorageTypeMax());
320 auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar,
321 inputType, inputShape);
322 auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar,
323 inputType, inputShape);
326 if (quantizedType.isSigned()) {
327 input = arith::MaxSIOp::create(builder, loc, input, storageMin);
328 input = arith::MinSIOp::create(builder, loc, input, storageMax);
330 input = arith::MaxUIOp::create(builder, loc, input, storageMin);
331 input = arith::MinUIOp::create(builder, loc, input, storageMax);
337Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
338 Type resultType,
bool isSigned) {
340 return arith::FPToSIOp::create(builder, loc, resultType, input);
341 return arith::FPToUIOp::create(builder, loc, resultType, input);
345Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
346 Type resultType,
bool isSigned) {
348 return arith::SIToFPOp::create(builder, loc, resultType, input);
349 return arith::UIToFPOp::create(builder, loc, resultType, input);
356Value quantizeValue(OpBuilder &builder, Location loc, Value input,
357 ArrayRef<OpFoldResult> inputShape, Value scale,
360 auto inputType = input.getType();
361 scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
364 auto scaledValue = arith::DivFOp::create(builder, loc, input, scale);
367 Value storedValueFloat = scaledValue;
370 zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
374 zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
375 quantizedType.isSigned());
379 arith::AddFOp::create(builder, loc, scaledValue, zeroPoint);
383 auto storageScalarOrTensorType =
384 getScalarOrTensorType(quantizedType.getStorageType(), inputType);
385 auto storedValueInt = convertFloatToInteger(builder, loc, storedValueFloat,
386 storageScalarOrTensorType,
387 quantizedType.isSigned());
390 auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt,
391 inputShape, quantizedType);
392 return storedValueClamped;
398Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
399 ArrayRef<OpFoldResult> inputShape, Value scale,
402 auto inputType = input.getType();
403 scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
406 auto result = convertIntegerToFloat(builder, loc, input, scale.getType(),
407 quantizedType.isSigned());
412 zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
416 zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
417 quantizedType.isSigned());
420 result = arith::SubFOp::create(builder, loc,
result, zeroPoint);
424 result = arith::MulFOp::create(builder, loc,
result, scale);
448Value convertRanked(OpBuilder &builder, Location loc, Operation *op,
449 Value input, ArrayRef<OpFoldResult> inputShape, Value scale,
451 if (isa<QuantizeCastOp>(op))
452 return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
454 if (isa<DequantizeCastOp>(op))
455 return dequantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
457 llvm_unreachable(
"unexpected quant op");
472Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
475 auto expressedType = quantizedType.getExpressedType();
476 auto storageType = quantizedType.getStorageType();
478 builder.getFloatAttr(expressedType, quantizedType.getScale());
480 arith::ConstantOp::create(builder, loc, expressedType, scaleAttr);
482 builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
484 arith::ConstantOp::create(builder, loc, storageType, zeroPointAttr);
486 auto inputShape = getScalarOrTensorShape(builder, loc, input);
487 return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint,
502Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op,
505 bool isUnranked = isa<UnrankedTensorType>(input.getType());
508 std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
511 auto result = convertPerLayerRanked(builder, loc, op, input, quantizedType);
515 result = restoreUnrankedTensorShape(builder, loc,
result, inputShape);
532Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
535 int64_t channelAxis) {
536 auto *context = builder.getContext();
538 auto inputType = cast<RankedTensorType>(input.getType());
539 auto inputRank = inputType.getRank();
541 auto scales = materializePerChannelScales(builder, loc, quantizedType);
543 materializePerChannelZeroPoints(builder, loc, quantizedType);
545 auto elementType = isa<FloatType>(inputType.getElementType())
546 ? quantizedType.getStorageType()
547 : quantizedType.getExpressedType();
549 Value init = tensor::EmptyOp::create(builder, loc, initShape, elementType);
551 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
552 utils::IteratorType::parallel);
554 inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
555 SmallVector<AffineMap> indexingMaps{
556 builder.getMultiDimIdentityMap(inputRank), channelAxisAffineMap,
557 channelAxisAffineMap, builder.getMultiDimIdentityMap(inputRank)};
558 auto result = linalg::GenericOp::create(
563 indexingMaps, iteratorTypes,
564 [&](OpBuilder &builder, Location loc,
ValueRange args) {
565 assert(args.size() == 4);
566 auto input = args[0];
567 auto scale = args[1];
568 auto zeroPoint = args[2];
571 convertRanked(builder, loc, op, input, {}, scale,
572 zeroPoint, quantizedType);
574 linalg::YieldOp::create(builder, loc,
result);
592Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
596 bool isUnranked = isa<UnrankedTensorType>(input.getType());
597 int64_t channelAxis = quantizedType.getQuantizedDimension();
598 int64_t channelAxisSize = (int64_t)quantizedType.getScales().size();
601 std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
602 builder, loc, input, channelAxis, channelAxisSize);
607 auto result = convertPerChannelRanked(builder, loc, op, input, quantizedType,
612 result = restoreUnrankedTensorShape(builder, loc,
result, inputShape);
628Value convertSubChannel(OpBuilder &builder, Location loc, Operation *op,
631 auto *context = builder.getContext();
633 auto inputType = cast<RankedTensorType>(input.getType());
634 auto inputRank = inputType.getRank();
636 auto scales = materializeSubChannelScales(builder, loc, quantizedType);
638 materializeSubChannelZeroPoints(builder, loc, quantizedType);
640 auto elementType = isa<FloatType>(inputType.getElementType())
641 ? quantizedType.getStorageType()
642 : quantizedType.getExpressedType();
644 Value init = tensor::EmptyOp::create(builder, loc, initShape, elementType);
646 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
647 utils::IteratorType::parallel);
648 const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo =
649 quantizedType.getBlockSizeInfo();
650 SmallVector<AffineExpr> affineExprs(inputRank,
651 builder.getAffineConstantExpr(0));
652 for (
auto [quantizedDimension, blockSize] : blockSizeInfo) {
653 affineExprs[quantizedDimension] =
654 builder.getAffineDimExpr(quantizedDimension).floorDiv(blockSize);
656 auto affineMap =
AffineMap::get(inputRank, 0, affineExprs, context);
657 SmallVector<AffineMap> indexingMaps{
658 builder.getMultiDimIdentityMap(inputRank), affineMap, affineMap,
659 builder.getMultiDimIdentityMap(inputRank)};
660 auto result = linalg::GenericOp::create(
665 indexingMaps, iteratorTypes,
666 [&](OpBuilder &builder, Location loc,
ValueRange args) {
667 assert(args.size() == 4);
668 auto input = args[0];
669 auto scale = args[1];
670 auto zeroPoint = args[2];
673 convertRanked(builder, loc, op, input, {}, scale,
674 zeroPoint, quantizedType);
676 linalg::YieldOp::create(builder, loc,
result);
698 if (
auto uniformQuantizedType = dyn_cast<UniformQuantizedType>(quantizedType))
699 return convertPerLayer(builder, loc, op, input, uniformQuantizedType);
701 if (
auto uniformQuantizedPerAxisType =
702 dyn_cast<UniformQuantizedPerAxisType>(quantizedType))
703 return convertPerChannel(builder, loc, op, input,
704 uniformQuantizedPerAxisType);
706 if (
auto uniformQuantizedSubChannelType =
707 dyn_cast<UniformQuantizedSubChannelType>(quantizedType))
708 return convertSubChannel(builder, loc, op, input,
709 uniformQuantizedSubChannelType);
711 llvm_unreachable(
"unexpected quantized type");
715struct DequantizeCastOpConversion
716 :
public OpConversionPattern<quant::DequantizeCastOp> {
717 using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
720 matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor,
721 ConversionPatternRewriter &rewriter)
const override {
722 auto loc = op.getLoc();
723 auto input = op.getInput();
725 cast<QuantizedType>(getScalarType(op.getInput().getType()));
728 auto storageScalarOrTensorType =
729 getScalarOrTensorType(quantizedType.getStorageType(), input.getType());
730 input = quant::StorageCastOp::create(rewriter, loc,
731 storageScalarOrTensorType, input);
733 auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
735 rewriter.replaceOp(op,
result);
741struct QuantizeCastOpConversion
742 :
public OpConversionPattern<quant::QuantizeCastOp> {
743 using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern;
746 matchAndRewrite(quant::QuantizeCastOp op, OpAdaptor adaptor,
747 ConversionPatternRewriter &rewriter)
const override {
748 auto loc = op.getLoc();
749 auto input = op.getInput();
750 auto quantizedType = getScalarType(op.getResult().getType());
753 auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
756 rewriter.replaceOpWithNewOp<quant::StorageCastOp>(
762struct LowerQuantOps :
public impl::LowerQuantOpsBase<LowerQuantOps> {
763 void runOnOperation()
override {
768 target.addLegalOp<quant::StorageCastOp>();
769 target.addIllegalDialect<quant::QuantDialect>();
770 target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
771 shape::ShapeDialect, tensor::TensorDialect>();
773 if (
failed(applyPartialConversion(getOperation(),
target,
774 std::move(patterns))))
782 patterns.add<DequantizeCastOpConversion, QuantizeCastOpConversion>(
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Base class for all quantized types known to this dialect.
void populateLowerQuantOpsPatterns(RewritePatternSet &patterns)
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns