MLIR 23.0.0git
LowerQuantOps.cpp
Go to the documentation of this file.
1//===- LowerQuantOps.cpp - Lower 'quant' dialect ops ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Transforms `quant.dcast` and `quant.qcast` into lower-level ops.
10//
11//===----------------------------------------------------------------------===//
12
21#include "mlir/IR/Matchers.h"
24
25namespace mlir {
26namespace quant {
27
28#define GEN_PASS_DEF_LOWERQUANTOPS
29#include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
30
31namespace {
32
33// If 'inputType' is a tensor, return its element type. If it is a scalar,
34// return it as is.
35Type getScalarType(Type inputType) {
36 if (auto tensorType = dyn_cast<TensorType>(inputType))
37 return tensorType.getElementType();
38 return inputType;
39}
40
41// Return the shape of an input value as a list of attributes (static
42// dimensions) and values (dynamic dimensions). If 'input' is a scalar, an empty
43// list is returned. If 'input' is a tensor, its shape is returned.
44SmallVector<OpFoldResult> getScalarOrTensorShape(OpBuilder &builder,
45 Location loc, Value input) {
46 if (isa<TensorType>(input.getType()))
47 return tensor::getMixedSizes(builder, loc, input);
48 return {};
49}
50
51// If 'referenceType' is a scalar, return 'elementType' as is. If
52// 'referenceType' is a tensor, return another tensor with the same shape and
53// elements of type 'elementType'.
54Type getScalarOrTensorType(Type elementType, Type referenceType) {
55 if (auto tensorType = dyn_cast<TensorType>(referenceType))
56 return tensorType.clone(elementType);
57 return elementType;
58}
59
60// Return a constant with the given value. If 'referenceType' is a tensor, a
61// tensor splat of shape 'referenceShape' is returned. If 'referenceType' is a
62// scalar, 'referenceShape' is ignored and a scalar constant is returned.
63Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar,
64 Type referenceType,
65 ArrayRef<OpFoldResult> referenceShape) {
66 // If the result type is a scalar, return the unmodified scalar constant.
67 auto tensorType = dyn_cast<TensorType>(referenceType);
68 if (!tensorType) {
69 assert(referenceShape.empty());
70 return scalar;
71 }
72
73 // Create tensor splat
74 auto tensorConstant =
75 tensor::SplatOp::create(builder, loc, scalar, referenceShape);
76 return tensorConstant;
77}
78
79// Reshape an unranked tensor into a 1D ranked tensor.
80//
81// - input
82// Unranked tensor.
83//
84// Return values:
85//
86// - flatInput
87// 1D ranked, dynamically shaped tensor.
88//
89// - inputShape
90// 1D extent tensor containing the shape of the original unranked input.
91//
92std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
93 Value input) {
94 // Get unranked input shape and total size
95 auto *context = builder.getContext();
96 auto shapeType = shape::getExtentTensorType(context);
97 auto inputShape = shape::ShapeOfOp::create(builder, loc, shapeType, input);
98 Value inputSize = shape::NumElementsOp::create(
99 builder, loc, builder.getIndexType(), inputShape);
100
101 // Turn input size into 1D tensor
102 auto flatShapeType = shape::getExtentTensorType(context, 1);
103 auto flatInputShape =
104 tensor::FromElementsOp::create(builder, loc, flatShapeType, inputSize);
105
106 // Reshape input tensor into 1D
107 auto inputType = cast<UnrankedTensorType>(input.getType());
108 auto elementType = inputType.getElementType();
109 auto flatInputType =
110 RankedTensorType::get({ShapedType::kDynamic}, elementType);
111 auto flatInput = tensor::ReshapeOp::create(builder, loc, flatInputType, input,
112 flatInputShape);
113 return std::make_pair(flatInput, inputShape);
114}
115
116// Reshape an unranked tensor into a 3D ranked tensor where the central
117// dimension of the result tensor corresponds to dimension 'axis' of the input
118// tensor.
119//
120// - input
121// Unranked tensor.
122//
123// - axis
124// Index of the input dimension around which other input dimiensions will be
125// collapsed.
126//
127// - axisSize
128// Size of input dimension 'axis'.
129//
130// Return values:
131//
132// - flatInput
133// 3D ranked tensor of shape [?, axisSize, ?].
134//
135// - inputShape
136// 1D extent tensor containing the shape of the original unranked input.
137//
138std::pair<Value, Value>
139flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input,
140 int64_t axis, int64_t axisSize) {
141 // Get full tensor shape
142 auto *context = builder.getContext();
143 auto indexType = builder.getIndexType();
144 auto shapeType = shape::getExtentTensorType(context);
145 auto inputShape = shape::ShapeOfOp::create(builder, loc, shapeType, input);
146
147 // Get shape and sizes on left and right of axis
148 auto axisValue = arith::ConstantIndexOp::create(builder, loc, axis);
149 auto axisNextValue = arith::ConstantIndexOp::create(builder, loc, axis + 1);
150 auto shapeLeft =
151 shape::SplitAtOp::create(builder, loc, TypeRange{shapeType, shapeType},
152 inputShape, axisValue)
153 .getResult(0);
154 auto sizeLeft =
155 shape::NumElementsOp::create(builder, loc, indexType, shapeLeft);
156 auto shapeRight =
157 shape::SplitAtOp::create(builder, loc, TypeRange{shapeType, shapeType},
158 inputShape, axisNextValue)
159 .getResult(1);
160 auto sizeRight =
161 shape::NumElementsOp::create(builder, loc, indexType, shapeRight);
162
163 // Compute flat input shape as a 3-element 1D tensor
164 auto axisSizeValue = arith::ConstantIndexOp::create(builder, loc, axisSize);
165 auto flatShapeType = shape::getExtentTensorType(context, 3);
166 auto flatInputShape = tensor::FromElementsOp::create(
167 builder, loc, flatShapeType,
168 ValueRange{sizeLeft, axisSizeValue, sizeRight});
169
170 // Reshape input to 3D tensor
171 auto inputType = cast<UnrankedTensorType>(input.getType());
172 auto elementType = inputType.getElementType();
173 auto flatInputType = RankedTensorType::get(
174 {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
175 auto flatInput = tensor::ReshapeOp::create(builder, loc, flatInputType, input,
176 flatInputShape);
177
178 return std::make_pair(flatInput, inputShape);
179}
180
181// Reshape an input tensor into its original unranked shape.
182//
183// - input
184// Ranked tensor.
185//
186// - inputShape
187// 1D extent tensor.
188//
189Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
190 Value inputShape) {
191 auto inputType = cast<RankedTensorType>(input.getType());
192 auto elementType = inputType.getElementType();
193 auto unrankedType = UnrankedTensorType::get(elementType);
194 return tensor::ReshapeOp::create(builder, loc, unrankedType, input,
195 inputShape);
196}
197
198// Create a tensor constant containing all scales in a per-channel quantized
199// type. Example:
200//
201// !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
202//
203// produces
204//
205// %cst = arith.constant dense<[2.0, 3.0]> : tensor<2xf32>
206//
207Value materializePerChannelScales(OpBuilder &builder, Location loc,
208 UniformQuantizedPerAxisType quantizedType) {
209 auto scales = quantizedType.getScales();
210 auto expressedType = quantizedType.getExpressedType();
211 auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute {
212 return builder.getFloatAttr(expressedType, scale);
213 });
214 auto tensorType =
215 RankedTensorType::get({(int64_t)scales.size()}, expressedType);
216 auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
217 return arith::ConstantOp::create(builder, loc, tensorType, scalesAttr);
218}
219
220// Create a tensor constant containing all zero points in a per-channel
221// quantized type. Example:
222//
223// !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
224//
225// produces
226//
227// %cst = arith.constant dense<[10, 20]> : tensor<2xi8>
228//
229Value materializePerChannelZeroPoints(
230 OpBuilder &builder, Location loc,
231 UniformQuantizedPerAxisType quantizedType) {
232 auto zeroPoints = quantizedType.getZeroPoints();
233 auto storageType = quantizedType.getStorageType();
234 auto zeroPointAttrs =
235 llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> Attribute {
236 return builder.getIntegerAttr(storageType, zeroPoint);
237 });
238 auto tensorType =
239 RankedTensorType::get({(int64_t)zeroPoints.size()}, storageType);
240 auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
241 return arith::ConstantOp::create(builder, loc, tensorType, zeroPointsAttr);
242}
243
244// Create a tensor constant containing all scales in a sub-channel quantized
245// type. Example:
246//
247// !quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
248//
249// produces
250//
251// %cst = arith.constant dense<[[2.0, 3.0], [4.0, 5.0]]> : tensor<2x2xf32>
252//
253Value materializeSubChannelScales(
254 OpBuilder &builder, Location loc,
255 UniformQuantizedSubChannelType quantizedType) {
256 auto scales = quantizedType.getScales();
257 auto expressedType = quantizedType.getExpressedType();
258 auto scaleAttrs = llvm::map_to_vector(
259 scales.getValues<APFloat>(), [&](APFloat scale) -> Attribute {
260 return builder.getFloatAttr(expressedType, scale);
261 });
262 auto tensorType =
263 RankedTensorType::get(scales.getType().getShape(), expressedType);
264 auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
265 return arith::ConstantOp::create(builder, loc, tensorType, scalesAttr);
266}
267
268// Create a tensor constant containing all zero points in a sub-channel
269// quantized type. Example:
270//
271// !quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
272//
273// produces
274//
275// %cst = arith.constant dense<[[10, 20], [30, 40]]> : tensor<2x2xi8>
276//
277Value materializeSubChannelZeroPoints(
278 OpBuilder &builder, Location loc,
279 UniformQuantizedSubChannelType quantizedType) {
280 auto zeroPoints = quantizedType.getZeroPoints();
281 auto storageType = quantizedType.getStorageType();
282 auto zeroPointAttrs = llvm::map_to_vector(
283 zeroPoints.getValues<APInt>(), [&](APInt zeroPoint) -> Attribute {
284 return builder.getIntegerAttr(storageType, zeroPoint);
285 });
286 auto tensorType =
287 RankedTensorType::get(zeroPoints.getType().getShape(), storageType);
288 auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
289 return arith::ConstantOp::create(builder, loc, tensorType, zeroPointsAttr);
290}
291
292// Clamp the given scalar or tensor input using the storage bounds encoded in
293// the given quantized type, if present.
294//
295// - input
296// Scalar or ranked tensor input. The element type must match the storage type
297// of 'quantizedType'.
298//
299// - inputShape
300// If 'input' is a tensor, combination of attributes/values representing its
301// static/dynamic dimensions. If 'input' is a scalar, empty list.
302//
303// - quantizedType
304// Per-axis or per-channel quantized type.
305Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
306 ArrayRef<OpFoldResult> inputShape,
307 QuantizedType quantizedType) {
308 // If quantized type does not narrow down the storage type range, there is
309 // nothing to do.
310 if (!quantizedType.hasStorageTypeBounds())
311 return input;
312
313 // Materialize bounds
314 auto inputType = input.getType();
315 auto storageType = quantizedType.getStorageType();
316 auto storageMinScalar = arith::ConstantIntOp::create(
317 builder, loc, storageType, quantizedType.getStorageTypeMin());
318 auto storageMaxScalar = arith::ConstantIntOp::create(
319 builder, loc, storageType, quantizedType.getStorageTypeMax());
320 auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar,
321 inputType, inputShape);
322 auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar,
323 inputType, inputShape);
324
325 // Clamp
326 if (quantizedType.isSigned()) {
327 input = arith::MaxSIOp::create(builder, loc, input, storageMin);
328 input = arith::MinSIOp::create(builder, loc, input, storageMax);
329 } else {
330 input = arith::MaxUIOp::create(builder, loc, input, storageMin);
331 input = arith::MinUIOp::create(builder, loc, input, storageMax);
332 }
333 return input;
334}
335
336// Emit op 'arith.fptosi' or 'arith.fptoui'.
337Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
338 Type resultType, bool isSigned) {
339 if (isSigned)
340 return arith::FPToSIOp::create(builder, loc, resultType, input);
341 return arith::FPToUIOp::create(builder, loc, resultType, input);
342}
343
344// Emit op 'arith.sitofp' or 'arith.uitofp'.
345Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
346 Type resultType, bool isSigned) {
347 if (isSigned)
348 return arith::SIToFPOp::create(builder, loc, resultType, input);
349 return arith::UIToFPOp::create(builder, loc, resultType, input);
350}
351
352// Quantize a scalar or ranked tensor value. The stored value is clamped using
353// the storage bounds encoded in the given quantized type.
354//
355// See function 'convertRanked()' below for a description of the arguments.
356Value quantizeValue(OpBuilder &builder, Location loc, Value input,
357 ArrayRef<OpFoldResult> inputShape, Value scale,
358 Value zeroPoint, QuantizedType quantizedType) {
359 // Convert scale to tensor if necessary
360 auto inputType = input.getType();
361 scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
362
363 // Scale input
364 auto scaledValue = arith::DivFOp::create(builder, loc, input, scale);
365
366 // Skip unnecessary computations if no zero point is given
367 Value storedValueFloat = scaledValue;
368 if (!matchPattern(zeroPoint, m_Zero())) {
369 // Convert zero point to tensor if necessary
370 zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
371 inputShape);
372
373 // Convert zero point from storage to expressed type
374 zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
375 quantizedType.isSigned());
376
377 // Add zero point to stored value
378 storedValueFloat =
379 arith::AddFOp::create(builder, loc, scaledValue, zeroPoint);
380 }
381
382 // Convert stored value to storage type
383 auto storageScalarOrTensorType =
384 getScalarOrTensorType(quantizedType.getStorageType(), inputType);
385 auto storedValueInt = convertFloatToInteger(builder, loc, storedValueFloat,
386 storageScalarOrTensorType,
387 quantizedType.isSigned());
388
389 // Clamp stored value it if the storage type is bound
390 auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt,
391 inputShape, quantizedType);
392 return storedValueClamped;
393}
394
395// Dequantize a scalar or ranked tensor input.
396//
397// See function 'convertRanked()' below for a description of the arguments.
398Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
399 ArrayRef<OpFoldResult> inputShape, Value scale,
400 Value zeroPoint, QuantizedType quantizedType) {
401 // Convert scale to tensor if necessary
402 auto inputType = input.getType();
403 scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
404
405 // Convert stored value to float
406 auto result = convertIntegerToFloat(builder, loc, input, scale.getType(),
407 quantizedType.isSigned());
408
409 // Skip unnecessary computations if no zero point is given
410 if (!matchPattern(zeroPoint, m_Zero())) {
411 // Convert zero point to tensor if necessary
412 zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
413 inputShape);
414
415 // Convert zero point from storage to expressed type
416 zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
417 quantizedType.isSigned());
418
419 // Subtract zero point to stored value
420 result = arith::SubFOp::create(builder, loc, result, zeroPoint);
421 }
422
423 // Multiply by scale
424 result = arith::MulFOp::create(builder, loc, result, scale);
425 return result;
426}
427
428// Convert a scalar or ranked tensor input with the given scale and zero point
429// values.
430//
431// - input
432// Scalar or ranked tensor value.
433//
434// - inputShape
435// If 'input' is a tensor, combination or attributes/values representing its
436// static/dynamic dimensions. If 'input' is a scalar, empty list.
437//
438// - scale
439// Scale as a floating-point scalar value.
440//
441// - zeroPoint
442// Zero point as an integer scalar value.
443//
444// - quantizedType
445// Scalar quantized type of the result ('quant.qcast') or of the input
446// ('quant.dcast').
447//
448Value convertRanked(OpBuilder &builder, Location loc, Operation *op,
449 Value input, ArrayRef<OpFoldResult> inputShape, Value scale,
450 Value zeroPoint, QuantizedType quantizedType) {
451 if (isa<QuantizeCastOp>(op))
452 return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
453 quantizedType);
454 if (isa<DequantizeCastOp>(op))
455 return dequantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
456 quantizedType);
457 llvm_unreachable("unexpected quant op");
458}
459
460// Convert an operation using per-layer quantization with a scalar or ranked
461// tensor input.
462//
463// - op
464// 'quant.dcast' or 'quant.qcast' op.
465//
466// - input
467// Scalar or ranked tensor.
468//
469// - quantizedType
470// Per-layer quantized type.
471//
472Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
473 Value input, UniformQuantizedType quantizedType) {
474 // Create scale and zero point constants
475 auto expressedType = quantizedType.getExpressedType();
476 auto storageType = quantizedType.getStorageType();
477 auto scaleAttr =
478 builder.getFloatAttr(expressedType, quantizedType.getScale());
479 auto scale =
480 arith::ConstantOp::create(builder, loc, expressedType, scaleAttr);
481 auto zeroPointAttr =
482 builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
483 auto zeroPoint =
484 arith::ConstantOp::create(builder, loc, storageType, zeroPointAttr);
485
486 auto inputShape = getScalarOrTensorShape(builder, loc, input);
487 return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint,
488 quantizedType);
489}
490
491// Convert an operation using per-layer quantization.
492//
493// - op
494// 'quant.dcast' or 'quant.qcast' op.
495//
496// - input
497// Scalar, ranked tensor, or unranked tensor.
498//
499// - quantizedType
500// Per-layer quantized type.
501//
502Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op,
503 Value input, UniformQuantizedType quantizedType) {
504 // Flatten input if unranked
505 bool isUnranked = isa<UnrankedTensorType>(input.getType());
506 Value inputShape;
507 if (isUnranked)
508 std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
509
510 // Process ranked tensor
511 auto result = convertPerLayerRanked(builder, loc, op, input, quantizedType);
512
513 // Restore original shape if unranked
514 if (isUnranked)
515 result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
516
517 return result;
518}
519
520// Convert an operation using per-channel quantization and a scalar or ranked
521// tensor as an input.
522//
523// - op
524// 'quant.dcast' or 'quant.qcast' op.
525//
526// - input
527// Scalar or ranked tensor.
528//
529// - quantizedType
530// Per-channel quantized type.
531//
532Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
533 Value input,
534 UniformQuantizedPerAxisType quantizedType,
535 int64_t channelAxis) {
536 auto *context = builder.getContext();
537
538 auto inputType = cast<RankedTensorType>(input.getType());
539 auto inputRank = inputType.getRank();
540
541 auto scales = materializePerChannelScales(builder, loc, quantizedType);
542 auto zeroPoints =
543 materializePerChannelZeroPoints(builder, loc, quantizedType);
544
545 auto elementType = isa<FloatType>(inputType.getElementType())
546 ? quantizedType.getStorageType()
547 : quantizedType.getExpressedType();
548 auto initShape = tensor::getMixedSizes(builder, loc, input);
549 Value init = tensor::EmptyOp::create(builder, loc, initShape, elementType);
550
551 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
552 utils::IteratorType::parallel);
553 auto channelAxisAffineMap = AffineMap::get(
554 inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
555 SmallVector<AffineMap> indexingMaps{
556 builder.getMultiDimIdentityMap(inputRank), channelAxisAffineMap,
557 channelAxisAffineMap, builder.getMultiDimIdentityMap(inputRank)};
558 auto result = linalg::GenericOp::create(
559 builder, loc,
560 init.getType(), // resultType
561 ValueRange{input, scales, zeroPoints}, // inputs
562 ValueRange{init}, // outputs
563 indexingMaps, iteratorTypes,
564 [&](OpBuilder &builder, Location loc, ValueRange args) {
565 assert(args.size() == 4);
566 auto input = args[0];
567 auto scale = args[1];
568 auto zeroPoint = args[2];
569
570 auto result =
571 convertRanked(builder, loc, op, input, {}, scale,
572 zeroPoint, quantizedType);
573
574 linalg::YieldOp::create(builder, loc, result);
575 })
576 .getResult(0);
577
578 return result;
579}
580
581// Convert an operation using per-channel quantization.
582//
583// - op
584// 'quant.dcast' or 'quant.qcast' op.
585//
586// - input
587// Scalar, ranked tensor, or unranked tensor.
588//
589// - quantizedType
590// Per-channel quantized type.
591//
592Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
593 Value input,
594 UniformQuantizedPerAxisType quantizedType) {
595 // Flatten unranked tensor into a 3D ranked tensor if necessary
596 bool isUnranked = isa<UnrankedTensorType>(input.getType());
597 int64_t channelAxis = quantizedType.getQuantizedDimension();
598 int64_t channelAxisSize = (int64_t)quantizedType.getScales().size();
599 Value inputShape;
600 if (isUnranked) {
601 std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
602 builder, loc, input, channelAxis, channelAxisSize);
603 channelAxis = 1;
604 }
605
606 // Work on a ranked tensor
607 auto result = convertPerChannelRanked(builder, loc, op, input, quantizedType,
608 channelAxis);
609
610 // Restore original tensor shape if unranked
611 if (isUnranked)
612 result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
613
614 return result;
615}
616
617// Convert an operation using sub-channel quantization.
618//
619// - op
620// 'quant.dcast' or 'quant.qcast' op.
621//
622// - input
623// Scalar, ranked tensor.
624//
625// - quantizedType
626// Sub-channel quantized type.
627//
628Value convertSubChannel(OpBuilder &builder, Location loc, Operation *op,
629 Value input,
630 UniformQuantizedSubChannelType quantizedType) {
631 auto *context = builder.getContext();
632
633 auto inputType = cast<RankedTensorType>(input.getType());
634 auto inputRank = inputType.getRank();
635
636 auto scales = materializeSubChannelScales(builder, loc, quantizedType);
637 auto zeroPoints =
638 materializeSubChannelZeroPoints(builder, loc, quantizedType);
639
640 auto elementType = isa<FloatType>(inputType.getElementType())
641 ? quantizedType.getStorageType()
642 : quantizedType.getExpressedType();
643 auto initShape = tensor::getMixedSizes(builder, loc, input);
644 Value init = tensor::EmptyOp::create(builder, loc, initShape, elementType);
645
646 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
647 utils::IteratorType::parallel);
648 const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo =
649 quantizedType.getBlockSizeInfo();
650 SmallVector<AffineExpr> affineExprs(inputRank,
651 builder.getAffineConstantExpr(0));
652 for (auto [quantizedDimension, blockSize] : blockSizeInfo) {
653 affineExprs[quantizedDimension] =
654 builder.getAffineDimExpr(quantizedDimension).floorDiv(blockSize);
655 }
656 auto affineMap = AffineMap::get(inputRank, 0, affineExprs, context);
657 SmallVector<AffineMap> indexingMaps{
658 builder.getMultiDimIdentityMap(inputRank), affineMap, affineMap,
659 builder.getMultiDimIdentityMap(inputRank)};
660 auto result = linalg::GenericOp::create(
661 builder, loc,
662 init.getType(), // resultType
663 ValueRange{input, scales, zeroPoints}, // inputs
664 ValueRange{init}, // outputs
665 indexingMaps, iteratorTypes,
666 [&](OpBuilder &builder, Location loc, ValueRange args) {
667 assert(args.size() == 4);
668 auto input = args[0];
669 auto scale = args[1];
670 auto zeroPoint = args[2];
671
672 auto result =
673 convertRanked(builder, loc, op, input, {}, scale,
674 zeroPoint, quantizedType);
675
676 linalg::YieldOp::create(builder, loc, result);
677 })
678 .getResult(0);
679
680 return result;
681}
682
683// Convert a quantization operation.
684//
685// - op
686// 'quant.dcast' or 'quant.qcast' op.
687//
688// - input
689// Scalar, ranked tensor, or unranked tensor. The element type matches
690// the storage type (quant.dcast) or expressed type (quant.qcast) of
691// 'quantizedType'.
692//
693// - quantizedType
694// Per-layer or per-channel quantized type.
695//
696Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
697 Value input, Type quantizedType) {
698 if (auto uniformQuantizedType = dyn_cast<UniformQuantizedType>(quantizedType))
699 return convertPerLayer(builder, loc, op, input, uniformQuantizedType);
700
701 if (auto uniformQuantizedPerAxisType =
702 dyn_cast<UniformQuantizedPerAxisType>(quantizedType))
703 return convertPerChannel(builder, loc, op, input,
704 uniformQuantizedPerAxisType);
705
706 if (auto uniformQuantizedSubChannelType =
707 dyn_cast<UniformQuantizedSubChannelType>(quantizedType))
708 return convertSubChannel(builder, loc, op, input,
709 uniformQuantizedSubChannelType);
710
711 llvm_unreachable("unexpected quantized type");
712}
713
714// Lowering pattern for 'quant.dcast'
715struct DequantizeCastOpConversion
716 : public OpConversionPattern<quant::DequantizeCastOp> {
717 using OpConversionPattern<quant::DequantizeCastOp>::OpConversionPattern;
718
719 LogicalResult
720 matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor,
721 ConversionPatternRewriter &rewriter) const override {
722 auto loc = op.getLoc();
723 auto input = op.getInput();
724 auto quantizedType =
725 cast<QuantizedType>(getScalarType(op.getInput().getType()));
726
727 // Convert quantized input to storage type
728 auto storageScalarOrTensorType =
729 getScalarOrTensorType(quantizedType.getStorageType(), input.getType());
730 input = quant::StorageCastOp::create(rewriter, loc,
731 storageScalarOrTensorType, input);
732
733 auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
734
735 rewriter.replaceOp(op, result);
736 return success();
737 }
738};
739
740// Lowering pattern for 'quant.qcast'
741struct QuantizeCastOpConversion
742 : public OpConversionPattern<quant::QuantizeCastOp> {
743 using OpConversionPattern<quant::QuantizeCastOp>::OpConversionPattern;
744
745 LogicalResult
746 matchAndRewrite(quant::QuantizeCastOp op, OpAdaptor adaptor,
747 ConversionPatternRewriter &rewriter) const override {
748 auto loc = op.getLoc();
749 auto input = op.getInput();
750 auto quantizedType = getScalarType(op.getResult().getType());
751
752 // Flatten unranked tensor input
753 auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
754
755 // Cast stored value to result quantized value
756 rewriter.replaceOpWithNewOp<quant::StorageCastOp>(
757 op, op.getResult().getType(), result);
758 return success();
759 }
760};
761
762struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
763 void runOnOperation() override {
764 RewritePatternSet patterns(&getContext());
766
767 ConversionTarget target(getContext());
768 target.addLegalOp<quant::StorageCastOp>();
769 target.addIllegalDialect<quant::QuantDialect>();
770 target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
771 shape::ShapeDialect, tensor::TensorDialect>();
772
773 if (failed(applyPartialConversion(getOperation(), target,
774 std::move(patterns))))
775 signalPassFailure();
776 }
777};
778
779} // namespace
780
782 patterns.add<DequantizeCastOpConversion, QuantizeCastOpConversion>(
783 patterns.getContext());
784}
785
786} // namespace quant
787} // namespace mlir
return success()
b getContext())
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:262
Represents per-axis (also known as per-channel quantization).
Definition QuantTypes.h:324
Represents sub-channel (also known as blockwise quantization).
Definition QuantTypes.h:409
void populateLowerQuantOpsPatterns(RewritePatternSet &patterns)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Definition Shape.cpp:40
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:68
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442