MLIR 22.0.0git
LowerQuantOps.cpp
Go to the documentation of this file.
1//===- LowerQuantOps.cpp - Lower 'quant' dialect ops ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Transforms `quant.dcast` and `quant.qcast` into lower-level ops.
10//
11//===----------------------------------------------------------------------===//
12
24
25namespace mlir {
26namespace quant {
28#define GEN_PASS_DEF_LOWERQUANTOPS
29#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
30
31namespace {
33// If 'inputType' is a tensor, return its element type. If it is a scalar,
34// return it as is.
35Type getScalarType(Type inputType) {
36 if (auto tensorType = dyn_cast<TensorType>(inputType))
37 return tensorType.getElementType();
38 return inputType;
40
41// Return the shape of an input value as a list of attributes (static
42// dimensions) and values (dynamic dimensions). If 'input' is a scalar, an empty
43// list is returned. If 'input' is a tensor, its shape is returned.
44SmallVector<OpFoldResult> getScalarOrTensorShape(OpBuilder &builder,
45 Location loc, Value input) {
46 if (isa<TensorType>(input.getType()))
47 return tensor::getMixedSizes(builder, loc, input);
48 return {};
49}
51// If 'referenceType' is a scalar, return 'elementType' as is. If
52// 'referenceType' is a tensor, return another tensor with the same shape and
53// elements of type 'elementType'.
54Type getScalarOrTensorType(Type elementType, Type referenceType) {
55 if (auto tensorType = dyn_cast<TensorType>(referenceType))
56 return tensorType.clone(elementType);
57 return elementType;
59
60// Return a constant with the given value. If 'referenceType' is a tensor, a
61// tensor splat of shape 'referenceShape' is returned. If 'referenceType' is a
62// scalar, 'referenceShape' is ignored and a scalar constant is returned.
63Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar,
64 Type referenceType,
65 ArrayRef<OpFoldResult> referenceShape) {
66 // If the result type is a scalar, return the unmodified scalar constant.
67 auto tensorType = dyn_cast<TensorType>(referenceType);
68 if (!tensorType) {
69 assert(referenceShape.empty());
70 return scalar;
71 }
72
73 // Create tensor splat
74 auto tensorConstant =
75 tensor::SplatOp::create(builder, loc, scalar, referenceShape);
76 return tensorConstant;
77}
78
79// Reshape an unranked tensor into a 1D ranked tensor.
80//
81// - input
82// Unranked tensor.
83//
84// Return values:
85//
86// - flatInput
87// 1D ranked, dynamically shaped tensor.
88//
89// - inputShape
90// 1D extent tensor containing the shape of the original unranked input.
91//
92std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
93 Value input) {
94 // Get unranked input shape and total size
95 auto *context = builder.getContext();
96 auto shapeType = shape::getExtentTensorType(context);
97 auto inputShape = shape::ShapeOfOp::create(builder, loc, shapeType, input);
98 Value inputSize = shape::NumElementsOp::create(
99 builder, loc, builder.getIndexType(), inputShape);
100
101 // Turn input size into 1D tensor
102 auto flatShapeType = shape::getExtentTensorType(context, 1);
103 auto flatInputShape =
104 tensor::FromElementsOp::create(builder, loc, flatShapeType, inputSize);
105
106 // Reshape input tensor into 1D
107 auto inputType = cast<UnrankedTensorType>(input.getType());
108 auto elementType = inputType.getElementType();
109 auto flatInputType =
110 RankedTensorType::get({ShapedType::kDynamic}, elementType);
111 auto flatInput = tensor::ReshapeOp::create(builder, loc, flatInputType, input,
112 flatInputShape);
113 return std::make_pair(flatInput, inputShape);
114}
115
116// Reshape an unranked tensor into a 3D ranked tensor where the central
117// dimension of the result tensor corresponds to dimension 'axis' of the input
118// tensor.
119//
120// - input
121// Unranked tensor.
122//
123// - axis
124// Index of the input dimension around which other input dimiensions will be
125// collapsed.
126//
127// - axisSize
128// Size of input dimension 'axis'.
129//
130// Return values:
131//
132// - flatInput
133// 3D ranked tensor of shape [?, axisSize, ?].
134//
135// - inputShape
136// 1D extent tensor containing the shape of the original unranked input.
137//
138std::pair<Value, Value>
139flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input,
140 int64_t axis, int64_t axisSize) {
141 // Get full tensor shape
142 auto *context = builder.getContext();
143 auto indexType = builder.getIndexType();
144 auto shapeType = shape::getExtentTensorType(context);
145 auto inputShape = shape::ShapeOfOp::create(builder, loc, shapeType, input);
146
147 // Get shape and sizes on left and right of axis
148 auto axisValue = arith::ConstantIndexOp::create(builder, loc, axis);
149 auto axisNextValue = arith::ConstantIndexOp::create(builder, loc, axis + 1);
150 auto shapeLeft =
151 shape::SplitAtOp::create(builder, loc, TypeRange{shapeType, shapeType},
152 inputShape, axisValue)
153 .getResult(0);
154 auto sizeLeft =
155 shape::NumElementsOp::create(builder, loc, indexType, shapeLeft);
156 auto shapeRight =
157 shape::SplitAtOp::create(builder, loc, TypeRange{shapeType, shapeType},
158 inputShape, axisNextValue)
159 .getResult(1);
160 auto sizeRight =
161 shape::NumElementsOp::create(builder, loc, indexType, shapeRight);
162
163 // Compute flat input shape as a 3-element 1D tensor
164 auto axisSizeValue = arith::ConstantIndexOp::create(builder, loc, axisSize);
165 auto flatShapeType = shape::getExtentTensorType(context, 3);
166 auto flatInputShape = tensor::FromElementsOp::create(
167 builder, loc, flatShapeType,
168 ValueRange{sizeLeft, axisSizeValue, sizeRight});
169
170 // Reshape input to 3D tensor
171 auto inputType = cast<UnrankedTensorType>(input.getType());
172 auto elementType = inputType.getElementType();
173 auto flatInputType = RankedTensorType::get(
174 {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
175 auto flatInput = tensor::ReshapeOp::create(builder, loc, flatInputType, input,
176 flatInputShape);
177
178 return std::make_pair(flatInput, inputShape);
179}
180
181// Reshape an input tensor into its original unranked shape.
182//
183// - input
184// Ranked tensor.
185//
186// - inputShape
187// 1D extent tensor.
188//
189Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
190 Value inputShape) {
191 auto inputType = cast<RankedTensorType>(input.getType());
192 auto elementType = inputType.getElementType();
193 auto unrankedType = UnrankedTensorType::get(elementType);
194 return tensor::ReshapeOp::create(builder, loc, unrankedType, input,
195 inputShape);
196}
197
198// Create a tensor constant containing all scales in a per-channel quantized
199// type. Example:
200//
201// !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
202//
203// produces
204//
205// %cst = arith.constant dense<[2.0, 3.0]> : tensor<2xf32>
206//
207Value materializePerChannelScales(OpBuilder &builder, Location loc,
208 UniformQuantizedPerAxisType quantizedType) {
209 auto scales = quantizedType.getScales();
210 auto expressedType = quantizedType.getExpressedType();
211 auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute {
212 return builder.getFloatAttr(expressedType, scale);
213 });
214 auto tensorType =
215 RankedTensorType::get({(int64_t)scales.size()}, expressedType);
216 auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
217 return arith::ConstantOp::create(builder, loc, tensorType, scalesAttr);
218}
219
220// Create a tensor constant containing all zero points in a per-channel
221// quantized type. Example:
222//
223// !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
224//
225// produces
226//
227// %cst = arith.constant dense<[10, 20]> : tensor<2xi8>
228//
229Value materializePerChannelZeroPoints(
230 OpBuilder &builder, Location loc,
231 UniformQuantizedPerAxisType quantizedType) {
232 auto zeroPoints = quantizedType.getZeroPoints();
233 auto storageType = quantizedType.getStorageType();
234 auto zeroPointAttrs =
235 llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> Attribute {
236 return builder.getIntegerAttr(storageType, zeroPoint);
237 });
238 auto tensorType =
239 RankedTensorType::get({(int64_t)zeroPoints.size()}, storageType);
240 auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
241 return arith::ConstantOp::create(builder, loc, tensorType, zeroPointsAttr);
242}
243
244// Create a tensor constant containing all scales in a sub-channel quantized
245// type. Example:
246//
247// !quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
248//
249// produces
250//
251// %cst = arith.constant dense<[[2.0, 3.0], [4.0, 5.0]]> : tensor<2x2xf32>
252//
253Value materializeSubChannelScales(
254 OpBuilder &builder, Location loc,
255 UniformQuantizedSubChannelType quantizedType) {
256 auto scales = quantizedType.getScales();
257 auto expressedType = quantizedType.getExpressedType();
258 auto scaleAttrs = llvm::map_to_vector(
259 scales.getValues<APFloat>(), [&](APFloat scale) -> Attribute {
260 return builder.getFloatAttr(expressedType, scale);
261 });
262 auto tensorType =
263 RankedTensorType::get(scales.getType().getShape(), expressedType);
264 auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
265 return arith::ConstantOp::create(builder, loc, tensorType, scalesAttr);
266}
267
268// Create a tensor constant containing all zero points in a sub-channel
269// quantized type. Example:
270//
271// !quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
272//
273// produces
274//
275// %cst = arith.constant dense<[[10, 20], [30, 40]]> : tensor<2x2xi8>
276//
277Value materializeSubChannelZeroPoints(
278 OpBuilder &builder, Location loc,
279 UniformQuantizedSubChannelType quantizedType) {
280 auto zeroPoints = quantizedType.getZeroPoints();
281 auto storageType = quantizedType.getStorageType();
282 auto zeroPointAttrs = llvm::map_to_vector(
283 zeroPoints.getValues<APInt>(), [&](APInt zeroPoint) -> Attribute {
284 return builder.getIntegerAttr(storageType, zeroPoint);
285 });
286 auto tensorType =
287 RankedTensorType::get(zeroPoints.getType().getShape(), storageType);
288 auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
289 return arith::ConstantOp::create(builder, loc, tensorType, zeroPointsAttr);
290}
291
292// Clamp the given scalar or tensor input using the storage bounds encoded in
293// the given quantized type, if present.
294//
295// - input
296// Scalar or ranked tensor input. The element type must match the storage type
297// of 'quantizedType'.
298//
299// - inputShape
300// If 'input' is a tensor, combination of attributes/values representing its
301// static/dynamic dimensions. If 'input' is a scalar, empty list.
302//
303// - quantizedType
304// Per-axis or per-channel quantized type.
305Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
306 ArrayRef<OpFoldResult> inputShape,
307 QuantizedType quantizedType) {
308 // If quantized type does not narrow down the storage type range, there is
309 // nothing to do.
310 if (!quantizedType.hasStorageTypeBounds())
311 return input;
312
313 // Materialize bounds
314 auto inputType = input.getType();
315 auto storageType = quantizedType.getStorageType();
316 auto storageMinScalar = arith::ConstantIntOp::create(
317 builder, loc, storageType, quantizedType.getStorageTypeMin());
318 auto storageMaxScalar = arith::ConstantIntOp::create(
319 builder, loc, storageType, quantizedType.getStorageTypeMax());
320 auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar,
321 inputType, inputShape);
322 auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar,
323 inputType, inputShape);
324
325 // Clamp
326 if (quantizedType.isSigned()) {
327 input = arith::MaxSIOp::create(builder, loc, input, storageMin);
328 input = arith::MinSIOp::create(builder, loc, input, storageMax);
329 } else {
330 input = arith::MaxUIOp::create(builder, loc, input, storageMin);
331 input = arith::MinUIOp::create(builder, loc, input, storageMax);
332 }
333 return input;
334}
335
336// Emit op 'arith.fptosi' or 'arith.fptoui'.
337Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
338 Type resultType, bool isSigned) {
339 if (isSigned)
340 return arith::FPToSIOp::create(builder, loc, resultType, input);
341 return arith::FPToUIOp::create(builder, loc, resultType, input);
342}
343
344// Emit op 'arith.sitofp' or 'arith.uitofp'.
345Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
346 Type resultType, bool isSigned) {
347 if (isSigned)
348 return arith::SIToFPOp::create(builder, loc, resultType, input);
349 return arith::UIToFPOp::create(builder, loc, resultType, input);
350}
351
352// Quantize a scalar or ranked tensor value. The stored value is clamped using
353// the storage bounds encoded in the given quantized type.
354//
355// See function 'convertRanked()' below for a description of the arguments.
356Value quantizeValue(OpBuilder &builder, Location loc, Value input,
357 ArrayRef<OpFoldResult> inputShape, Value scale,
358 Value zeroPoint, QuantizedType quantizedType) {
359 // Convert scale to tensor if necessary
360 auto inputType = input.getType();
361 scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
362
363 // Scale input
364 auto scaledValue = arith::DivFOp::create(builder, loc, input, scale);
365
366 // Skip unnecessary computations if no zero point is given
367 Value storedValueFloat = scaledValue;
368 if (!matchPattern(zeroPoint, m_Zero())) {
369 // Convert zero point to tensor if necessary
370 zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
371 inputShape);
372
373 // Convert zero point from storage to expressed type
374 zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
375 quantizedType.isSigned());
376
377 // Add zero point to stored value
378 storedValueFloat =
379 arith::AddFOp::create(builder, loc, scaledValue, zeroPoint);
380 }
381
382 // Convert stored value to storage type
383 auto storageScalarOrTensorType =
384 getScalarOrTensorType(quantizedType.getStorageType(), inputType);
385 auto storedValueInt = convertFloatToInteger(builder, loc, storedValueFloat,
386 storageScalarOrTensorType,
387 quantizedType.isSigned());
388
389 // Clamp stored value it if the storage type is bound
390 auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt,
391 inputShape, quantizedType);
392 return storedValueClamped;
393}
394
395// Dequantize a scalar or ranked tensor input.
396//
397// See function 'convertRanked()' below for a description of the arguments.
398Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
399 ArrayRef<OpFoldResult> inputShape, Value scale,
400 Value zeroPoint, QuantizedType quantizedType) {
401 // Convert scale to tensor if necessary
402 auto inputType = input.getType();
403 scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
404
405 // Convert stored value to float
406 auto result = convertIntegerToFloat(builder, loc, input, scale.getType(),
407 quantizedType.isSigned());
408
409 // Skip unnecessary computations if no zero point is given
410 if (!matchPattern(zeroPoint, m_Zero())) {
411 // Convert zero point to tensor if necessary
412 zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
413 inputShape);
414
415 // Convert zero point from storage to expressed type
416 zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
417 quantizedType.isSigned());
418
419 // Subtract zero point to stored value
420 result = arith::SubFOp::create(builder, loc, result, zeroPoint);
421 }
422
423 // Multiply by scale
424 result = arith::MulFOp::create(builder, loc, result, scale);
425 return result;
426}
427
428// Convert a scalar or ranked tensor input with the given scale and zero point
429// values.
430//
431// - input
432// Scalar or ranked tensor value.
433//
434// - inputShape
435// If 'input' is a tensor, combination or attributes/values representing its
436// static/dynamic dimensions. If 'input' is a scalar, empty list.
437//
438// - scale
439// Scale as a floating-point scalar value.
440//
441// - zeroPoint
442// Zero point as an integer scalar value.
443//
444// - quantizedType
445// Scalar quantized type of the result ('quant.qcast') or of the input
446// ('quant.dcast').
447//
448Value convertRanked(OpBuilder &builder, Location loc, Operation *op,
449 Value input, ArrayRef<OpFoldResult> inputShape, Value scale,
450 Value zeroPoint, QuantizedType quantizedType) {
451 if (isa<QuantizeCastOp>(op))
452 return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
453 quantizedType);
454 if (isa<DequantizeCastOp>(op))
455 return dequantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
456 quantizedType);
457 llvm_unreachable("unexpected quant op");
458}
459
460// Convert an operation using per-layer quantization with a scalar or ranked
461// tensor input.
462//
463// - op
464// 'quant.dcast' or 'quant.qcast' op.
465//
466// - input
467// Scalar or ranked tensor.
468//
469// - quantizedType
470// Per-layer quantized type.
471//
472Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
473 Value input, UniformQuantizedType quantizedType) {
474 // Create scale and zero point constants
475 auto expressedType = quantizedType.getExpressedType();
476 auto storageType = quantizedType.getStorageType();
477 auto scaleAttr =
478 builder.getFloatAttr(expressedType, quantizedType.getScale());
479 auto scale =
480 arith::ConstantOp::create(builder, loc, expressedType, scaleAttr);
481 auto zeroPointAttr =
482 builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
483 auto zeroPoint =
484 arith::ConstantOp::create(builder, loc, storageType, zeroPointAttr);
485
486 auto inputShape = getScalarOrTensorShape(builder, loc, input);
487 return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint,
488 quantizedType);
489}
490
491// Convert an operation using per-layer quantization.
492//
493// - op
494// 'quant.dcast' or 'quant.qcast' op.
495//
496// - input
497// Scalar, ranked tensor, or unranked tensor.
498//
499// - quantizedType
500// Per-layer quantized type.
501//
502Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op,
503 Value input, UniformQuantizedType quantizedType) {
504 // Flatten input if unranked
505 bool isUnranked = isa<UnrankedTensorType>(input.getType());
506 Value inputShape;
507 if (isUnranked)
508 std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
509
510 // Process ranked tensor
511 auto result = convertPerLayerRanked(builder, loc, op, input, quantizedType);
512
513 // Restore original shape if unranked
514 if (isUnranked)
515 result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
516
517 return result;
518}
519
520// Convert an operation using per-channel quantization and a scalar or ranked
521// tensor as an input.
522//
523// - op
524// 'quant.dcast' or 'quant.qcast' op.
525//
526// - input
527// Scalar or ranked tensor.
528//
529// - quantizedType
530// Per-channel quantized type.
531//
532Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
533 Value input,
534 UniformQuantizedPerAxisType quantizedType,
535 int64_t channelAxis) {
536 auto *context = builder.getContext();
537
538 auto inputType = cast<RankedTensorType>(input.getType());
539 auto inputRank = inputType.getRank();
540
541 auto scales = materializePerChannelScales(builder, loc, quantizedType);
542 auto zeroPoints =
543 materializePerChannelZeroPoints(builder, loc, quantizedType);
544
545 auto elementType = isa<FloatType>(inputType.getElementType())
546 ? quantizedType.getStorageType()
547 : quantizedType.getExpressedType();
548 auto initShape = tensor::getMixedSizes(builder, loc, input);
549 Value init = tensor::EmptyOp::create(builder, loc, initShape, elementType);
550
551 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
552 utils::IteratorType::parallel);
553 auto channelAxisAffineMap = AffineMap::get(
554 inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
555 SmallVector<AffineMap> indexingMaps{
556 builder.getMultiDimIdentityMap(inputRank), channelAxisAffineMap,
557 channelAxisAffineMap, builder.getMultiDimIdentityMap(inputRank)};
558 auto result = linalg::GenericOp::create(
559 builder, loc,
560 init.getType(), // resultType
561 ValueRange{input, scales, zeroPoints}, // inputs
562 ValueRange{init}, // outputs
563 indexingMaps, iteratorTypes,
564 [&](OpBuilder &builder, Location loc, ValueRange args) {
565 assert(args.size() == 4);
566 auto input = args[0];
567 auto scale = args[1];
568 auto zeroPoint = args[2];
569
570 auto result =
571 convertRanked(builder, loc, op, input, {}, scale,
572 zeroPoint, quantizedType);
573
574 linalg::YieldOp::create(builder, loc, result);
575 })
576 .getResult(0);
577
578 return result;
579}
580
581// Convert an operation using per-channel quantization.
582//
583// - op
584// 'quant.dcast' or 'quant.qcast' op.
585//
586// - input
587// Scalar, ranked tensor, or unranked tensor.
588//
589// - quantizedType
590// Per-channel quantized type.
591//
592Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
593 Value input,
594 UniformQuantizedPerAxisType quantizedType) {
595 // Flatten unranked tensor into a 3D ranked tensor if necessary
596 bool isUnranked = isa<UnrankedTensorType>(input.getType());
597 int64_t channelAxis = quantizedType.getQuantizedDimension();
598 int64_t channelAxisSize = (int64_t)quantizedType.getScales().size();
599 Value inputShape;
600 if (isUnranked) {
601 std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
602 builder, loc, input, channelAxis, channelAxisSize);
603 channelAxis = 1;
604 }
605
606 // Work on a ranked tensor
607 auto result = convertPerChannelRanked(builder, loc, op, input, quantizedType,
608 channelAxis);
609
610 // Restore original tensor shape if unranked
611 if (isUnranked)
612 result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
613
614 return result;
615}
616
617// Convert an operation using sub-channel quantization.
618//
619// - op
620// 'quant.dcast' or 'quant.qcast' op.
621//
622// - input
623// Scalar, ranked tensor.
624//
625// - quantizedType
626// Sub-channel quantized type.
627//
628Value convertSubChannel(OpBuilder &builder, Location loc, Operation *op,
629 Value input,
630 UniformQuantizedSubChannelType quantizedType) {
631 auto *context = builder.getContext();
632
633 auto inputType = cast<RankedTensorType>(input.getType());
634 auto inputRank = inputType.getRank();
635
636 auto scales = materializeSubChannelScales(builder, loc, quantizedType);
637 auto zeroPoints =
638 materializeSubChannelZeroPoints(builder, loc, quantizedType);
639
640 auto elementType = isa<FloatType>(inputType.getElementType())
641 ? quantizedType.getStorageType()
642 : quantizedType.getExpressedType();
643 auto initShape = tensor::getMixedSizes(builder, loc, input);
644 Value init = tensor::EmptyOp::create(builder, loc, initShape, elementType);
645
646 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
647 utils::IteratorType::parallel);
648 const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo =
649 quantizedType.getBlockSizeInfo();
650 SmallVector<AffineExpr> affineExprs(inputRank,
651 builder.getAffineConstantExpr(0));
652 for (auto [quantizedDimension, blockSize] : blockSizeInfo) {
653 affineExprs[quantizedDimension] =
654 builder.getAffineDimExpr(quantizedDimension).floorDiv(blockSize);
655 }
656 auto affineMap = AffineMap::get(inputRank, 0, affineExprs, context);
657 SmallVector<AffineMap> indexingMaps{
658 builder.getMultiDimIdentityMap(inputRank), affineMap, affineMap,
659 builder.getMultiDimIdentityMap(inputRank)};
660 auto result = linalg::GenericOp::create(
661 builder, loc,
662 init.getType(), // resultType
663 ValueRange{input, scales, zeroPoints}, // inputs
664 ValueRange{init}, // outputs
665 indexingMaps, iteratorTypes,
666 [&](OpBuilder &builder, Location loc, ValueRange args) {
667 assert(args.size() == 4);
668 auto input = args[0];
669 auto scale = args[1];
670 auto zeroPoint = args[2];
671
672 auto result =
673 convertRanked(builder, loc, op, input, {}, scale,
674 zeroPoint, quantizedType);
675
676 linalg::YieldOp::create(builder, loc, result);
677 })
678 .getResult(0);
679
680 return result;
681}
682
683// Convert a quantization operation.
684//
685// - op
686// 'quant.dcast' or 'quant.qcast' op.
687//
688// - input
689// Scalar, ranked tensor, or unranked tensor. The element type matches
690// the storage type (quant.dcast) or expressed type (quant.qcast) of
691// 'quantizedType'.
692//
693// - quantizedType
694// Per-layer or per-channel quantized type.
695//
696Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
697 Value input, Type quantizedType) {
698 if (auto uniformQuantizedType = dyn_cast<UniformQuantizedType>(quantizedType))
699 return convertPerLayer(builder, loc, op, input, uniformQuantizedType);
700
701 if (auto uniformQuantizedPerAxisType =
702 dyn_cast<UniformQuantizedPerAxisType>(quantizedType))
703 return convertPerChannel(builder, loc, op, input,
704 uniformQuantizedPerAxisType);
705
706 if (auto uniformQuantizedSubChannelType =
707 dyn_cast<UniformQuantizedSubChannelType>(quantizedType))
708 return convertSubChannel(builder, loc, op, input,
709 uniformQuantizedSubChannelType);
710
711 llvm_unreachable("unexpected quantized type");
712}
713
714// Lowering pattern for 'quant.dcast'
715struct DequantizeCastOpConversion
716 : public OpConversionPattern<quant::DequantizeCastOp> {
717 using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
718
719 LogicalResult
720 matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor,
721 ConversionPatternRewriter &rewriter) const override {
722 auto loc = op.getLoc();
723 auto input = op.getInput();
724 auto quantizedType =
725 cast<QuantizedType>(getScalarType(op.getInput().getType()));
726
727 // Convert quantized input to storage type
728 auto storageScalarOrTensorType =
729 getScalarOrTensorType(quantizedType.getStorageType(), input.getType());
730 input = quant::StorageCastOp::create(rewriter, loc,
731 storageScalarOrTensorType, input);
732
733 auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
734
735 rewriter.replaceOp(op, result);
736 return success();
737 }
738};
739
740// Lowering pattern for 'quant.qcast'
741struct QuantizeCastOpConversion
742 : public OpConversionPattern<quant::QuantizeCastOp> {
743 using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern;
744
745 LogicalResult
746 matchAndRewrite(quant::QuantizeCastOp op, OpAdaptor adaptor,
747 ConversionPatternRewriter &rewriter) const override {
748 auto loc = op.getLoc();
749 auto input = op.getInput();
750 auto quantizedType = getScalarType(op.getResult().getType());
751
752 // Flatten unranked tensor input
753 auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
754
755 // Cast stored value to result quantized value
756 rewriter.replaceOpWithNewOp<quant::StorageCastOp>(
757 op, op.getResult().getType(), result);
758 return success();
759 }
760};
761
762struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
763 void runOnOperation() override {
764 RewritePatternSet patterns(&getContext());
766
767 ConversionTarget target(getContext());
768 target.addLegalOp<quant::StorageCastOp>();
769 target.addIllegalDialect<quant::QuantDialect>();
770 target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
771 shape::ShapeDialect, tensor::TensorDialect>();
772
773 if (failed(applyPartialConversion(getOperation(), target,
774 std::move(patterns))))
775 signalPassFailure();
776 }
777};
778
779} // namespace
780
782 patterns.add<DequantizeCastOpConversion, QuantizeCastOpConversion>(
783 patterns.getContext());
784}
785
786} // namespace quant
787} // namespace mlir
return success()
b getContext())
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:258
Base class for all quantized types known to this dialect.
Definition QuantTypes.h:50
Represents per-axis (also known as per-channel quantization).
Definition QuantTypes.h:324
Represents sub-channel (also known as blockwise quantization).
Definition QuantTypes.h:409
Represents a family of uniform, quantized types.
Definition QuantTypes.h:264
void populateLowerQuantOpsPatterns(RewritePatternSet &patterns)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Definition Shape.cpp:40
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
const FrozenRewritePatternSet & patterns