MLIR  21.0.0git
LowerQuantOps.cpp
Go to the documentation of this file.
1 //===- LowerQuantOps.cpp - Lower 'quant' dialect ops ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Transforms `quant.dcast` and `quant.qcast` into lower-level ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
24 
25 namespace mlir {
26 namespace quant {
27 
28 #define GEN_PASS_DEF_LOWERQUANTOPS
29 #include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
30 
31 namespace {
32 
33 // If 'inputType' is a tensor, return its element type. If it is a scalar,
34 // return it as is.
35 Type getScalarType(Type inputType) {
36  if (auto tensorType = dyn_cast<TensorType>(inputType))
37  return tensorType.getElementType();
38  return inputType;
39 }
40 
41 // Return the shape of an input value as a list of attributes (static
42 // dimensions) and values (dynamic dimensions). If 'input' is a scalar, an empty
43 // list is returned. If 'input' is a tensor, its shape is returned.
44 SmallVector<OpFoldResult> getScalarOrTensorShape(OpBuilder &builder,
45  Location loc, Value input) {
46  if (isa<TensorType>(input.getType()))
47  return tensor::getMixedSizes(builder, loc, input);
48  return {};
49 }
50 
51 // If 'referenceType' is a scalar, return 'elementType' as is. If
52 // 'referenceType' is a tensor, return another tensor with the same shape and
53 // elements of type 'elementType'.
54 Type getScalarOrTensorType(Type elementType, Type referenceType) {
55  if (auto tensorType = dyn_cast<TensorType>(referenceType))
56  return tensorType.clone(elementType);
57  return elementType;
58 }
59 
60 // Return a constant with the given value. If 'referenceType' is a tensor, a
61 // tensor splat of shape 'referenceShape' is returned. If 'referenceType' is a
62 // scalar, 'referenceShape' is ignored and a scalar constant is returned.
63 Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar,
64  Type referenceType,
65  ArrayRef<OpFoldResult> referenceShape) {
66  // If the result type is a scalar, return the unmodified scalar constant.
67  auto tensorType = dyn_cast<TensorType>(referenceType);
68  if (!tensorType) {
69  assert(referenceShape.empty());
70  return scalar;
71  }
72 
73  // Create tensor splat
74  auto tensorConstant =
75  builder.create<tensor::SplatOp>(loc, scalar, referenceShape);
76  return tensorConstant;
77 }
78 
79 // Reshape an unranked tensor into a 1D ranked tensor.
80 //
81 // - input
82 // Unranked tensor.
83 //
84 // Return values:
85 //
86 // - flatInput
87 // 1D ranked, dynamically shaped tensor.
88 //
89 // - inputShape
90 // 1D extent tensor containing the shape of the original unranked input.
91 //
92 std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
93  Value input) {
94  // Get unranked input shape and total size
95  auto *context = builder.getContext();
96  auto shapeType = shape::getExtentTensorType(context);
97  auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
98  Value inputSize = builder.create<shape::NumElementsOp>(
99  loc, builder.getIndexType(), inputShape);
100 
101  // Turn input size into 1D tensor
102  auto flatShapeType = shape::getExtentTensorType(context, 1);
103  auto flatInputShape =
104  builder.create<tensor::FromElementsOp>(loc, flatShapeType, inputSize);
105 
106  // Reshape input tensor into 1D
107  auto inputType = cast<UnrankedTensorType>(input.getType());
108  auto elementType = inputType.getElementType();
109  auto flatInputType =
110  RankedTensorType::get({ShapedType::kDynamic}, elementType);
111  auto flatInput = builder.create<tensor::ReshapeOp>(loc, flatInputType, input,
112  flatInputShape);
113  return std::make_pair(flatInput, inputShape);
114 }
115 
116 // Reshape an unranked tensor into a 3D ranked tensor where the central
117 // dimension of the result tensor corresponds to dimension 'axis' of the input
118 // tensor.
119 //
120 // - input
121 // Unranked tensor.
122 //
123 // - axis
124 // Index of the input dimension around which other input dimiensions will be
125 // collapsed.
126 //
127 // - axisSize
128 // Size of input dimension 'axis'.
129 //
130 // Return values:
131 //
132 // - flatInput
133 // 3D ranked tensor of shape [?, axisSize, ?].
134 //
135 // - inputShape
136 // 1D extent tensor containing the shape of the original unranked input.
137 //
138 std::pair<Value, Value>
139 flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input,
140  int64_t axis, int64_t axisSize) {
141  // Get full tensor shape
142  auto *context = builder.getContext();
143  auto indexType = builder.getIndexType();
144  auto shapeType = shape::getExtentTensorType(context);
145  auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
146 
147  // Get shape and sizes on left and right of axis
148  auto axisValue = builder.create<arith::ConstantIndexOp>(loc, axis);
149  auto axisNextValue = builder.create<arith::ConstantIndexOp>(loc, axis + 1);
150  auto shapeLeft =
151  builder
152  .create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
153  inputShape, axisValue)
154  .getResult(0);
155  auto sizeLeft =
156  builder.create<shape::NumElementsOp>(loc, indexType, shapeLeft);
157  auto shapeRight =
158  builder
159  .create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType},
160  inputShape, axisNextValue)
161  .getResult(1);
162  auto sizeRight =
163  builder.create<shape::NumElementsOp>(loc, indexType, shapeRight);
164 
165  // Compute flat input shape as a 3-element 1D tensor
166  auto axisSizeValue = builder.create<arith::ConstantIndexOp>(loc, axisSize);
167  auto flatShapeType = shape::getExtentTensorType(context, 3);
168  auto flatInputShape = builder.create<tensor::FromElementsOp>(
169  loc, flatShapeType, ValueRange{sizeLeft, axisSizeValue, sizeRight});
170 
171  // Reshape input to 3D tensor
172  auto inputType = cast<UnrankedTensorType>(input.getType());
173  auto elementType = inputType.getElementType();
174  auto flatInputType = RankedTensorType::get(
175  {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
176  auto flatInput = builder.create<tensor::ReshapeOp>(loc, flatInputType, input,
177  flatInputShape);
178 
179  return std::make_pair(flatInput, inputShape);
180 }
181 
182 // Reshape an input tensor into its original unranked shape.
183 //
184 // - input
185 // Ranked tensor.
186 //
187 // - inputShape
188 // 1D extent tensor.
189 //
190 Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
191  Value inputShape) {
192  auto inputType = cast<RankedTensorType>(input.getType());
193  auto elementType = inputType.getElementType();
194  auto unrankedType = UnrankedTensorType::get(elementType);
195  return builder.create<tensor::ReshapeOp>(loc, unrankedType, input,
196  inputShape);
197 }
198 
199 // Create a tensor constant containing all scales in a per-channel quantized
200 // type. Example:
201 //
202 // !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
203 //
204 // produces
205 //
206 // %cst = arith.constant dense<[2.0, 3.0]> : tensor<2xf32>
207 //
208 Value materializePerChannelScales(OpBuilder &builder, Location loc,
209  UniformQuantizedPerAxisType quantizedType) {
210  auto scales = quantizedType.getScales();
211  auto expressedType = quantizedType.getExpressedType();
212  auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute {
213  return builder.getFloatAttr(expressedType, scale);
214  });
215  auto tensorType =
216  RankedTensorType::get({(int64_t)scales.size()}, expressedType);
217  auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
218  return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
219 }
220 
221 // Create a tensor constant containing all zero points in a per-channel
222 // quantized type. Example:
223 //
224 // !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
225 //
226 // produces
227 //
228 // %cst = arith.constant dense<[10, 20]> : tensor<2xi8>
229 //
230 Value materializePerChannelZeroPoints(
231  OpBuilder &builder, Location loc,
232  UniformQuantizedPerAxisType quantizedType) {
233  auto zeroPoints = quantizedType.getZeroPoints();
234  auto storageType = quantizedType.getStorageType();
235  auto zeroPointAttrs =
236  llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> Attribute {
237  return builder.getIntegerAttr(storageType, zeroPoint);
238  });
239  auto tensorType =
240  RankedTensorType::get({(int64_t)zeroPoints.size()}, storageType);
241  auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
242  return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
243 }
244 
245 // Create a tensor constant containing all scales in a sub-channel quantized
246 // type. Example:
247 //
248 // !quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
249 //
250 // produces
251 //
252 // %cst = arith.constant dense<[[2.0, 3.0], [4.0, 5.0]]> : tensor<2x2xf32>
253 //
254 Value materializeSubChannelScales(
255  OpBuilder &builder, Location loc,
256  UniformQuantizedSubChannelType quantizedType) {
257  auto scales = quantizedType.getScales();
258  auto expressedType = quantizedType.getExpressedType();
259  auto scaleAttrs = llvm::map_to_vector(
260  scales.getValues<APFloat>(), [&](APFloat scale) -> Attribute {
261  return builder.getFloatAttr(expressedType, scale);
262  });
263  auto tensorType =
264  RankedTensorType::get(scales.getType().getShape(), expressedType);
265  auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
266  return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
267 }
268 
269 // Create a tensor constant containing all zero points in a sub-channel
270 // quantized type. Example:
271 //
272 // !quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
273 //
274 // produces
275 //
276 // %cst = arith.constant dense<[[10, 20], [30, 40]]> : tensor<2x2xi8>
277 //
278 Value materializeSubChannelZeroPoints(
279  OpBuilder &builder, Location loc,
280  UniformQuantizedSubChannelType quantizedType) {
281  auto zeroPoints = quantizedType.getZeroPoints();
282  auto storageType = quantizedType.getStorageType();
283  auto zeroPointAttrs = llvm::map_to_vector(
284  zeroPoints.getValues<APInt>(), [&](APInt zeroPoint) -> Attribute {
285  return builder.getIntegerAttr(storageType, zeroPoint);
286  });
287  auto tensorType =
288  RankedTensorType::get(zeroPoints.getType().getShape(), storageType);
289  auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
290  return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
291 }
292 
293 // Clamp the given scalar or tensor input using the storage bounds encoded in
294 // the given quantized type, if present.
295 //
296 // - input
297 // Scalar or ranked tensor input. The element type must match the storage type
298 // of 'quantizedType'.
299 //
300 // - inputShape
301 // If 'input' is a tensor, combination of attributes/values representing its
302 // static/dynamic dimensions. If 'input' is a scalar, empty list.
303 //
304 // - quantizedType
305 // Per-axis or per-channel quantized type.
306 Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
307  ArrayRef<OpFoldResult> inputShape,
308  QuantizedType quantizedType) {
309  // If quantized type does not narrow down the storage type range, there is
310  // nothing to do.
311  if (!quantizedType.hasStorageTypeBounds())
312  return input;
313 
314  // Materialize bounds
315  auto inputType = input.getType();
316  auto storageType = quantizedType.getStorageType();
317  auto storageMinScalar = builder.create<arith::ConstantIntOp>(
318  loc, quantizedType.getStorageTypeMin(), storageType);
319  auto storageMaxScalar = builder.create<arith::ConstantIntOp>(
320  loc, quantizedType.getStorageTypeMax(), storageType);
321  auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar,
322  inputType, inputShape);
323  auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar,
324  inputType, inputShape);
325 
326  // Clamp
327  if (quantizedType.isSigned()) {
328  input = builder.create<arith::MaxSIOp>(loc, input, storageMin);
329  input = builder.create<arith::MinSIOp>(loc, input, storageMax);
330  } else {
331  input = builder.create<arith::MaxUIOp>(loc, input, storageMin);
332  input = builder.create<arith::MinUIOp>(loc, input, storageMax);
333  }
334  return input;
335 }
336 
337 // Emit op 'arith.fptosi' or 'arith.fptoui'.
338 Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
339  Type resultType, bool isSigned) {
340  if (isSigned)
341  return builder.create<arith::FPToSIOp>(loc, resultType, input);
342  return builder.create<arith::FPToUIOp>(loc, resultType, input);
343 }
344 
345 // Emit op 'arith.sitofp' or 'arith.uitofp'.
346 Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
347  Type resultType, bool isSigned) {
348  if (isSigned)
349  return builder.create<arith::SIToFPOp>(loc, resultType, input);
350  return builder.create<arith::UIToFPOp>(loc, resultType, input);
351 }
352 
353 // Quantize a scalar or ranked tensor value. The stored value is clamped using
354 // the storage bounds encoded in the given quantized type.
355 //
356 // See function 'convertRanked()' below for a description of the arguments.
357 Value quantizeValue(OpBuilder &builder, Location loc, Value input,
358  ArrayRef<OpFoldResult> inputShape, Value scale,
359  Value zeroPoint, QuantizedType quantizedType) {
360  // Convert scale to tensor if necessary
361  auto inputType = input.getType();
362  scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
363 
364  // Scale input
365  auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
366 
367  // Skip unnecessary computations if no zero point is given
368  Value storedValueFloat = scaledValue;
369  if (!matchPattern(zeroPoint, m_Zero())) {
370  // Convert zero point to tensor if necessary
371  zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
372  inputShape);
373 
374  // Convert zero point from storage to expressed type
375  zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
376  quantizedType.isSigned());
377 
378  // Add zero point to stored value
379  storedValueFloat =
380  builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint);
381  }
382 
383  // Convert stored value to storage type
384  auto storageScalarOrTensorType =
385  getScalarOrTensorType(quantizedType.getStorageType(), inputType);
386  auto storedValueInt = convertFloatToInteger(builder, loc, storedValueFloat,
387  storageScalarOrTensorType,
388  quantizedType.isSigned());
389 
390  // Clamp stored value it if the storage type is bound
391  auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt,
392  inputShape, quantizedType);
393  return storedValueClamped;
394 }
395 
396 // Dequantize a scalar or ranked tensor input.
397 //
398 // See function 'convertRanked()' below for a description of the arguments.
399 Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
400  ArrayRef<OpFoldResult> inputShape, Value scale,
401  Value zeroPoint, QuantizedType quantizedType) {
402  // Convert scale to tensor if necessary
403  auto inputType = input.getType();
404  scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
405 
406  // Convert stored value to float
407  auto result = convertIntegerToFloat(builder, loc, input, scale.getType(),
408  quantizedType.isSigned());
409 
410  // Skip unnecessary computations if no zero point is given
411  if (!matchPattern(zeroPoint, m_Zero())) {
412  // Convert zero point to tensor if necessary
413  zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
414  inputShape);
415 
416  // Convert zero point from storage to expressed type
417  zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
418  quantizedType.isSigned());
419 
420  // Subtract zero point to stored value
421  result = builder.create<arith::SubFOp>(loc, result, zeroPoint);
422  }
423 
424  // Multiply by scale
425  result = builder.create<arith::MulFOp>(loc, result, scale);
426  return result;
427 }
428 
429 // Convert a scalar or ranked tensor input with the given scale and zero point
430 // values.
431 //
432 // - input
433 // Scalar or ranked tensor value.
434 //
435 // - inputShape
436 // If 'input' is a tensor, combination or attributes/values representing its
437 // static/dynamic dimensions. If 'input' is a scalar, empty list.
438 //
439 // - scale
440 // Scale as a floating-point scalar value.
441 //
442 // - zeroPoint
443 // Zero point as an integer scalar value.
444 //
445 // - quantizedType
446 // Scalar quantized type of the result ('quant.qcast') or of the input
447 // ('quant.dcast').
448 //
449 Value convertRanked(OpBuilder &builder, Location loc, Operation *op,
450  Value input, ArrayRef<OpFoldResult> inputShape, Value scale,
451  Value zeroPoint, QuantizedType quantizedType) {
452  if (isa<QuantizeCastOp>(op))
453  return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
454  quantizedType);
455  if (isa<DequantizeCastOp>(op))
456  return dequantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
457  quantizedType);
458  llvm_unreachable("unexpected quant op");
459 }
460 
461 // Convert an operation using per-layer quantization with a scalar or ranked
462 // tensor input.
463 //
464 // - op
465 // 'quant.dcast' or 'quant.qcast' op.
466 //
467 // - input
468 // Scalar or ranked tensor.
469 //
470 // - quantizedType
471 // Per-layer quantized type.
472 //
473 Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
474  Value input, UniformQuantizedType quantizedType) {
475  // Create scale and zero point constants
476  auto expressedType = quantizedType.getExpressedType();
477  auto storageType = quantizedType.getStorageType();
478  auto scaleAttr =
479  builder.getFloatAttr(expressedType, quantizedType.getScale());
480  auto scale = builder.create<arith::ConstantOp>(loc, expressedType, scaleAttr);
481  auto zeroPointAttr =
482  builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
483  auto zeroPoint =
484  builder.create<arith::ConstantOp>(loc, storageType, zeroPointAttr);
485 
486  auto inputShape = getScalarOrTensorShape(builder, loc, input);
487  return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint,
488  quantizedType);
489 }
490 
491 // Convert an operation using per-layer quantization.
492 //
493 // - op
494 // 'quant.dcast' or 'quant.qcast' op.
495 //
496 // - input
497 // Scalar, ranked tensor, or unranked tensor.
498 //
499 // - quantizedType
500 // Per-layer quantized type.
501 //
502 Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op,
503  Value input, UniformQuantizedType quantizedType) {
504  // Flatten input if unranked
505  bool isUnranked = isa<UnrankedTensorType>(input.getType());
506  Value inputShape;
507  if (isUnranked)
508  std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
509 
510  // Process ranked tensor
511  auto result = convertPerLayerRanked(builder, loc, op, input, quantizedType);
512 
513  // Restore original shape if unranked
514  if (isUnranked)
515  result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
516 
517  return result;
518 }
519 
520 // Convert an operation using per-channel quantization and a scalar or ranked
521 // tensor as an input.
522 //
523 // - op
524 // 'quant.dcast' or 'quant.qcast' op.
525 //
526 // - input
527 // Scalar or ranked tensor.
528 //
529 // - quantizedType
530 // Per-channel quantized type.
531 //
532 Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
533  Value input,
534  UniformQuantizedPerAxisType quantizedType,
535  int64_t channelAxis) {
536  auto *context = builder.getContext();
537 
538  auto inputType = cast<RankedTensorType>(input.getType());
539  auto inputRank = inputType.getRank();
540 
541  auto scales = materializePerChannelScales(builder, loc, quantizedType);
542  auto zeroPoints =
543  materializePerChannelZeroPoints(builder, loc, quantizedType);
544 
545  auto elementType = isa<FloatType>(inputType.getElementType())
546  ? quantizedType.getStorageType()
547  : quantizedType.getExpressedType();
548  auto initShape = tensor::getMixedSizes(builder, loc, input);
549  Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType);
550 
551  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
552  utils::IteratorType::parallel);
553  auto channelAxisAffineMap = AffineMap::get(
554  inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
555  SmallVector<AffineMap> indexingMaps{
556  builder.getMultiDimIdentityMap(inputRank), channelAxisAffineMap,
557  channelAxisAffineMap, builder.getMultiDimIdentityMap(inputRank)};
558  auto result = builder
559  .create<linalg::GenericOp>(
560  loc,
561  init.getType(), // resultType
562  ValueRange{input, scales, zeroPoints}, // inputs
563  ValueRange{init}, // outputs
564  indexingMaps, iteratorTypes,
565  [&](OpBuilder &builder, Location loc, ValueRange args) {
566  assert(args.size() == 4);
567  auto input = args[0];
568  auto scale = args[1];
569  auto zeroPoint = args[2];
570 
571  auto result =
572  convertRanked(builder, loc, op, input, {}, scale,
573  zeroPoint, quantizedType);
574 
575  builder.create<linalg::YieldOp>(loc, result);
576  })
577  .getResult(0);
578 
579  return result;
580 }
581 
582 // Convert an operation using per-channel quantization.
583 //
584 // - op
585 // 'quant.dcast' or 'quant.qcast' op.
586 //
587 // - input
588 // Scalar, ranked tensor, or unranked tensor.
589 //
590 // - quantizedType
591 // Per-channel quantized type.
592 //
593 Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
594  Value input,
595  UniformQuantizedPerAxisType quantizedType) {
596  // Flatten unranked tensor into a 3D ranked tensor if necessary
597  bool isUnranked = isa<UnrankedTensorType>(input.getType());
598  int64_t channelAxis = quantizedType.getQuantizedDimension();
599  int64_t channelAxisSize = (int64_t)quantizedType.getScales().size();
600  Value inputShape;
601  if (isUnranked) {
602  std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
603  builder, loc, input, channelAxis, channelAxisSize);
604  channelAxis = 1;
605  }
606 
607  // Work on a ranked tensor
608  auto result = convertPerChannelRanked(builder, loc, op, input, quantizedType,
609  channelAxis);
610 
611  // Restore original tensor shape if unranked
612  if (isUnranked)
613  result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
614 
615  return result;
616 }
617 
618 // Convert an operation using sub-channel quantization.
619 //
620 // - op
621 // 'quant.dcast' or 'quant.qcast' op.
622 //
623 // - input
624 // Scalar, ranked tensor.
625 //
626 // - quantizedType
627 // Sub-channel quantized type.
628 //
629 Value convertSubChannel(OpBuilder &builder, Location loc, Operation *op,
630  Value input,
631  UniformQuantizedSubChannelType quantizedType) {
632  auto *context = builder.getContext();
633 
634  auto inputType = cast<RankedTensorType>(input.getType());
635  auto inputRank = inputType.getRank();
636 
637  auto scales = materializeSubChannelScales(builder, loc, quantizedType);
638  auto zeroPoints =
639  materializeSubChannelZeroPoints(builder, loc, quantizedType);
640 
641  auto elementType = isa<FloatType>(inputType.getElementType())
642  ? quantizedType.getStorageType()
643  : quantizedType.getExpressedType();
644  auto initShape = tensor::getMixedSizes(builder, loc, input);
645  Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType);
646 
647  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
648  utils::IteratorType::parallel);
649  const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo =
650  quantizedType.getBlockSizeInfo();
651  SmallVector<AffineExpr> affineExprs(inputRank,
652  builder.getAffineConstantExpr(0));
653  for (auto [quantizedDimension, blockSize] : blockSizeInfo) {
654  affineExprs[quantizedDimension] =
655  builder.getAffineDimExpr(quantizedDimension).floorDiv(blockSize);
656  }
657  auto affineMap = AffineMap::get(inputRank, 0, affineExprs, context);
658  SmallVector<AffineMap> indexingMaps{
659  builder.getMultiDimIdentityMap(inputRank), affineMap, affineMap,
660  builder.getMultiDimIdentityMap(inputRank)};
661  auto result = builder
662  .create<linalg::GenericOp>(
663  loc,
664  init.getType(), // resultType
665  ValueRange{input, scales, zeroPoints}, // inputs
666  ValueRange{init}, // outputs
667  indexingMaps, iteratorTypes,
668  [&](OpBuilder &builder, Location loc, ValueRange args) {
669  assert(args.size() == 4);
670  auto input = args[0];
671  auto scale = args[1];
672  auto zeroPoint = args[2];
673 
674  auto result =
675  convertRanked(builder, loc, op, input, {}, scale,
676  zeroPoint, quantizedType);
677 
678  builder.create<linalg::YieldOp>(loc, result);
679  })
680  .getResult(0);
681 
682  return result;
683 }
684 
685 // Convert a quantization operation.
686 //
687 // - op
688 // 'quant.dcast' or 'quant.qcast' op.
689 //
690 // - input
691 // Scalar, ranked tensor, or unranked tensor. The element type matches
692 // the storage type (quant.dcast) or expressed type (quant.qcast) of
693 // 'quantizedType'.
694 //
695 // - quantizedType
696 // Per-layer or per-channel quantized type.
697 //
698 Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
699  Value input, Type quantizedType) {
700  if (auto uniformQuantizedType = dyn_cast<UniformQuantizedType>(quantizedType))
701  return convertPerLayer(builder, loc, op, input, uniformQuantizedType);
702 
703  if (auto uniformQuantizedPerAxisType =
704  dyn_cast<UniformQuantizedPerAxisType>(quantizedType))
705  return convertPerChannel(builder, loc, op, input,
706  uniformQuantizedPerAxisType);
707 
708  if (auto uniformQuantizedSubChannelType =
709  dyn_cast<UniformQuantizedSubChannelType>(quantizedType))
710  return convertSubChannel(builder, loc, op, input,
711  uniformQuantizedSubChannelType);
712 
713  llvm_unreachable("unexpected quantized type");
714 }
715 
716 // Lowering pattern for 'quant.dcast'
717 struct DequantizeCastOpConversion
718  : public OpConversionPattern<quant::DequantizeCastOp> {
720 
721  LogicalResult
722  matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor,
723  ConversionPatternRewriter &rewriter) const override {
724  auto loc = op.getLoc();
725  auto input = op.getInput();
726  auto quantizedType =
727  cast<QuantizedType>(getScalarType(op.getInput().getType()));
728 
729  // Convert quantized input to storage type
730  auto storageScalarOrTensorType =
731  getScalarOrTensorType(quantizedType.getStorageType(), input.getType());
732  input = rewriter.create<quant::StorageCastOp>(
733  loc, storageScalarOrTensorType, input);
734 
735  auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
736 
737  rewriter.replaceOp(op, result);
738  return success();
739  }
740 };
741 
742 // Lowering pattern for 'quant.qcast'
743 struct QuantizeCastOpConversion
744  : public OpConversionPattern<quant::QuantizeCastOp> {
746 
747  LogicalResult
748  matchAndRewrite(quant::QuantizeCastOp op, OpAdaptor adaptor,
749  ConversionPatternRewriter &rewriter) const override {
750  auto loc = op.getLoc();
751  auto input = op.getInput();
752  auto quantizedType = getScalarType(op.getResult().getType());
753 
754  // Flatten unranked tensor input
755  auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
756 
757  // Cast stored value to result quantized value
758  rewriter.replaceOpWithNewOp<quant::StorageCastOp>(
759  op, op.getResult().getType(), result);
760  return success();
761  }
762 };
763 
764 struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
765  void runOnOperation() override {
766  RewritePatternSet patterns(&getContext());
768 
769  ConversionTarget target(getContext());
770  target.addLegalOp<quant::StorageCastOp>();
771  target.addIllegalDialect<quant::QuantDialect>();
772  target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
773  shape::ShapeDialect, tensor::TensorDialect>();
774 
775  if (failed(applyPartialConversion(getOperation(), target,
776  std::move(patterns))))
777  signalPassFailure();
778  }
779 };
780 
781 } // namespace
782 
784  patterns.add<DequantizeCastOpConversion, QuantizeCastOpConversion>(
785  patterns.getContext());
786 }
787 
788 } // namespace quant
789 } // namespace mlir
static MLIRContext * getContext(OpFoldResult val)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
void populateLowerQuantOpsPatterns(RewritePatternSet &patterns)
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Definition: Shape.cpp:41
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.