28 #define GEN_PASS_DEF_LOWERQUANTOPS
29 #include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
35 Type getScalarType(Type inputType) {
36 if (
auto tensorType = dyn_cast<TensorType>(inputType))
37 return tensorType.getElementType();
44 SmallVector<OpFoldResult> getScalarOrTensorShape(OpBuilder &builder,
45 Location loc, Value input) {
46 if (isa<TensorType>(input.getType()))
54 Type getScalarOrTensorType(Type elementType, Type referenceType) {
55 if (
auto tensorType = dyn_cast<TensorType>(referenceType))
56 return tensorType.clone(elementType);
63 Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar,
65 ArrayRef<OpFoldResult> referenceShape) {
67 auto tensorType = dyn_cast<TensorType>(referenceType);
69 assert(referenceShape.empty());
75 builder.create<tensor::SplatOp>(loc, scalar, referenceShape);
76 return tensorConstant;
92 std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
95 auto *context = builder.getContext();
97 auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
98 Value inputSize = builder.create<shape::NumElementsOp>(
99 loc, builder.getIndexType(), inputShape);
103 auto flatInputShape =
104 builder.create<tensor::FromElementsOp>(loc, flatShapeType, inputSize);
107 auto inputType = cast<UnrankedTensorType>(input.getType());
108 auto elementType = inputType.getElementType();
111 auto flatInput = builder.create<tensor::ReshapeOp>(loc, flatInputType, input,
113 return std::make_pair(flatInput, inputShape);
138 std::pair<Value, Value>
139 flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input,
140 int64_t axis, int64_t axisSize) {
142 auto *context = builder.getContext();
143 auto indexType = builder.getIndexType();
145 auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
148 auto axisValue = builder.create<arith::ConstantIndexOp>(loc, axis);
149 auto axisNextValue = builder.create<arith::ConstantIndexOp>(loc, axis + 1);
152 .create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
153 inputShape, axisValue)
156 builder.create<shape::NumElementsOp>(loc, indexType, shapeLeft);
159 .create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
160 inputShape, axisNextValue)
163 builder.create<shape::NumElementsOp>(loc, indexType, shapeRight);
166 auto axisSizeValue = builder.create<arith::ConstantIndexOp>(loc, axisSize);
168 auto flatInputShape = builder.create<tensor::FromElementsOp>(
169 loc, flatShapeType, ValueRange{sizeLeft, axisSizeValue, sizeRight});
172 auto inputType = cast<UnrankedTensorType>(input.getType());
173 auto elementType = inputType.getElementType();
175 {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
176 auto flatInput = builder.create<tensor::ReshapeOp>(loc, flatInputType, input,
179 return std::make_pair(flatInput, inputShape);
190 Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
192 auto inputType = cast<RankedTensorType>(input.getType());
193 auto elementType = inputType.getElementType();
195 return builder.create<tensor::ReshapeOp>(loc, unrankedType, input,
208 Value materializePerChannelScales(OpBuilder &builder, Location loc,
209 UniformQuantizedPerAxisType quantizedType) {
210 auto scales = quantizedType.getScales();
211 auto expressedType = quantizedType.getExpressedType();
212 auto scaleAttrs = llvm::map_to_vector(scales, [&](
double scale) -> Attribute {
213 return builder.getFloatAttr(expressedType, scale);
218 return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
230 Value materializePerChannelZeroPoints(
231 OpBuilder &builder, Location loc,
232 UniformQuantizedPerAxisType quantizedType) {
233 auto zeroPoints = quantizedType.getZeroPoints();
234 auto storageType = quantizedType.getStorageType();
235 auto zeroPointAttrs =
236 llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> Attribute {
237 return builder.getIntegerAttr(storageType, zeroPoint);
242 return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
254 Value materializeSubChannelScales(
255 OpBuilder &builder, Location loc,
256 UniformQuantizedSubChannelType quantizedType) {
257 auto scales = quantizedType.getScales();
258 auto expressedType = quantizedType.getExpressedType();
259 auto scaleAttrs = llvm::map_to_vector(
260 scales.getValues<APFloat>(), [&](APFloat scale) -> Attribute {
261 return builder.getFloatAttr(expressedType, scale);
266 return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
278 Value materializeSubChannelZeroPoints(
279 OpBuilder &builder, Location loc,
280 UniformQuantizedSubChannelType quantizedType) {
281 auto zeroPoints = quantizedType.getZeroPoints();
282 auto storageType = quantizedType.getStorageType();
283 auto zeroPointAttrs = llvm::map_to_vector(
284 zeroPoints.getValues<APInt>(), [&](APInt zeroPoint) -> Attribute {
285 return builder.getIntegerAttr(storageType, zeroPoint);
290 return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
306 Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
307 ArrayRef<OpFoldResult> inputShape,
308 QuantizedType quantizedType) {
311 if (!quantizedType.hasStorageTypeBounds())
315 auto inputType = input.getType();
316 auto storageType = quantizedType.getStorageType();
317 auto storageMinScalar = builder.create<arith::ConstantIntOp>(
318 loc, quantizedType.getStorageTypeMin(), storageType);
319 auto storageMaxScalar = builder.create<arith::ConstantIntOp>(
320 loc, quantizedType.getStorageTypeMax(), storageType);
321 auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar,
322 inputType, inputShape);
323 auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar,
324 inputType, inputShape);
327 if (quantizedType.isSigned()) {
328 input = builder.create<arith::MaxSIOp>(loc, input, storageMin);
329 input = builder.create<arith::MinSIOp>(loc, input, storageMax);
331 input = builder.create<arith::MaxUIOp>(loc, input, storageMin);
332 input = builder.create<arith::MinUIOp>(loc, input, storageMax);
338 Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
339 Type resultType,
bool isSigned) {
341 return builder.create<arith::FPToSIOp>(loc, resultType, input);
342 return builder.create<arith::FPToUIOp>(loc, resultType, input);
346 Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
347 Type resultType,
bool isSigned) {
349 return builder.create<arith::SIToFPOp>(loc, resultType, input);
350 return builder.create<arith::UIToFPOp>(loc, resultType, input);
357 Value quantizeValue(OpBuilder &builder, Location loc, Value input,
358 ArrayRef<OpFoldResult> inputShape, Value scale,
359 Value zeroPoint, QuantizedType quantizedType) {
361 auto inputType = input.getType();
362 scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
365 auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
368 Value storedValueFloat = scaledValue;
371 zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
375 zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
376 quantizedType.isSigned());
380 builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint);
384 auto storageScalarOrTensorType =
385 getScalarOrTensorType(quantizedType.getStorageType(), inputType);
386 auto storedValueInt = convertFloatToInteger(builder, loc, storedValueFloat,
387 storageScalarOrTensorType,
388 quantizedType.isSigned());
391 auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt,
392 inputShape, quantizedType);
393 return storedValueClamped;
399 Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
400 ArrayRef<OpFoldResult> inputShape, Value scale,
401 Value zeroPoint, QuantizedType quantizedType) {
403 auto inputType = input.getType();
404 scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
407 auto result = convertIntegerToFloat(builder, loc, input, scale.getType(),
408 quantizedType.isSigned());
413 zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
417 zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
418 quantizedType.isSigned());
421 result = builder.create<arith::SubFOp>(loc, result, zeroPoint);
425 result = builder.create<arith::MulFOp>(loc, result, scale);
449 Value convertRanked(OpBuilder &builder, Location loc, Operation *op,
450 Value input, ArrayRef<OpFoldResult> inputShape, Value scale,
451 Value zeroPoint, QuantizedType quantizedType) {
452 if (isa<QuantizeCastOp>(op))
453 return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
455 if (isa<DequantizeCastOp>(op))
456 return dequantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
458 llvm_unreachable(
"unexpected quant op");
473 Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
474 Value input, UniformQuantizedType quantizedType) {
476 auto expressedType = quantizedType.getExpressedType();
477 auto storageType = quantizedType.getStorageType();
479 builder.getFloatAttr(expressedType, quantizedType.getScale());
480 auto scale = builder.create<arith::ConstantOp>(loc, expressedType, scaleAttr);
482 builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
484 builder.create<arith::ConstantOp>(loc, storageType, zeroPointAttr);
486 auto inputShape = getScalarOrTensorShape(builder, loc, input);
487 return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint,
502 Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op,
503 Value input, UniformQuantizedType quantizedType) {
505 bool isUnranked = isa<UnrankedTensorType>(input.getType());
508 std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
511 auto result = convertPerLayerRanked(builder, loc, op, input, quantizedType);
515 result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
532 Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
534 UniformQuantizedPerAxisType quantizedType,
535 int64_t channelAxis) {
536 auto *context = builder.getContext();
538 auto inputType = cast<RankedTensorType>(input.getType());
539 auto inputRank = inputType.getRank();
541 auto scales = materializePerChannelScales(builder, loc, quantizedType);
543 materializePerChannelZeroPoints(builder, loc, quantizedType);
545 auto elementType = isa<FloatType>(inputType.getElementType())
546 ? quantizedType.getStorageType()
547 : quantizedType.getExpressedType();
549 Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType);
551 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
552 utils::IteratorType::parallel);
554 inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
555 SmallVector<AffineMap> indexingMaps{
556 builder.getMultiDimIdentityMap(inputRank), channelAxisAffineMap,
557 channelAxisAffineMap, builder.getMultiDimIdentityMap(inputRank)};
558 auto result = builder
559 .create<linalg::GenericOp>(
562 ValueRange{input, scales, zeroPoints},
564 indexingMaps, iteratorTypes,
565 [&](OpBuilder &builder, Location loc, ValueRange args) {
566 assert(args.size() == 4);
567 auto input = args[0];
568 auto scale = args[1];
569 auto zeroPoint = args[2];
572 convertRanked(builder, loc, op, input, {}, scale,
573 zeroPoint, quantizedType);
575 builder.create<linalg::YieldOp>(loc, result);
593 Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
595 UniformQuantizedPerAxisType quantizedType) {
597 bool isUnranked = isa<UnrankedTensorType>(input.getType());
598 int64_t channelAxis = quantizedType.getQuantizedDimension();
599 int64_t channelAxisSize = (int64_t)quantizedType.getScales().size();
602 std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
603 builder, loc, input, channelAxis, channelAxisSize);
608 auto result = convertPerChannelRanked(builder, loc, op, input, quantizedType,
613 result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
629 Value convertSubChannel(OpBuilder &builder, Location loc, Operation *op,
631 UniformQuantizedSubChannelType quantizedType) {
632 auto *context = builder.getContext();
634 auto inputType = cast<RankedTensorType>(input.getType());
635 auto inputRank = inputType.getRank();
637 auto scales = materializeSubChannelScales(builder, loc, quantizedType);
639 materializeSubChannelZeroPoints(builder, loc, quantizedType);
641 auto elementType = isa<FloatType>(inputType.getElementType())
642 ? quantizedType.getStorageType()
643 : quantizedType.getExpressedType();
645 Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType);
647 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
648 utils::IteratorType::parallel);
649 const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo =
650 quantizedType.getBlockSizeInfo();
651 SmallVector<AffineExpr> affineExprs(inputRank,
652 builder.getAffineConstantExpr(0));
653 for (
auto [quantizedDimension, blockSize] : blockSizeInfo) {
654 affineExprs[quantizedDimension] =
655 builder.getAffineDimExpr(quantizedDimension).floorDiv(blockSize);
657 auto affineMap =
AffineMap::get(inputRank, 0, affineExprs, context);
658 SmallVector<AffineMap> indexingMaps{
659 builder.getMultiDimIdentityMap(inputRank), affineMap, affineMap,
660 builder.getMultiDimIdentityMap(inputRank)};
661 auto result = builder
662 .create<linalg::GenericOp>(
665 ValueRange{input, scales, zeroPoints},
667 indexingMaps, iteratorTypes,
668 [&](OpBuilder &builder, Location loc, ValueRange args) {
669 assert(args.size() == 4);
670 auto input = args[0];
671 auto scale = args[1];
672 auto zeroPoint = args[2];
675 convertRanked(builder, loc, op, input, {}, scale,
676 zeroPoint, quantizedType);
678 builder.create<linalg::YieldOp>(loc, result);
698 Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
699 Value input, Type quantizedType) {
700 if (
auto uniformQuantizedType = dyn_cast<UniformQuantizedType>(quantizedType))
701 return convertPerLayer(builder, loc, op, input, uniformQuantizedType);
703 if (
auto uniformQuantizedPerAxisType =
704 dyn_cast<UniformQuantizedPerAxisType>(quantizedType))
705 return convertPerChannel(builder, loc, op, input,
706 uniformQuantizedPerAxisType);
708 if (
auto uniformQuantizedSubChannelType =
709 dyn_cast<UniformQuantizedSubChannelType>(quantizedType))
710 return convertSubChannel(builder, loc, op, input,
711 uniformQuantizedSubChannelType);
713 llvm_unreachable(
"unexpected quantized type");
717 struct DequantizeCastOpConversion
718 :
public OpConversionPattern<quant::DequantizeCastOp> {
722 matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor,
723 ConversionPatternRewriter &rewriter)
const override {
724 auto loc = op.getLoc();
725 auto input = op.getInput();
727 cast<QuantizedType>(getScalarType(op.getInput().getType()));
730 auto storageScalarOrTensorType =
731 getScalarOrTensorType(quantizedType.getStorageType(), input.getType());
732 input = rewriter.create<quant::StorageCastOp>(
733 loc, storageScalarOrTensorType, input);
735 auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
737 rewriter.replaceOp(op, result);
743 struct QuantizeCastOpConversion
744 :
public OpConversionPattern<quant::QuantizeCastOp> {
748 matchAndRewrite(quant::QuantizeCastOp op, OpAdaptor adaptor,
749 ConversionPatternRewriter &rewriter)
const override {
750 auto loc = op.getLoc();
751 auto input = op.getInput();
752 auto quantizedType = getScalarType(op.getResult().getType());
755 auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
758 rewriter.replaceOpWithNewOp<quant::StorageCastOp>(
759 op, op.getResult().getType(), result);
764 struct LowerQuantOps :
public impl::LowerQuantOpsBase<LowerQuantOps> {
765 void runOnOperation()
override {
770 target.addLegalOp<quant::StorageCastOp>();
771 target.addIllegalDialect<quant::QuantDialect>();
772 target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
773 shape::ShapeDialect, tensor::TensorDialect>();
784 patterns.add<DequantizeCastOpConversion, QuantizeCastOpConversion>(
static MLIRContext * getContext(OpFoldResult val)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
void populateLowerQuantOpsPatterns(RewritePatternSet &patterns)
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.