MLIR  20.0.0git
LowerQuantOps.cpp
Go to the documentation of this file.
1 //===- LowerQuantOps.cpp - Lower 'quant' dialect ops ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Transforms `quant.dcast` and `quant.qcast` into lower-level ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
24 
25 namespace mlir {
26 namespace quant {
27 
28 #define GEN_PASS_DEF_LOWERQUANTOPS
29 #include "mlir/Dialect/Quant/Transforms/Passes.h.inc"
30 
31 namespace {
32 
33 // If 'inputType' is a tensor, return its element type. If it is a scalar,
34 // return it as is.
35 Type getScalarType(Type inputType) {
36  if (auto tensorType = dyn_cast<TensorType>(inputType))
37  return tensorType.getElementType();
38  return inputType;
39 }
40 
41 // Return the shape of an input value as a list of attributes (static dimensions)
42 // and values (dynamic dimensions). If 'input' is a scalar, an empty list is
43 // returned. If 'input' is a tensor, its shape is returned.
44 SmallVector<OpFoldResult>
45 getScalarOrTensorShape(OpBuilder &builder, Location loc, Value input) {
46  if (isa<TensorType>(input.getType()))
47  return tensor::getMixedSizes(builder, loc, input);
48  return {};
49 }
50 
51 // If 'referenceType' is a scalar, return 'elementType' as is. If
52 // 'referenceType' is a tensor, return another tensor with the same shape and
53 // elements of type 'elementType'.
54 Type getScalarOrTensorType(Type elementType, Type referenceType) {
55  if (auto tensorType = dyn_cast<TensorType>(referenceType))
56  return tensorType.clone(elementType);
57  return elementType;
58 }
59 
60 // Return a constant with the given value. If 'referenceType' is a tensor, a
61 // tensor splat of shape 'referenceShape' is returned. If 'referenceType' is a
62 // scalar, 'referenceShape' is ignored and a scalar constant is returned.
63 Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar,
64  Type referenceType,
65  ArrayRef<OpFoldResult> referenceShape) {
66  // If the result type is a scalar, return the unmodified scalar constant.
67  auto tensorType = dyn_cast<TensorType>(referenceType);
68  if (!tensorType) {
69  assert(referenceShape.empty());
70  return scalar;
71  }
72 
73  // Create tensor splat
74  auto tensorConstant =
75  builder.create<tensor::SplatOp>(loc, scalar, referenceShape);
76  return tensorConstant;
77 }
78 
79 // Reshape an unranked tensor into a 1D ranked tensor.
80 //
81 // - input
82 // Unranked tensor.
83 //
84 // Return values:
85 //
86 // - flatInput
87 // 1D ranked, dynamically shaped tensor.
88 //
89 // - inputShape
90 // 1D extent tensor containing the shape of the original unranked input.
91 //
92 std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
93  Value input) {
94  // Get unranked input shape and total size
95  auto *context = builder.getContext();
96  auto shapeType = shape::getExtentTensorType(context);
97  auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
98  Value inputSize = builder.create<shape::NumElementsOp>(
99  loc, builder.getIndexType(), inputShape);
100 
101  // Turn input size into 1D tensor
102  auto flatShapeType = shape::getExtentTensorType(context, 1);
103  auto flatInputShape = builder.create<tensor::FromElementsOp>(
104  loc, flatShapeType, inputSize);
105 
106  // Reshape input tensor into 1D
107  auto inputType = cast<UnrankedTensorType>(input.getType());
108  auto elementType = inputType.getElementType();
109  auto flatInputType =
110  RankedTensorType::get({ShapedType::kDynamic}, elementType);
111  auto flatInput = builder.create<tensor::ReshapeOp>(
112  loc, flatInputType, input, flatInputShape);
113  return std::make_pair(flatInput, inputShape);
114 }
115 
116 // Reshape an unranked tensor into a 3D ranked tensor where the central
117 // dimension of the result tensor corresponds to dimension 'axis' of the input
118 // tensor.
119 //
120 // - input
121 // Unranked tensor.
122 //
123 // - axis
124 // Index of the input dimension around which other input dimiensions will be
125 // collapsed.
126 //
127 // - axisSize
128 // Size of input dimension 'axis'.
129 //
130 // Return values:
131 //
132 // - flatInput
133 // 3D ranked tensor of shape [?, axisSize, ?].
134 //
135 // - inputShape
136 // 1D extent tensor containing the shape of the original unranked input.
137 //
138 std::pair<Value, Value> flattenUnrankedTensorAroundAxis(OpBuilder &builder,
139  Location loc,
140  Value input,
141  int64_t axis,
142  int64_t axisSize) {
143  // Get full tensor shape
144  auto *context = builder.getContext();
145  auto indexType = builder.getIndexType();
146  auto shapeType = shape::getExtentTensorType(context);
147  auto inputShape = builder.create<shape::ShapeOfOp>(loc, shapeType, input);
148 
149  // Get shape and sizes on left and right of axis
150  auto axisValue = builder.create<arith::ConstantIndexOp>(loc, axis);
151  auto axisNextValue = builder.create<arith::ConstantIndexOp>(loc, axis + 1);
152  auto shapeLeft = builder.create<shape::SplitAtOp>(
153  loc, TypeRange{shapeType, shapeType}, inputShape, axisValue)
154  .getResult(0);
155  auto sizeLeft = builder.create<shape::NumElementsOp>(
156  loc, indexType, shapeLeft);
157  auto shapeRight = builder.create<shape::SplitAtOp>(
158  loc, TypeRange{shapeType, shapeType}, inputShape, axisNextValue)
159  .getResult(1);
160  auto sizeRight = builder.create<shape::NumElementsOp>(
161  loc, indexType, shapeRight);
162 
163  // Compute flat input shape as a 3-element 1D tensor
164  auto axisSizeValue = builder.create<arith::ConstantIndexOp>(loc, axisSize);
165  auto flatShapeType = shape::getExtentTensorType(context, 3);
166  auto flatInputShape = builder.create<tensor::FromElementsOp>(
167  loc, flatShapeType, ValueRange{sizeLeft, axisSizeValue, sizeRight});
168 
169  // Reshape input to 3D tensor
170  auto inputType = cast<UnrankedTensorType>(input.getType());
171  auto elementType = inputType.getElementType();
172  auto flatInputType = RankedTensorType::get(
173  {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
174  auto flatInput = builder.create<tensor::ReshapeOp>(
175  loc, flatInputType, input, flatInputShape);
176 
177  return std::make_pair(flatInput, inputShape);
178 }
179 
180 // Reshape an input tensor into its original unranked shape.
181 //
182 // - input
183 // Ranked tensor.
184 //
185 // - inputShape
186 // 1D extent tensor.
187 //
188 Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
189  Value inputShape) {
190  auto inputType = cast<RankedTensorType>(input.getType());
191  auto elementType = inputType.getElementType();
192  auto unrankedType = UnrankedTensorType::get(elementType);
193  return builder.create<tensor::ReshapeOp>(loc, unrankedType, input, inputShape);
194 }
195 
196 // Create a tensor constant containing all scales in a per-channel quantized
197 // type. Example:
198 //
199 // !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
200 //
201 // produces
202 //
203 // %cst = arith.constant dense<[2.0, 3.0]> : tensor<2xf32>
204 //
205 Value materializePerChannelScales(OpBuilder &builder, Location loc,
206  UniformQuantizedPerAxisType quantizedType) {
207  auto scales = quantizedType.getScales();
208  auto expressedType = quantizedType.getExpressedType();
209  auto scaleAttrs = llvm::map_to_vector(scales, [&](double scale) -> Attribute {
210  return builder.getFloatAttr(expressedType, scale);
211  });
212  auto tensorType = RankedTensorType::get({(int64_t) scales.size()}, expressedType);
213  auto scalesAttr = DenseElementsAttr::get(tensorType, scaleAttrs);
214  return builder.create<arith::ConstantOp>(loc, tensorType, scalesAttr);
215 }
216 
217 // Create a tensor constant containing all zero points in a per-channel
218 // quantized type. Example:
219 //
220 // !quant.uniform<i8:f32:1, {2.0:10, 3.0:20}>
221 //
222 // produces
223 //
224 // %cst = arith.constant dense<[10, 20]> : tensor<2xi8>
225 //
226 Value materializePerChannelZeroPoints(
227  OpBuilder &builder, Location loc,
228  UniformQuantizedPerAxisType quantizedType) {
229  auto zeroPoints = quantizedType.getZeroPoints();
230  auto storageType = quantizedType.getStorageType();
231  auto zeroPointAttrs = llvm::map_to_vector(
232  zeroPoints,
233  [&](int64_t zeroPoint) -> Attribute {
234  return builder.getIntegerAttr(storageType, zeroPoint);
235  });
236  auto tensorType =
237  RankedTensorType::get({(int64_t)zeroPoints.size()}, storageType);
238  auto zeroPointsAttr = DenseElementsAttr::get(tensorType, zeroPointAttrs);
239  return builder.create<arith::ConstantOp>(loc, tensorType, zeroPointsAttr);
240 }
241 
242 // Clamp the given scalar or tensor input using the storage bounds encoded in
243 // the given quantized type, if present.
244 //
245 // - input
246 // Scalar or ranked tensor input. The element type must match the storage type
247 // of 'quantizedType'.
248 //
249 // - inputShape
250 // If 'input' is a tensor, combination of attributes/values representing its
251 // static/dynamic dimensions. If 'input' is a scalar, empty list.
252 //
253 // - quantizedType
254 // Per-axis or per-channel quantized type.
255 Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
256  ArrayRef<OpFoldResult> inputShape,
257  QuantizedType quantizedType) {
258  // If quantized type does not narrow down the storage type range, there is
259  // nothing to do.
260  if (!quantizedType.hasStorageTypeBounds())
261  return input;
262 
263  // Materialize bounds
264  auto inputType = input.getType();
265  auto storageType = quantizedType.getStorageType();
266  auto storageMinScalar = builder.create<arith::ConstantIntOp>(
267  loc, quantizedType.getStorageTypeMin(), storageType);
268  auto storageMaxScalar = builder.create<arith::ConstantIntOp>(
269  loc, quantizedType.getStorageTypeMax(), storageType);
270  auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar,
271  inputType, inputShape);
272  auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar,
273  inputType, inputShape);
274 
275  // Clamp
276  if (quantizedType.isSigned()) {
277  input = builder.create<arith::MaxSIOp>(loc, input, storageMin);
278  input = builder.create<arith::MinSIOp>(loc, input, storageMax);
279  } else {
280  input = builder.create<arith::MaxUIOp>(loc, input, storageMin);
281  input = builder.create<arith::MinUIOp>(loc, input, storageMax);
282  }
283  return input;
284 }
285 
286 // Emit op 'arith.fptosi' or 'arith.fptoui'.
287 Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
288  Type resultType, bool isSigned) {
289  if (isSigned)
290  return builder.create<arith::FPToSIOp>(loc, resultType, input);
291  return builder.create<arith::FPToUIOp>(loc, resultType, input);
292 }
293 
294 // Emit op 'arith.sitofp' or 'arith.uitofp'.
295 Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
296  Type resultType, bool isSigned) {
297  if (isSigned)
298  return builder.create<arith::SIToFPOp>(loc, resultType, input);
299  return builder.create<arith::UIToFPOp>(loc, resultType, input);
300 }
301 
302 // Quantize a scalar or ranked tensor value. The stored value is clamped using
303 // the storage bounds encoded in the given quantized type.
304 //
305 // See function 'convertRanked()' below for a description of the arguments.
306 Value quantizeValue(OpBuilder &builder, Location loc, Value input,
307  ArrayRef<OpFoldResult> inputShape, Value scale,
308  Value zeroPoint, QuantizedType quantizedType) {
309  // Convert scale to tensor if necessary
310  auto inputType = input.getType();
311  scale = getScalarOrTensorConstant(
312  builder, loc, scale, inputType, inputShape);
313 
314  // Scale input
315  auto scaledValue = builder.create<arith::DivFOp>(loc, input, scale);
316 
317  // Skip unnecessary computations if no zero point is given
318  Value storedValueFloat = scaledValue;
319  if (!matchPattern(zeroPoint, m_Zero())) {
320  // Convert zero point to tensor if necessary
321  zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
322  inputShape);
323 
324  // Convert zero point from storage to expressed type
325  zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
326  scale.getType(),
327  quantizedType.isSigned());
328 
329  // Add zero point to stored value
330  storedValueFloat =
331  builder.create<arith::AddFOp>(loc, scaledValue, zeroPoint);
332  }
333 
334  // Convert stored value to storage type
335  auto storageScalarOrTensorType =
336  getScalarOrTensorType(quantizedType.getStorageType(), inputType);
337  auto storedValueInt = convertFloatToInteger(
338  builder, loc, storedValueFloat, storageScalarOrTensorType,
339  quantizedType.isSigned());
340 
341  // Clamp stored value it if the storage type is bound
342  auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt,
343  inputShape, quantizedType);
344  return storedValueClamped;
345 }
346 
347 // Dequantize a scalar or ranked tensor input.
348 //
349 // See function 'convertRanked()' below for a description of the arguments.
350 Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
351  ArrayRef<OpFoldResult> inputShape, Value scale,
352  Value zeroPoint, QuantizedType quantizedType) {
353  // Convert scale to tensor if necessary
354  auto inputType = input.getType();
355  scale = getScalarOrTensorConstant(
356  builder, loc, scale, inputType, inputShape);
357 
358  // Convert stored value to float
359  auto result = convertIntegerToFloat(
360  builder, loc, input, scale.getType(), quantizedType.isSigned());
361 
362  // Skip unnecessary computations if no zero point is given
363  if (!matchPattern(zeroPoint, m_Zero())) {
364  // Convert zero point to tensor if necessary
365  zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
366  inputShape);
367 
368  // Convert zero point from storage to expressed type
369  zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint,
370  scale.getType(),
371  quantizedType.isSigned());
372 
373  // Subtract zero point to stored value
374  result = builder.create<arith::SubFOp>(loc, result, zeroPoint);
375  }
376 
377  // Multiply by scale
378  result = builder.create<arith::MulFOp>(loc, result, scale);
379  return result;
380 }
381 
382 // Convert a scalar or ranked tensor input with the given scale and zero point
383 // values.
384 //
385 // - input
386 // Scalar or ranked tensor value.
387 //
388 // - inputShape
389 // If 'input' is a tensor, combination or attributes/values representing its
390 // static/dynamic dimensions. If 'input' is a scalar, empty list.
391 //
392 // - scale
393 // Scale as a floating-point scalar value.
394 //
395 // - zeroPoint
396 // Zero point as an integer scalar value.
397 //
398 // - quantizedType
399 // Scalar quantized type of the result ('quant.qcast') or of the input
400 // ('quant.dcast').
401 //
402 Value convertRanked(OpBuilder &builder, Location loc, Operation *op,
403  Value input, ArrayRef<OpFoldResult> inputShape, Value scale,
404  Value zeroPoint, QuantizedType quantizedType) {
405  if (isa<QuantizeCastOp>(op))
406  return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
407  quantizedType);
408  if (isa<DequantizeCastOp>(op))
409  return dequantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
410  quantizedType);
411  llvm_unreachable("unexpected quant op");
412 }
413 
414 // Convert an operation using per-layer quantization with a scalar or ranked
415 // tensor input.
416 //
417 // - op
418 // 'quant.dcast' or 'quant.qcast' op.
419 //
420 // - input
421 // Scalar or ranked tensor.
422 //
423 // - quantizedType
424 // Per-layer quantized type.
425 //
426 Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
427  Value input, UniformQuantizedType quantizedType) {
428  // Create scale and zero point constants
429  auto expressedType = quantizedType.getExpressedType();
430  auto storageType = quantizedType.getStorageType();
431  auto scaleAttr =
432  builder.getFloatAttr(expressedType, quantizedType.getScale());
433  auto scale = builder.create<arith::ConstantOp>(loc, expressedType, scaleAttr);
434  auto zeroPointAttr =
435  builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
436  auto zeroPoint =
437  builder.create<arith::ConstantOp>(loc, storageType, zeroPointAttr);
438 
439  auto inputShape = getScalarOrTensorShape(builder, loc, input);
440  return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint,
441  quantizedType);
442 }
443 
444 // Convert an operation using per-layer quantization.
445 //
446 // - op
447 // 'quant.dcast' or 'quant.qcast' op.
448 //
449 // - input
450 // Scalar, ranked tensor, or unranked tensor.
451 //
452 // - quantizedType
453 // Per-layer quantized type.
454 //
455 Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op,
456  Value input, UniformQuantizedType quantizedType) {
457  // Flatten input if unranked
458  bool isUnranked = isa<UnrankedTensorType>(input.getType());
459  Value inputShape;
460  if (isUnranked)
461  std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
462 
463  // Process ranked tensor
464  auto result = convertPerLayerRanked(builder, loc, op, input, quantizedType);
465 
466  // Restore original shape if unranked
467  if (isUnranked)
468  result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
469 
470  return result;
471 }
472 
473 // Convert an operation using per-channel quantization and a scalar or ranked
474 // tensor as an input.
475 //
476 // - op
477 // 'quant.dcast' or 'quant.qcast' op.
478 //
479 // - input
480 // Scalar or ranked tensor.
481 //
482 // - quantizedType
483 // Per-channel quantized type.
484 //
485 Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
486  Value input,
487  UniformQuantizedPerAxisType quantizedType,
488  int64_t channelAxis) {
489  auto *context = builder.getContext();
490 
491  auto inputType = cast<RankedTensorType>(input.getType());
492  auto inputRank = inputType.getRank();
493 
494  auto scales = materializePerChannelScales(builder, loc, quantizedType);
495  auto zeroPoints =
496  materializePerChannelZeroPoints(builder, loc, quantizedType);
497 
498  auto elementType = isa<FloatType>(inputType.getElementType())
499  ? quantizedType.getStorageType()
500  : quantizedType.getExpressedType();
501  auto initShape = tensor::getMixedSizes(builder, loc, input);
502  Value init = builder.create<tensor::EmptyOp>(loc, initShape, elementType);
503 
504  SmallVector<utils::IteratorType> iteratorTypes(
505  inputRank, utils::IteratorType::parallel);
506  auto channelAxisAffineMap = AffineMap::get(
507  inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
508  SmallVector<AffineMap> indexingMaps{
509  builder.getMultiDimIdentityMap(inputRank),
510  channelAxisAffineMap,
511  channelAxisAffineMap,
512  builder.getMultiDimIdentityMap(inputRank)
513  };
514  auto result = builder.create<linalg::GenericOp>(
515  loc,
516  init.getType(), // resultType
517  ValueRange{input, scales, zeroPoints}, // inputs
518  ValueRange{init}, // outputs
519  indexingMaps,
520  iteratorTypes,
521  [&](OpBuilder& builder, Location loc, ValueRange args) {
522  assert(args.size() == 4);
523  auto input = args[0];
524  auto scale = args[1];
525  auto zeroPoint = args[2];
526 
527  auto result = convertRanked(builder, loc, op, input, {}, scale,
528  zeroPoint, quantizedType);
529 
530  builder.create<linalg::YieldOp>(loc, result);
531  })
532  .getResult(0);
533 
534  return result;
535 }
536 
537 // Convert an operation using per-channel quantization.
538 //
539 // - op
540 // 'quant.dcast' or 'quant.qcast' op.
541 //
542 // - input
543 // Scalar, ranked tensor, or unranked tensor.
544 //
545 // - quantizedType
546 // Per-channel quantized type.
547 //
548 Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
549  Value input,
550  UniformQuantizedPerAxisType quantizedType) {
551  // Flatten unranked tensor into a 3D ranked tensor if necessary
552  bool isUnranked = isa<UnrankedTensorType>(input.getType());
553  int64_t channelAxis = quantizedType.getQuantizedDimension();
554  int64_t channelAxisSize = (int64_t) quantizedType.getScales().size();
555  Value inputShape;
556  if (isUnranked) {
557  std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
558  builder, loc, input, channelAxis, channelAxisSize);
559  channelAxis = 1;
560  }
561 
562  // Work on a ranked tensor
563  auto result = convertPerChannelRanked(builder, loc, op, input, quantizedType,
564  channelAxis);
565 
566  // Restore original tensor shape if unranked
567  if (isUnranked)
568  result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
569 
570  return result;
571 }
572 
573 // Convert a quantization operation.
574 //
575 // - op
576 // 'quant.dcast' or 'quant.qcast' op.
577 //
578 // - input
579 // Scalar, ranked tensor, or unranked tensor. The element type matches
580 // the storage type (quant.dcast) or expressed type (quant.qcast) of
581 // 'quantizedType'.
582 //
583 // - quantizedType
584 // Per-layer or per-channel quantized type.
585 //
586 Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
587  Value input, Type quantizedType) {
588  if (auto uniformQuantizedType = dyn_cast<UniformQuantizedType>(quantizedType))
589  return convertPerLayer(builder, loc, op, input, uniformQuantizedType);
590 
591  if (auto uniformQuantizedPerAxisType =
592  dyn_cast<UniformQuantizedPerAxisType>(quantizedType))
593  return convertPerChannel(builder, loc, op, input,
594  uniformQuantizedPerAxisType);
595 
596  llvm_unreachable("unexpected quantized type");
597 }
598 
599 // Lowering pattern for 'quant.dcast'
600 struct DequantizeCastOpConversion : public OpConversionPattern<quant::DequantizeCastOp> {
602 
603  LogicalResult
604  matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor,
605  ConversionPatternRewriter &rewriter) const override {
606  auto loc = op.getLoc();
607  auto input = op.getInput();
608  auto quantizedType =
609  cast<QuantizedType>(getScalarType(op.getInput().getType()));
610 
611  // Convert quantized input to storage type
612  auto storageScalarOrTensorType =
613  getScalarOrTensorType(quantizedType.getStorageType(), input.getType());
614  input = rewriter.create<quant::StorageCastOp>(
615  loc, storageScalarOrTensorType, input);
616 
617  auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
618 
619  rewriter.replaceOp(op, result);
620  return success();
621  }
622 };
623 
624 // Lowering pattern for 'quant.qcast'
625 struct QuantizeCastOpConversion : public OpConversionPattern<quant::QuantizeCastOp> {
627 
628  LogicalResult
629  matchAndRewrite(quant::QuantizeCastOp op, OpAdaptor adaptor,
630  ConversionPatternRewriter &rewriter) const override {
631  auto loc = op.getLoc();
632  auto input = op.getInput();
633  auto quantizedType = getScalarType(op.getResult().getType());
634 
635  // Flatten unranked tensor input
636  auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
637 
638  // Cast stored value to result quantized value
639  rewriter.replaceOpWithNewOp<quant::StorageCastOp>(
640  op, op.getResult().getType(), result);
641  return success();
642  }
643 };
644 
645 struct LowerQuantOps : public impl::LowerQuantOpsBase<LowerQuantOps> {
646  void runOnOperation() override {
647  RewritePatternSet patterns(&getContext());
649 
650  ConversionTarget target(getContext());
651  target.addLegalOp<quant::StorageCastOp>();
652  target.addIllegalDialect<quant::QuantDialect>();
653  target.addLegalDialect<
654  arith::ArithDialect,
655  linalg::LinalgDialect,
656  shape::ShapeDialect,
657  tensor::TensorDialect
658  >();
659 
660  if (failed(applyPartialConversion(getOperation(), target,
661  std::move(patterns))))
662  signalPassFailure();
663  }
664 };
665 
666 } // namespace
667 
669  patterns.add<
670  DequantizeCastOpConversion,
671  QuantizeCastOpConversion
672  >(patterns.getContext());
673 }
674 
675 } // namespace quant
676 } // namespace mlir
static MLIRContext * getContext(OpFoldResult val)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
void populateLowerQuantOpsPatterns(RewritePatternSet &patterns)
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Definition: Shape.cpp:41
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:66
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:437
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.