28 #define GEN_PASS_DEF_LOWERQUANTOPS
29 #include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
35 Type getScalarType(Type inputType) {
36 if (
auto tensorType = dyn_cast<TensorType>(inputType))
37 return tensorType.getElementType();
44 SmallVector<OpFoldResult> getScalarOrTensorShape(OpBuilder &builder,
45 Location loc, Value input) {
46 if (isa<TensorType>(input.getType()))
54 Type getScalarOrTensorType(Type elementType, Type referenceType) {
55 if (
auto tensorType = dyn_cast<TensorType>(referenceType))
56 return tensorType.clone(elementType);
63 Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar,
65 ArrayRef<OpFoldResult> referenceShape) {
67 auto tensorType = dyn_cast<TensorType>(referenceType);
69 assert(referenceShape.empty());
75 tensor::SplatOp::create(builder, loc, scalar, referenceShape);
76 return tensorConstant;
92 std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
95 auto *context = builder.getContext();
97 auto inputShape = shape::ShapeOfOp::create(builder, loc, shapeType, input);
98 Value inputSize = shape::NumElementsOp::create(
99 builder, loc, builder.getIndexType(), inputShape);
103 auto flatInputShape =
104 tensor::FromElementsOp::create(builder, loc, flatShapeType, inputSize);
107 auto inputType = cast<UnrankedTensorType>(input.getType());
108 auto elementType = inputType.getElementType();
111 auto flatInput = tensor::ReshapeOp::create(builder, loc, flatInputType, input,
113 return std::make_pair(flatInput, inputShape);
138 std::pair<Value, Value>
139 flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input,
140 int64_t axis, int64_t axisSize) {
142 auto *context = builder.getContext();
143 auto indexType = builder.getIndexType();
145 auto inputShape = shape::ShapeOfOp::create(builder, loc, shapeType, input);
151 shape::SplitAtOp::create(builder, loc, TypeRange{shapeType, shapeType},
152 inputShape, axisValue)
155 shape::NumElementsOp::create(builder, loc, indexType, shapeLeft);
157 shape::SplitAtOp::create(builder, loc, TypeRange{shapeType, shapeType},
158 inputShape, axisNextValue)
161 shape::NumElementsOp::create(builder, loc, indexType, shapeRight);
166 auto flatInputShape = tensor::FromElementsOp::create(
167 builder, loc, flatShapeType,
168 ValueRange{sizeLeft, axisSizeValue, sizeRight});
171 auto inputType = cast<UnrankedTensorType>(input.getType());
172 auto elementType = inputType.getElementType();
174 {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
175 auto flatInput = tensor::ReshapeOp::create(builder, loc, flatInputType, input,
178 return std::make_pair(flatInput, inputShape);
189 Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
191 auto inputType = cast<RankedTensorType>(input.getType());
192 auto elementType = inputType.getElementType();
194 return tensor::ReshapeOp::create(builder, loc, unrankedType, input,
207 Value materializePerChannelScales(OpBuilder &builder, Location loc,
208 UniformQuantizedPerAxisType quantizedType) {
209 auto scales = quantizedType.getScales();
210 auto expressedType = quantizedType.getExpressedType();
211 auto scaleAttrs = llvm::map_to_vector(scales, [&](
double scale) -> Attribute {
212 return builder.getFloatAttr(expressedType, scale);
217 return arith::ConstantOp::create(builder, loc, tensorType, scalesAttr);
229 Value materializePerChannelZeroPoints(
230 OpBuilder &builder, Location loc,
231 UniformQuantizedPerAxisType quantizedType) {
232 auto zeroPoints = quantizedType.getZeroPoints();
233 auto storageType = quantizedType.getStorageType();
234 auto zeroPointAttrs =
235 llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> Attribute {
236 return builder.getIntegerAttr(storageType, zeroPoint);
241 return arith::ConstantOp::create(builder, loc, tensorType, zeroPointsAttr);
253 Value materializeSubChannelScales(
254 OpBuilder &builder, Location loc,
255 UniformQuantizedSubChannelType quantizedType) {
256 auto scales = quantizedType.getScales();
257 auto expressedType = quantizedType.getExpressedType();
258 auto scaleAttrs = llvm::map_to_vector(
259 scales.getValues<APFloat>(), [&](APFloat scale) -> Attribute {
260 return builder.getFloatAttr(expressedType, scale);
265 return arith::ConstantOp::create(builder, loc, tensorType, scalesAttr);
277 Value materializeSubChannelZeroPoints(
278 OpBuilder &builder, Location loc,
279 UniformQuantizedSubChannelType quantizedType) {
280 auto zeroPoints = quantizedType.getZeroPoints();
281 auto storageType = quantizedType.getStorageType();
282 auto zeroPointAttrs = llvm::map_to_vector(
283 zeroPoints.getValues<APInt>(), [&](APInt zeroPoint) -> Attribute {
284 return builder.getIntegerAttr(storageType, zeroPoint);
289 return arith::ConstantOp::create(builder, loc, tensorType, zeroPointsAttr);
305 Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
306 ArrayRef<OpFoldResult> inputShape,
307 QuantizedType quantizedType) {
310 if (!quantizedType.hasStorageTypeBounds())
314 auto inputType = input.getType();
315 auto storageType = quantizedType.getStorageType();
317 builder, loc, storageType, quantizedType.getStorageTypeMin());
319 builder, loc, storageType, quantizedType.getStorageTypeMax());
320 auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar,
321 inputType, inputShape);
322 auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar,
323 inputType, inputShape);
326 if (quantizedType.isSigned()) {
327 input = arith::MaxSIOp::create(builder, loc, input, storageMin);
328 input = arith::MinSIOp::create(builder, loc, input, storageMax);
330 input = arith::MaxUIOp::create(builder, loc, input, storageMin);
331 input = arith::MinUIOp::create(builder, loc, input, storageMax);
337 Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
338 Type resultType,
bool isSigned) {
340 return arith::FPToSIOp::create(builder, loc, resultType, input);
341 return arith::FPToUIOp::create(builder, loc, resultType, input);
345 Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
346 Type resultType,
bool isSigned) {
348 return arith::SIToFPOp::create(builder, loc, resultType, input);
349 return arith::UIToFPOp::create(builder, loc, resultType, input);
356 Value quantizeValue(OpBuilder &builder, Location loc, Value input,
357 ArrayRef<OpFoldResult> inputShape, Value scale,
358 Value zeroPoint, QuantizedType quantizedType) {
360 auto inputType = input.getType();
361 scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
364 auto scaledValue = arith::DivFOp::create(builder, loc, input, scale);
367 Value storedValueFloat = scaledValue;
370 zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
374 zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
375 quantizedType.isSigned());
379 arith::AddFOp::create(builder, loc, scaledValue, zeroPoint);
383 auto storageScalarOrTensorType =
384 getScalarOrTensorType(quantizedType.getStorageType(), inputType);
385 auto storedValueInt = convertFloatToInteger(builder, loc, storedValueFloat,
386 storageScalarOrTensorType,
387 quantizedType.isSigned());
390 auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt,
391 inputShape, quantizedType);
392 return storedValueClamped;
398 Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
399 ArrayRef<OpFoldResult> inputShape, Value scale,
400 Value zeroPoint, QuantizedType quantizedType) {
402 auto inputType = input.getType();
403 scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
406 auto result = convertIntegerToFloat(builder, loc, input, scale.getType(),
407 quantizedType.isSigned());
412 zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
416 zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
417 quantizedType.isSigned());
420 result = arith::SubFOp::create(builder, loc, result, zeroPoint);
424 result = arith::MulFOp::create(builder, loc, result, scale);
448 Value convertRanked(OpBuilder &builder, Location loc, Operation *op,
449 Value input, ArrayRef<OpFoldResult> inputShape, Value scale,
450 Value zeroPoint, QuantizedType quantizedType) {
451 if (isa<QuantizeCastOp>(op))
452 return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
454 if (isa<DequantizeCastOp>(op))
455 return dequantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
457 llvm_unreachable(
"unexpected quant op");
472 Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
473 Value input, UniformQuantizedType quantizedType) {
475 auto expressedType = quantizedType.getExpressedType();
476 auto storageType = quantizedType.getStorageType();
478 builder.getFloatAttr(expressedType, quantizedType.getScale());
480 arith::ConstantOp::create(builder, loc, expressedType, scaleAttr);
482 builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
484 arith::ConstantOp::create(builder, loc, storageType, zeroPointAttr);
486 auto inputShape = getScalarOrTensorShape(builder, loc, input);
487 return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint,
502 Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op,
503 Value input, UniformQuantizedType quantizedType) {
505 bool isUnranked = isa<UnrankedTensorType>(input.getType());
508 std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
511 auto result = convertPerLayerRanked(builder, loc, op, input, quantizedType);
515 result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
532 Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
534 UniformQuantizedPerAxisType quantizedType,
535 int64_t channelAxis) {
536 auto *context = builder.getContext();
538 auto inputType = cast<RankedTensorType>(input.getType());
539 auto inputRank = inputType.getRank();
541 auto scales = materializePerChannelScales(builder, loc, quantizedType);
543 materializePerChannelZeroPoints(builder, loc, quantizedType);
545 auto elementType = isa<FloatType>(inputType.getElementType())
546 ? quantizedType.getStorageType()
547 : quantizedType.getExpressedType();
549 Value init = tensor::EmptyOp::create(builder, loc, initShape, elementType);
551 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
552 utils::IteratorType::parallel);
554 inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
555 SmallVector<AffineMap> indexingMaps{
556 builder.getMultiDimIdentityMap(inputRank), channelAxisAffineMap,
557 channelAxisAffineMap, builder.getMultiDimIdentityMap(inputRank)};
558 auto result = linalg::GenericOp::create(
561 ValueRange{input, scales, zeroPoints},
563 indexingMaps, iteratorTypes,
564 [&](OpBuilder &builder, Location loc, ValueRange args) {
565 assert(args.size() == 4);
566 auto input = args[0];
567 auto scale = args[1];
568 auto zeroPoint = args[2];
571 convertRanked(builder, loc, op, input, {}, scale,
572 zeroPoint, quantizedType);
574 linalg::YieldOp::create(builder, loc, result);
592 Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
594 UniformQuantizedPerAxisType quantizedType) {
596 bool isUnranked = isa<UnrankedTensorType>(input.getType());
597 int64_t channelAxis = quantizedType.getQuantizedDimension();
598 int64_t channelAxisSize = (int64_t)quantizedType.getScales().size();
601 std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
602 builder, loc, input, channelAxis, channelAxisSize);
607 auto result = convertPerChannelRanked(builder, loc, op, input, quantizedType,
612 result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
628 Value convertSubChannel(OpBuilder &builder, Location loc, Operation *op,
630 UniformQuantizedSubChannelType quantizedType) {
631 auto *context = builder.getContext();
633 auto inputType = cast<RankedTensorType>(input.getType());
634 auto inputRank = inputType.getRank();
636 auto scales = materializeSubChannelScales(builder, loc, quantizedType);
638 materializeSubChannelZeroPoints(builder, loc, quantizedType);
640 auto elementType = isa<FloatType>(inputType.getElementType())
641 ? quantizedType.getStorageType()
642 : quantizedType.getExpressedType();
644 Value init = tensor::EmptyOp::create(builder, loc, initShape, elementType);
646 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
647 utils::IteratorType::parallel);
648 const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo =
649 quantizedType.getBlockSizeInfo();
650 SmallVector<AffineExpr> affineExprs(inputRank,
651 builder.getAffineConstantExpr(0));
652 for (
auto [quantizedDimension, blockSize] : blockSizeInfo) {
653 affineExprs[quantizedDimension] =
654 builder.getAffineDimExpr(quantizedDimension).floorDiv(blockSize);
656 auto affineMap =
AffineMap::get(inputRank, 0, affineExprs, context);
657 SmallVector<AffineMap> indexingMaps{
658 builder.getMultiDimIdentityMap(inputRank), affineMap, affineMap,
659 builder.getMultiDimIdentityMap(inputRank)};
660 auto result = linalg::GenericOp::create(
663 ValueRange{input, scales, zeroPoints},
665 indexingMaps, iteratorTypes,
666 [&](OpBuilder &builder, Location loc, ValueRange args) {
667 assert(args.size() == 4);
668 auto input = args[0];
669 auto scale = args[1];
670 auto zeroPoint = args[2];
673 convertRanked(builder, loc, op, input, {}, scale,
674 zeroPoint, quantizedType);
676 linalg::YieldOp::create(builder, loc, result);
696 Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
697 Value input, Type quantizedType) {
698 if (
auto uniformQuantizedType = dyn_cast<UniformQuantizedType>(quantizedType))
699 return convertPerLayer(builder, loc, op, input, uniformQuantizedType);
701 if (
auto uniformQuantizedPerAxisType =
702 dyn_cast<UniformQuantizedPerAxisType>(quantizedType))
703 return convertPerChannel(builder, loc, op, input,
704 uniformQuantizedPerAxisType);
706 if (
auto uniformQuantizedSubChannelType =
707 dyn_cast<UniformQuantizedSubChannelType>(quantizedType))
708 return convertSubChannel(builder, loc, op, input,
709 uniformQuantizedSubChannelType);
711 llvm_unreachable(
"unexpected quantized type");
715 struct DequantizeCastOpConversion
716 :
public OpConversionPattern<quant::DequantizeCastOp> {
720 matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor,
721 ConversionPatternRewriter &rewriter)
const override {
722 auto loc = op.getLoc();
723 auto input = op.getInput();
725 cast<QuantizedType>(getScalarType(op.getInput().getType()));
728 auto storageScalarOrTensorType =
729 getScalarOrTensorType(quantizedType.getStorageType(), input.getType());
730 input = quant::StorageCastOp::create(rewriter, loc,
731 storageScalarOrTensorType, input);
733 auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
735 rewriter.replaceOp(op, result);
741 struct QuantizeCastOpConversion
742 :
public OpConversionPattern<quant::QuantizeCastOp> {
746 matchAndRewrite(quant::QuantizeCastOp op, OpAdaptor adaptor,
747 ConversionPatternRewriter &rewriter)
const override {
748 auto loc = op.getLoc();
749 auto input = op.getInput();
750 auto quantizedType = getScalarType(op.getResult().getType());
753 auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
756 rewriter.replaceOpWithNewOp<quant::StorageCastOp>(
757 op, op.getResult().getType(), result);
762 struct LowerQuantOps :
public impl::LowerQuantOpsBase<LowerQuantOps> {
763 void runOnOperation()
override {
768 target.addLegalOp<quant::StorageCastOp>();
769 target.addIllegalDialect<quant::QuantDialect>();
770 target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
771 shape::ShapeDialect, tensor::TensorDialect>();
782 patterns.add<DequantizeCastOpConversion, QuantizeCastOpConversion>(
static MLIRContext * getContext(OpFoldResult val)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
void populateLowerQuantOpsPatterns(RewritePatternSet &patterns)
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.