MLIR  22.0.0git
LowerQuantOps.cpp
Go to the documentation of this file.
1 //===- LowerQuantOps.cpp - Lower 'quant' dialect ops ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Transforms `quant.dcast` and `quant.qcast` into lower-level ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
24 
25 namespace mlir {
26 namespace quant {
27 
28 #define GEN_PASS_DEF_LOWERQUANTOPS
29 #include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
30 
31 namespace {
32 
33 // If 'inputType' is a tensor, return its element type. If it is a scalar,
34 // return it as is.
35 Type getScalarType(Type inputType) {
36  if (auto tensorType = dyn_cast<TensorType>(inputType))
37  return tensorType.getElementType();
38  return inputType;
39 }
40 
41 // Return the shape of an input value as a list of attributes (static
42 // dimensions) and values (dynamic dimensions). If 'input' is a scalar, an empty
43 // list is returned. If 'input' is a tensor, its shape is returned.
44 SmallVector<OpFoldResult> getScalarOrTensorShape(OpBuilder &builder,
45  Location loc, Value input) {
46  if (isa<TensorType>(input.getType()))
47  return tensor::getMixedSizes(builder, loc, input);
48  return {};
49 }
50 
51 // If 'referenceType' is a scalar, return 'elementType' as is. If
52 // 'referenceType' is a tensor, return another tensor with the same shape and
53 // elements of type 'elementType'.
54 Type getScalarOrTensorType(Type elementType, Type referenceType) {
55  if (auto tensorType = dyn_cast<TensorType>(referenceType))
56  return tensorType.clone(elementType);
57  return elementType;
58 }
59 
60 // Return a constant with the given value. If 'referenceType' is a tensor, a
61 // tensor splat of shape 'referenceShape' is returned. If 'referenceType' is a
62 // scalar, 'referenceShape' is ignored and a scalar constant is returned.
63 Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar,
64  Type referenceType,
65  ArrayRef<OpFoldResult> referenceShape) {
66  // If the result type is a scalar, return the unmodified scalar constant.
67  auto tensorType = dyn_cast<TensorType>(referenceType);
68  if (!tensorType) {
69  assert(referenceShape.empty());
70  return scalar;
71  }
72 
73  // Create tensor splat
74  auto tensorConstant =
75  tensor::SplatOp::create(builder, loc, scalar, referenceShape);
76  return tensorConstant;
77 }
78 
79 // Reshape an unranked tensor into a 1D ranked tensor.
80 //
81 // - input
82 // Unranked tensor.
83 //
84 // Return values:
85 //
86 // - flatInput
87 // 1D ranked, dynamically shaped tensor.
88 //
89 // - inputShape
90 // 1D extent tensor containing the shape of the original unranked input.
91 //
92 std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
93  Value input) {
94  // Get unranked input shape and total size
95  auto *context = builder.getContext();
96  auto shapeType = shape::getExtentTensorType(context);
97  auto inputShape = shape::ShapeOfOp::create(builder, loc, shapeType, input);
98  Value inputSize = shape::NumElementsOp::create(
99  builder, loc, builder.getIndexType(), inputShape);
100 
101  // Turn input size into 1D tensor
102  auto flatShapeType = shape::getExtentTensorType(context, 1);
103  auto flatInputShape =
104  tensor::FromElementsOp::create(builder, loc, flatShapeType, inputSize);
105 
106  // Reshape input tensor into 1D
107  auto inputType = cast<UnrankedTensorType>(input.getType());
108  auto elementType = inputType.getElementType();
109  auto flatInputType =
110  RankedTensorType::get({ShapedType::kDynamic}, elementType);
111  auto flatInput = tensor::ReshapeOp::create(builder, loc, flatInputType, input,
112  flatInputShape);
113  return std::make_pair(flatInput, inputShape);
114 }
115 
116 // Reshape an unranked tensor into a 3D ranked tensor where the central
117 // dimension of the result tensor corresponds to dimension 'axis' of the input
118 // tensor.
119 //
120 // - input
121 // Unranked tensor.
122 //
123 // - axis
124 // Index of the input dimension around which other input dimiensions will be
125 // collapsed.
126 //
127 // - axisSize
128 // Size of input dimension 'axis'.
129 //
130 // Return values:
131 //
132 // - flatInput
133 // 3D ranked tensor of shape [?, axisSize, ?].
134 //
135 // - inputShape
136 // 1D extent tensor containing the shape of the original unranked input.
137 //
138 std::pair<Value, Value>
139 flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input,
140  int64_t axis, int64_t axisSize) {
141  // Get full tensor shape
142  auto *context = builder.getContext();
143  auto indexType = builder.getIndexType();
144  auto shapeType = shape::getExtentTensorType(context);
145  auto inputShape = shape::ShapeOfOp::create(builder, loc, shapeType, input);
146 
147  // Get shape and sizes on left and right of axis
148  auto axisValue = arith::ConstantIndexOp::create(builder, loc, axis);
149  auto axisNextValue = arith::ConstantIndexOp::create(builder, loc, axis + 1);
150  auto shapeLeft =
151  shape::SplitAtOp::create(builder, loc, TypeRange{shapeType, shapeType},
152  inputShape, axisValue)
153  .getResult(0);
154  auto sizeLeft =
155  shape::NumElementsOp::create(builder, loc, indexType, shapeLeft);
156  auto shapeRight =
157  shape::SplitAtOp::create(builder, loc, TypeRange{shapeType, shapeType},
158  inputShape, axisNextValue)
159  .getResult(1);
160  auto sizeRight =
161  shape::NumElementsOp::create(builder, loc, indexType, shapeRight);
162 
163  // Compute flat input shape as a 3-element 1D tensor
164  auto axisSizeValue = arith::ConstantIndexOp::create(builder, loc, axisSize);
165  auto flatShapeType = shape::getExtentTensorType(context, 3);
166  auto flatInputShape = tensor::FromElementsOp::create(
167  builder, loc, flatShapeType,
168  ValueRange{sizeLeft, axisSizeValue, sizeRight});
169 
170  // Reshape input to 3D tensor
171  auto inputType = cast<UnrankedTensorType>(input.getType());
172  auto elementType = inputType.getElementType();
173  auto flatInputType = RankedTensorType::get(
174  {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
175  auto flatInput = tensor::ReshapeOp::create(builder, loc, flatInputType, input,
176  flatInputShape);
177 
178  return std::make_pair(flatInput, inputShape);
179 }
180 
181 // Reshape an input tensor into its original unranked shape.
182 //
183 // - input
184 // Ranked tensor.
185 //
186 // - inputShape
187 // 1D extent tensor.
188 //
189 Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
190  Value inputShape) {
191  auto inputType = cast<RankedTensorType>(input.getType());
192  auto elementType = inputType.getElementType();
193  auto unrankedType = UnrankedTensorType::get(elementType);
194  return tensor::ReshapeOp::create(builder, loc, unrankedType, input,
195  inputShape);
196 }
197 
198 // Create a tensor constant containing all scales in a per-channel quantized
199 // type. Example:
200 //
201 // !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
202 //
203 // produces
204 //
205 // %cst = arith.constant dense<[2.0, 3.0]> : tensor<2xf32>
206 //
207 Value materializePerChannelScales(OpBuilder &builder, Location loc,
208  UniformQuantizedPerAxisType quantizedType) {
209  auto scales = quantizedType.getScales();
210  auto expressedType = quantizedType.getExpressedType();
211  auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute {
212  return builder.getFloatAttr(expressedType, scale);
213  });
214  auto tensorType =
215  RankedTensorType::get({(int64_t)scales.size()}, expressedType);
216  auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
217  return arith::ConstantOp::create(builder, loc, tensorType, scalesAttr);
218 }
219 
220 // Create a tensor constant containing all zero points in a per-channel
221 // quantized type. Example:
222 //
223 // !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
224 //
225 // produces
226 //
227 // %cst = arith.constant dense<[10, 20]> : tensor<2xi8>
228 //
229 Value materializePerChannelZeroPoints(
230  OpBuilder &builder, Location loc,
231  UniformQuantizedPerAxisType quantizedType) {
232  auto zeroPoints = quantizedType.getZeroPoints();
233  auto storageType = quantizedType.getStorageType();
234  auto zeroPointAttrs =
235  llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> Attribute {
236  return builder.getIntegerAttr(storageType, zeroPoint);
237  });
238  auto tensorType =
239  RankedTensorType::get({(int64_t)zeroPoints.size()}, storageType);
240  auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
241  return arith::ConstantOp::create(builder, loc, tensorType, zeroPointsAttr);
242 }
243 
244 // Create a tensor constant containing all scales in a sub-channel quantized
245 // type. Example:
246 //
247 // !quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
248 //
249 // produces
250 //
251 // %cst = arith.constant dense<[[2.0, 3.0], [4.0, 5.0]]> : tensor<2x2xf32>
252 //
253 Value materializeSubChannelScales(
254  OpBuilder &builder, Location loc,
255  UniformQuantizedSubChannelType quantizedType) {
256  auto scales = quantizedType.getScales();
257  auto expressedType = quantizedType.getExpressedType();
258  auto scaleAttrs = llvm::map_to_vector(
259  scales.getValues<APFloat>(), [&](APFloat scale) -> Attribute {
260  return builder.getFloatAttr(expressedType, scale);
261  });
262  auto tensorType =
263  RankedTensorType::get(scales.getType().getShape(), expressedType);
264  auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
265  return arith::ConstantOp::create(builder, loc, tensorType, scalesAttr);
266 }
267 
268 // Create a tensor constant containing all zero points in a sub-channel
269 // quantized type. Example:
270 //
271 // !quant.uniform<i8:f32:{0:1,1:2}, {{2.0:10, 3.0:20}, {4.0:30, 5.0:40}}>
272 //
273 // produces
274 //
275 // %cst = arith.constant dense<[[10, 20], [30, 40]]> : tensor<2x2xi8>
276 //
277 Value materializeSubChannelZeroPoints(
278  OpBuilder &builder, Location loc,
279  UniformQuantizedSubChannelType quantizedType) {
280  auto zeroPoints = quantizedType.getZeroPoints();
281  auto storageType = quantizedType.getStorageType();
282  auto zeroPointAttrs = llvm::map_to_vector(
283  zeroPoints.getValues<APInt>(), [&](APInt zeroPoint) -> Attribute {
284  return builder.getIntegerAttr(storageType, zeroPoint);
285  });
286  auto tensorType =
287  RankedTensorType::get(zeroPoints.getType().getShape(), storageType);
288  auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
289  return arith::ConstantOp::create(builder, loc, tensorType, zeroPointsAttr);
290 }
291 
292 // Clamp the given scalar or tensor input using the storage bounds encoded in
293 // the given quantized type, if present.
294 //
295 // - input
296 // Scalar or ranked tensor input. The element type must match the storage type
297 // of 'quantizedType'.
298 //
299 // - inputShape
300 // If 'input' is a tensor, combination of attributes/values representing its
301 // static/dynamic dimensions. If 'input' is a scalar, empty list.
302 //
303 // - quantizedType
304 // Per-axis or per-channel quantized type.
305 Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
306  ArrayRef<OpFoldResult> inputShape,
307  QuantizedType quantizedType) {
308  // If quantized type does not narrow down the storage type range, there is
309  // nothing to do.
310  if (!quantizedType.hasStorageTypeBounds())
311  return input;
312 
313  // Materialize bounds
314  auto inputType = input.getType();
315  auto storageType = quantizedType.getStorageType();
316  auto storageMinScalar = arith::ConstantIntOp::create(
317  builder, loc, storageType, quantizedType.getStorageTypeMin());
318  auto storageMaxScalar = arith::ConstantIntOp::create(
319  builder, loc, storageType, quantizedType.getStorageTypeMax());
320  auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar,
321  inputType, inputShape);
322  auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar,
323  inputType, inputShape);
324 
325  // Clamp
326  if (quantizedType.isSigned()) {
327  input = arith::MaxSIOp::create(builder, loc, input, storageMin);
328  input = arith::MinSIOp::create(builder, loc, input, storageMax);
329  } else {
330  input = arith::MaxUIOp::create(builder, loc, input, storageMin);
331  input = arith::MinUIOp::create(builder, loc, input, storageMax);
332  }
333  return input;
334 }
335 
336 // Emit op 'arith.fptosi' or 'arith.fptoui'.
337 Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
338  Type resultType, bool isSigned) {
339  if (isSigned)
340  return arith::FPToSIOp::create(builder, loc, resultType, input);
341  return arith::FPToUIOp::create(builder, loc, resultType, input);
342 }
343 
344 // Emit op 'arith.sitofp' or 'arith.uitofp'.
345 Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
346  Type resultType, bool isSigned) {
347  if (isSigned)
348  return arith::SIToFPOp::create(builder, loc, resultType, input);
349  return arith::UIToFPOp::create(builder, loc, resultType, input);
350 }
351 
352 // Quantize a scalar or ranked tensor value. The stored value is clamped using
353 // the storage bounds encoded in the given quantized type.
354 //
355 // See function 'convertRanked()' below for a description of the arguments.
356 Value quantizeValue(OpBuilder &builder, Location loc, Value input,
357  ArrayRef<OpFoldResult> inputShape, Value scale,
358  Value zeroPoint, QuantizedType quantizedType) {
359  // Convert scale to tensor if necessary
360  auto inputType = input.getType();
361  scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
362 
363  // Scale input
364  auto scaledValue = arith::DivFOp::create(builder, loc, input, scale);
365 
366  // Skip unnecessary computations if no zero point is given
367  Value storedValueFloat = scaledValue;
368  if (!matchPattern(zeroPoint, m_Zero())) {
369  // Convert zero point to tensor if necessary
370  zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
371  inputShape);
372 
373  // Convert zero point from storage to expressed type
374  zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
375  quantizedType.isSigned());
376 
377  // Add zero point to stored value
378  storedValueFloat =
379  arith::AddFOp::create(builder, loc, scaledValue, zeroPoint);
380  }
381 
382  // Convert stored value to storage type
383  auto storageScalarOrTensorType =
384  getScalarOrTensorType(quantizedType.getStorageType(), inputType);
385  auto storedValueInt = convertFloatToInteger(builder, loc, storedValueFloat,
386  storageScalarOrTensorType,
387  quantizedType.isSigned());
388 
389  // Clamp stored value it if the storage type is bound
390  auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt,
391  inputShape, quantizedType);
392  return storedValueClamped;
393 }
394 
395 // Dequantize a scalar or ranked tensor input.
396 //
397 // See function 'convertRanked()' below for a description of the arguments.
398 Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
399  ArrayRef<OpFoldResult> inputShape, Value scale,
400  Value zeroPoint, QuantizedType quantizedType) {
401  // Convert scale to tensor if necessary
402  auto inputType = input.getType();
403  scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
404 
405  // Convert stored value to float
406  auto result = convertIntegerToFloat(builder, loc, input, scale.getType(),
407  quantizedType.isSigned());
408 
409  // Skip unnecessary computations if no zero point is given
410  if (!matchPattern(zeroPoint, m_Zero())) {
411  // Convert zero point to tensor if necessary
412  zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
413  inputShape);
414 
415  // Convert zero point from storage to expressed type
416  zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
417  quantizedType.isSigned());
418 
419  // Subtract zero point to stored value
420  result = arith::SubFOp::create(builder, loc, result, zeroPoint);
421  }
422 
423  // Multiply by scale
424  result = arith::MulFOp::create(builder, loc, result, scale);
425  return result;
426 }
427 
428 // Convert a scalar or ranked tensor input with the given scale and zero point
429 // values.
430 //
431 // - input
432 // Scalar or ranked tensor value.
433 //
434 // - inputShape
435 // If 'input' is a tensor, combination or attributes/values representing its
436 // static/dynamic dimensions. If 'input' is a scalar, empty list.
437 //
438 // - scale
439 // Scale as a floating-point scalar value.
440 //
441 // - zeroPoint
442 // Zero point as an integer scalar value.
443 //
444 // - quantizedType
445 // Scalar quantized type of the result ('quant.qcast') or of the input
446 // ('quant.dcast').
447 //
448 Value convertRanked(OpBuilder &builder, Location loc, Operation *op,
449  Value input, ArrayRef<OpFoldResult> inputShape, Value scale,
450  Value zeroPoint, QuantizedType quantizedType) {
451  if (isa<QuantizeCastOp>(op))
452  return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
453  quantizedType);
454  if (isa<DequantizeCastOp>(op))
455  return dequantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
456  quantizedType);
457  llvm_unreachable("unexpected quant op");
458 }
459 
460 // Convert an operation using per-layer quantization with a scalar or ranked
461 // tensor input.
462 //
463 // - op
464 // 'quant.dcast' or 'quant.qcast' op.
465 //
466 // - input
467 // Scalar or ranked tensor.
468 //
469 // - quantizedType
470 // Per-layer quantized type.
471 //
472 Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
473  Value input, UniformQuantizedType quantizedType) {
474  // Create scale and zero point constants
475  auto expressedType = quantizedType.getExpressedType();
476  auto storageType = quantizedType.getStorageType();
477  auto scaleAttr =
478  builder.getFloatAttr(expressedType, quantizedType.getScale());
479  auto scale =
480  arith::ConstantOp::create(builder, loc, expressedType, scaleAttr);
481  auto zeroPointAttr =
482  builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
483  auto zeroPoint =
484  arith::ConstantOp::create(builder, loc, storageType, zeroPointAttr);
485 
486  auto inputShape = getScalarOrTensorShape(builder, loc, input);
487  return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint,
488  quantizedType);
489 }
490 
491 // Convert an operation using per-layer quantization.
492 //
493 // - op
494 // 'quant.dcast' or 'quant.qcast' op.
495 //
496 // - input
497 // Scalar, ranked tensor, or unranked tensor.
498 //
499 // - quantizedType
500 // Per-layer quantized type.
501 //
502 Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op,
503  Value input, UniformQuantizedType quantizedType) {
504  // Flatten input if unranked
505  bool isUnranked = isa<UnrankedTensorType>(input.getType());
506  Value inputShape;
507  if (isUnranked)
508  std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
509 
510  // Process ranked tensor
511  auto result = convertPerLayerRanked(builder, loc, op, input, quantizedType);
512 
513  // Restore original shape if unranked
514  if (isUnranked)
515  result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
516 
517  return result;
518 }
519 
520 // Convert an operation using per-channel quantization and a scalar or ranked
521 // tensor as an input.
522 //
523 // - op
524 // 'quant.dcast' or 'quant.qcast' op.
525 //
526 // - input
527 // Scalar or ranked tensor.
528 //
529 // - quantizedType
530 // Per-channel quantized type.
531 //
532 Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
533  Value input,
534  UniformQuantizedPerAxisType quantizedType,
535  int64_t channelAxis) {
536  auto *context = builder.getContext();
537 
538  auto inputType = cast<RankedTensorType>(input.getType());
539  auto inputRank = inputType.getRank();
540 
541  auto scales = materializePerChannelScales(builder, loc, quantizedType);
542  auto zeroPoints =
543  materializePerChannelZeroPoints(builder, loc, quantizedType);
544 
545  auto elementType = isa<FloatType>(inputType.getElementType())
546  ? quantizedType.getStorageType()
547  : quantizedType.getExpressedType();
548  auto initShape = tensor::getMixedSizes(builder, loc, input);
549  Value init = tensor::EmptyOp::create(builder, loc, initShape, elementType);
550 
551  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
552  utils::IteratorType::parallel);
553  auto channelAxisAffineMap = AffineMap::get(
554  inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
555  SmallVector<AffineMap> indexingMaps{
556  builder.getMultiDimIdentityMap(inputRank), channelAxisAffineMap,
557  channelAxisAffineMap, builder.getMultiDimIdentityMap(inputRank)};
558  auto result = linalg::GenericOp::create(
559  builder, loc,
560  init.getType(), // resultType
561  ValueRange{input, scales, zeroPoints}, // inputs
562  ValueRange{init}, // outputs
563  indexingMaps, iteratorTypes,
564  [&](OpBuilder &builder, Location loc, ValueRange args) {
565  assert(args.size() == 4);
566  auto input = args[0];
567  auto scale = args[1];
568  auto zeroPoint = args[2];
569 
570  auto result =
571  convertRanked(builder, loc, op, input, {}, scale,
572  zeroPoint, quantizedType);
573 
574  linalg::YieldOp::create(builder, loc, result);
575  })
576  .getResult(0);
577 
578  return result;
579 }
580 
581 // Convert an operation using per-channel quantization.
582 //
583 // - op
584 // 'quant.dcast' or 'quant.qcast' op.
585 //
586 // - input
587 // Scalar, ranked tensor, or unranked tensor.
588 //
589 // - quantizedType
590 // Per-channel quantized type.
591 //
592 Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
593  Value input,
594  UniformQuantizedPerAxisType quantizedType) {
595  // Flatten unranked tensor into a 3D ranked tensor if necessary
596  bool isUnranked = isa<UnrankedTensorType>(input.getType());
597  int64_t channelAxis = quantizedType.getQuantizedDimension();
598  int64_t channelAxisSize = (int64_t)quantizedType.getScales().size();
599  Value inputShape;
600  if (isUnranked) {
601  std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
602  builder, loc, input, channelAxis, channelAxisSize);
603  channelAxis = 1;
604  }
605 
606  // Work on a ranked tensor
607  auto result = convertPerChannelRanked(builder, loc, op, input, quantizedType,
608  channelAxis);
609 
610  // Restore original tensor shape if unranked
611  if (isUnranked)
612  result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
613 
614  return result;
615 }
616 
617 // Convert an operation using sub-channel quantization.
618 //
619 // - op
620 // 'quant.dcast' or 'quant.qcast' op.
621 //
622 // - input
623 // Scalar, ranked tensor.
624 //
625 // - quantizedType
626 // Sub-channel quantized type.
627 //
628 Value convertSubChannel(OpBuilder &builder, Location loc, Operation *op,
629  Value input,
630  UniformQuantizedSubChannelType quantizedType) {
631  auto *context = builder.getContext();
632 
633  auto inputType = cast<RankedTensorType>(input.getType());
634  auto inputRank = inputType.getRank();
635 
636  auto scales = materializeSubChannelScales(builder, loc, quantizedType);
637  auto zeroPoints =
638  materializeSubChannelZeroPoints(builder, loc, quantizedType);
639 
640  auto elementType = isa<FloatType>(inputType.getElementType())
641  ? quantizedType.getStorageType()
642  : quantizedType.getExpressedType();
643  auto initShape = tensor::getMixedSizes(builder, loc, input);
644  Value init = tensor::EmptyOp::create(builder, loc, initShape, elementType);
645 
646  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
647  utils::IteratorType::parallel);
648  const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo =
649  quantizedType.getBlockSizeInfo();
650  SmallVector<AffineExpr> affineExprs(inputRank,
651  builder.getAffineConstantExpr(0));
652  for (auto [quantizedDimension, blockSize] : blockSizeInfo) {
653  affineExprs[quantizedDimension] =
654  builder.getAffineDimExpr(quantizedDimension).floorDiv(blockSize);
655  }
656  auto affineMap = AffineMap::get(inputRank, 0, affineExprs, context);
657  SmallVector<AffineMap> indexingMaps{
658  builder.getMultiDimIdentityMap(inputRank), affineMap, affineMap,
659  builder.getMultiDimIdentityMap(inputRank)};
660  auto result = linalg::GenericOp::create(
661  builder, loc,
662  init.getType(), // resultType
663  ValueRange{input, scales, zeroPoints}, // inputs
664  ValueRange{init}, // outputs
665  indexingMaps, iteratorTypes,
666  [&](OpBuilder &builder, Location loc, ValueRange args) {
667  assert(args.size() == 4);
668  auto input = args[0];
669  auto scale = args[1];
670  auto zeroPoint = args[2];
671 
672  auto result =
673  convertRanked(builder, loc, op, input, {}, scale,
674  zeroPoint, quantizedType);
675 
676  linalg::YieldOp::create(builder, loc, result);
677  })
678  .getResult(0);
679 
680  return result;
681 }
682 
683 // Convert a quantization operation.
684 //
685 // - op
686 // 'quant.dcast' or 'quant.qcast' op.
687 //
688 // - input
689 // Scalar, ranked tensor, or unranked tensor. The element type matches
690 // the storage type (quant.dcast) or expressed type (quant.qcast) of
691 // 'quantizedType'.
692 //
693 // - quantizedType
694 // Per-layer or per-channel quantized type.
695 //
696 Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
697  Value input, Type quantizedType) {
698  if (auto uniformQuantizedType = dyn_cast<UniformQuantizedType>(quantizedType))
699  return convertPerLayer(builder, loc, op, input, uniformQuantizedType);
700 
701  if (auto uniformQuantizedPerAxisType =
702  dyn_cast<UniformQuantizedPerAxisType>(quantizedType))
703  return convertPerChannel(builder, loc, op, input,
704  uniformQuantizedPerAxisType);
705 
706  if (auto uniformQuantizedSubChannelType =
707  dyn_cast<UniformQuantizedSubChannelType>(quantizedType))
708  return convertSubChannel(builder, loc, op, input,
709  uniformQuantizedSubChannelType);
710 
711  llvm_unreachable("unexpected quantized type");
712 }
713 
714 // Lowering pattern for 'quant.dcast'
715 struct DequantizeCastOpConversion
716  : public OpConversionPattern<quant::DequantizeCastOp> {
718 
719  LogicalResult
720  matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor,
721  ConversionPatternRewriter &rewriter) const override {
722  auto loc = op.getLoc();
723  auto input = op.getInput();
724  auto quantizedType =
725  cast<QuantizedType>(getScalarType(op.getInput().getType()));
726 
727  // Convert quantized input to storage type
728  auto storageScalarOrTensorType =
729  getScalarOrTensorType(quantizedType.getStorageType(), input.getType());
730  input = quant::StorageCastOp::create(rewriter, loc,
731  storageScalarOrTensorType, input);
732 
733  auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
734 
735  rewriter.replaceOp(op, result);
736  return success();
737  }
738 };
739 
740 // Lowering pattern for 'quant.qcast'
741 struct QuantizeCastOpConversion
742  : public OpConversionPattern<quant::QuantizeCastOp> {
744 
745  LogicalResult
746  matchAndRewrite(quant::QuantizeCastOp op, OpAdaptor adaptor,
747  ConversionPatternRewriter &rewriter) const override {
748  auto loc = op.getLoc();
749  auto input = op.getInput();
750  auto quantizedType = getScalarType(op.getResult().getType());
751 
752  // Flatten unranked tensor input
753  auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
754 
755  // Cast stored value to result quantized value
756  rewriter.replaceOpWithNewOp<quant::StorageCastOp>(
757  op, op.getResult().getType(), result);
758  return success();
759  }
760 };
761 
762 struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
763  void runOnOperation() override {
764  RewritePatternSet patterns(&getContext());
766 
767  ConversionTarget target(getContext());
768  target.addLegalOp<quant::StorageCastOp>();
769  target.addIllegalDialect<quant::QuantDialect>();
770  target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
771  shape::ShapeDialect, tensor::TensorDialect>();
772 
773  if (failed(applyPartialConversion(getOperation(), target,
774  std::move(patterns))))
775  signalPassFailure();
776  }
777 };
778 
779 } // namespace
780 
782  patterns.add<DequantizeCastOpConversion, QuantizeCastOpConversion>(
783  patterns.getContext());
784 }
785 
786 } // namespace quant
787 } // namespace mlir
static MLIRContext * getContext(OpFoldResult val)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition: ArithOps.cpp:258
void populateLowerQuantOpsPatterns(RewritePatternSet &patterns)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Definition: Shape.cpp:40
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.