28 #define GEN_PASS_DEF_LOWERQUANTOPS 
   29 #include "mlir/Dialect/Quant/Transforms/Passes.h.inc" 
   35 Type getScalarType(Type inputType) {
 
   36   if (
auto tensorType = dyn_cast<TensorType>(inputType))
 
   37     return tensorType.getElementType();
 
   44 SmallVector<OpFoldResult> getScalarOrTensorShape(OpBuilder &builder,
 
   45                                                  Location loc, Value input) {
 
   46   if (isa<TensorType>(input.getType()))
 
   54 Type getScalarOrTensorType(Type elementType, Type referenceType) {
 
   55   if (
auto tensorType = dyn_cast<TensorType>(referenceType))
 
   56     return tensorType.clone(elementType);
 
   63 Value getScalarOrTensorConstant(OpBuilder &builder, Location loc, Value scalar,
 
   65                                 ArrayRef<OpFoldResult> referenceShape) {
 
   67   auto tensorType = dyn_cast<TensorType>(referenceType);
 
   69     assert(referenceShape.empty());
 
   75       tensor::SplatOp::create(builder, loc, scalar, referenceShape);
 
   76   return tensorConstant;
 
   92 std::pair<Value, Value> flattenUnrankedTensor(OpBuilder &builder, Location loc,
 
   95   auto *context = builder.getContext();
 
   97   auto inputShape = shape::ShapeOfOp::create(builder, loc, shapeType, input);
 
   98   Value inputSize = shape::NumElementsOp::create(
 
   99       builder, loc, builder.getIndexType(), inputShape);
 
  103   auto flatInputShape =
 
  104       tensor::FromElementsOp::create(builder, loc, flatShapeType, inputSize);
 
  107   auto inputType = cast<UnrankedTensorType>(input.getType());
 
  108   auto elementType = inputType.getElementType();
 
  111   auto flatInput = tensor::ReshapeOp::create(builder, loc, flatInputType, input,
 
  113   return std::make_pair(flatInput, inputShape);
 
  138 std::pair<Value, Value>
 
  139 flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input,
 
  140                                 int64_t axis, int64_t axisSize) {
 
  142   auto *context = builder.getContext();
 
  143   auto indexType = builder.getIndexType();
 
  145   auto inputShape = shape::ShapeOfOp::create(builder, loc, shapeType, input);
 
  151       shape::SplitAtOp::create(builder, loc, TypeRange{shapeType, shapeType},
 
  152                                inputShape, axisValue)
 
  155       shape::NumElementsOp::create(builder, loc, indexType, shapeLeft);
 
  157       shape::SplitAtOp::create(builder, loc, TypeRange{shapeType, shapeType},
 
  158                                inputShape, axisNextValue)
 
  161       shape::NumElementsOp::create(builder, loc, indexType, shapeRight);
 
  166   auto flatInputShape = tensor::FromElementsOp::create(
 
  167       builder, loc, flatShapeType,
 
  168       ValueRange{sizeLeft, axisSizeValue, sizeRight});
 
  171   auto inputType = cast<UnrankedTensorType>(input.getType());
 
  172   auto elementType = inputType.getElementType();
 
  174       {ShapedType::kDynamic, axisSize, ShapedType::kDynamic}, elementType);
 
  175   auto flatInput = tensor::ReshapeOp::create(builder, loc, flatInputType, input,
 
  178   return std::make_pair(flatInput, inputShape);
 
  189 Value restoreUnrankedTensorShape(OpBuilder &builder, Location loc, Value input,
 
  191   auto inputType = cast<RankedTensorType>(input.getType());
 
  192   auto elementType = inputType.getElementType();
 
  194   return tensor::ReshapeOp::create(builder, loc, unrankedType, input,
 
  207 Value materializePerChannelScales(OpBuilder &builder, Location loc,
 
  208                                   UniformQuantizedPerAxisType quantizedType) {
 
  209   auto scales = quantizedType.getScales();
 
  210   auto expressedType = quantizedType.getExpressedType();
 
  211   auto scaleAttrs = llvm::map_to_vector(scales, [&](
double scale) -> Attribute {
 
  212     return builder.getFloatAttr(expressedType, scale);
 
  217   return arith::ConstantOp::create(builder, loc, tensorType, scalesAttr);
 
  229 Value materializePerChannelZeroPoints(
 
  230     OpBuilder &builder, Location loc,
 
  231     UniformQuantizedPerAxisType quantizedType) {
 
  232   auto zeroPoints = quantizedType.getZeroPoints();
 
  233   auto storageType = quantizedType.getStorageType();
 
  234   auto zeroPointAttrs =
 
  235       llvm::map_to_vector(zeroPoints, [&](int64_t zeroPoint) -> Attribute {
 
  236         return builder.getIntegerAttr(storageType, zeroPoint);
 
  241   return arith::ConstantOp::create(builder, loc, tensorType, zeroPointsAttr);
 
  253 Value materializeSubChannelScales(
 
  254     OpBuilder &builder, Location loc,
 
  255     UniformQuantizedSubChannelType quantizedType) {
 
  256   auto scales = quantizedType.getScales();
 
  257   auto expressedType = quantizedType.getExpressedType();
 
  258   auto scaleAttrs = llvm::map_to_vector(
 
  259       scales.getValues<APFloat>(), [&](APFloat scale) -> Attribute {
 
  260         return builder.getFloatAttr(expressedType, scale);
 
  265   return arith::ConstantOp::create(builder, loc, tensorType, scalesAttr);
 
  277 Value materializeSubChannelZeroPoints(
 
  278     OpBuilder &builder, Location loc,
 
  279     UniformQuantizedSubChannelType quantizedType) {
 
  280   auto zeroPoints = quantizedType.getZeroPoints();
 
  281   auto storageType = quantizedType.getStorageType();
 
  282   auto zeroPointAttrs = llvm::map_to_vector(
 
  283       zeroPoints.getValues<APInt>(), [&](APInt zeroPoint) -> Attribute {
 
  284         return builder.getIntegerAttr(storageType, zeroPoint);
 
  289   return arith::ConstantOp::create(builder, loc, tensorType, zeroPointsAttr);
 
  305 Value clampScalarOrTensor(OpBuilder &builder, Location loc, Value input,
 
  306                           ArrayRef<OpFoldResult> inputShape,
 
  307                           QuantizedType quantizedType) {
 
  310   if (!quantizedType.hasStorageTypeBounds())
 
  314   auto inputType = input.getType();
 
  315   auto storageType = quantizedType.getStorageType();
 
  317       builder, loc, storageType, quantizedType.getStorageTypeMin());
 
  319       builder, loc, storageType, quantizedType.getStorageTypeMax());
 
  320   auto storageMin = getScalarOrTensorConstant(builder, loc, storageMinScalar,
 
  321                                               inputType, inputShape);
 
  322   auto storageMax = getScalarOrTensorConstant(builder, loc, storageMaxScalar,
 
  323                                               inputType, inputShape);
 
  326   if (quantizedType.isSigned()) {
 
  327     input = arith::MaxSIOp::create(builder, loc, input, storageMin);
 
  328     input = arith::MinSIOp::create(builder, loc, input, storageMax);
 
  330     input = arith::MaxUIOp::create(builder, loc, input, storageMin);
 
  331     input = arith::MinUIOp::create(builder, loc, input, storageMax);
 
  337 Value convertFloatToInteger(OpBuilder &builder, Location loc, Value input,
 
  338                             Type resultType, 
bool isSigned) {
 
  340     return arith::FPToSIOp::create(builder, loc, resultType, input);
 
  341   return arith::FPToUIOp::create(builder, loc, resultType, input);
 
  345 Value convertIntegerToFloat(OpBuilder &builder, Location loc, Value input,
 
  346                             Type resultType, 
bool isSigned) {
 
  348     return arith::SIToFPOp::create(builder, loc, resultType, input);
 
  349   return arith::UIToFPOp::create(builder, loc, resultType, input);
 
  356 Value quantizeValue(OpBuilder &builder, Location loc, Value input,
 
  357                     ArrayRef<OpFoldResult> inputShape, Value scale,
 
  358                     Value zeroPoint, QuantizedType quantizedType) {
 
  360   auto inputType = input.getType();
 
  361   scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
 
  364   auto scaledValue = arith::DivFOp::create(builder, loc, input, scale);
 
  367   Value storedValueFloat = scaledValue;
 
  370     zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
 
  374     zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
 
  375                                       quantizedType.isSigned());
 
  379         arith::AddFOp::create(builder, loc, scaledValue, zeroPoint);
 
  383   auto storageScalarOrTensorType =
 
  384       getScalarOrTensorType(quantizedType.getStorageType(), inputType);
 
  385   auto storedValueInt = convertFloatToInteger(builder, loc, storedValueFloat,
 
  386                                               storageScalarOrTensorType,
 
  387                                               quantizedType.isSigned());
 
  390   auto storedValueClamped = clampScalarOrTensor(builder, loc, storedValueInt,
 
  391                                                 inputShape, quantizedType);
 
  392   return storedValueClamped;
 
  398 Value dequantizeValue(OpBuilder &builder, Location loc, Value input,
 
  399                       ArrayRef<OpFoldResult> inputShape, Value scale,
 
  400                       Value zeroPoint, QuantizedType quantizedType) {
 
  402   auto inputType = input.getType();
 
  403   scale = getScalarOrTensorConstant(builder, loc, scale, inputType, inputShape);
 
  406   auto result = convertIntegerToFloat(builder, loc, input, scale.getType(),
 
  407                                       quantizedType.isSigned());
 
  412     zeroPoint = getScalarOrTensorConstant(builder, loc, zeroPoint, inputType,
 
  416     zeroPoint = convertIntegerToFloat(builder, loc, zeroPoint, scale.getType(),
 
  417                                       quantizedType.isSigned());
 
  420     result = arith::SubFOp::create(builder, loc, result, zeroPoint);
 
  424   result = arith::MulFOp::create(builder, loc, result, scale);
 
  448 Value convertRanked(OpBuilder &builder, Location loc, Operation *op,
 
  449                     Value input, ArrayRef<OpFoldResult> inputShape, Value scale,
 
  450                     Value zeroPoint, QuantizedType quantizedType) {
 
  451   if (isa<QuantizeCastOp>(op))
 
  452     return quantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
 
  454   if (isa<DequantizeCastOp>(op))
 
  455     return dequantizeValue(builder, loc, input, inputShape, scale, zeroPoint,
 
  457   llvm_unreachable(
"unexpected quant op");
 
  472 Value convertPerLayerRanked(OpBuilder &builder, Location loc, Operation *op,
 
  473                             Value input, UniformQuantizedType quantizedType) {
 
  475   auto expressedType = quantizedType.getExpressedType();
 
  476   auto storageType = quantizedType.getStorageType();
 
  478       builder.getFloatAttr(expressedType, quantizedType.getScale());
 
  480       arith::ConstantOp::create(builder, loc, expressedType, scaleAttr);
 
  482       builder.getIntegerAttr(storageType, quantizedType.getZeroPoint());
 
  484       arith::ConstantOp::create(builder, loc, storageType, zeroPointAttr);
 
  486   auto inputShape = getScalarOrTensorShape(builder, loc, input);
 
  487   return convertRanked(builder, loc, op, input, inputShape, scale, zeroPoint,
 
  502 Value convertPerLayer(OpBuilder &builder, Location loc, Operation *op,
 
  503                       Value input, UniformQuantizedType quantizedType) {
 
  505   bool isUnranked = isa<UnrankedTensorType>(input.getType());
 
  508     std::tie(input, inputShape) = flattenUnrankedTensor(builder, loc, input);
 
  511   auto result = convertPerLayerRanked(builder, loc, op, input, quantizedType);
 
  515     result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
 
  532 Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op,
 
  534                               UniformQuantizedPerAxisType quantizedType,
 
  535                               int64_t channelAxis) {
 
  536   auto *context = builder.getContext();
 
  538   auto inputType = cast<RankedTensorType>(input.getType());
 
  539   auto inputRank = inputType.getRank();
 
  541   auto scales = materializePerChannelScales(builder, loc, quantizedType);
 
  543       materializePerChannelZeroPoints(builder, loc, quantizedType);
 
  545   auto elementType = isa<FloatType>(inputType.getElementType())
 
  546                          ? quantizedType.getStorageType()
 
  547                          : quantizedType.getExpressedType();
 
  549   Value init = tensor::EmptyOp::create(builder, loc, initShape, elementType);
 
  551   SmallVector<utils::IteratorType> iteratorTypes(inputRank,
 
  552                                                  utils::IteratorType::parallel);
 
  554       inputRank, 0, builder.getAffineDimExpr(channelAxis), context);
 
  555   SmallVector<AffineMap> indexingMaps{
 
  556       builder.getMultiDimIdentityMap(inputRank), channelAxisAffineMap,
 
  557       channelAxisAffineMap, builder.getMultiDimIdentityMap(inputRank)};
 
  558   auto result = linalg::GenericOp::create(
 
  561                     ValueRange{input, scales, zeroPoints}, 
 
  563                     indexingMaps, iteratorTypes,
 
  564                     [&](OpBuilder &builder, Location loc, ValueRange args) {
 
  565                       assert(args.size() == 4);
 
  566                       auto input = args[0];
 
  567                       auto scale = args[1];
 
  568                       auto zeroPoint = args[2];
 
  571                           convertRanked(builder, loc, op, input, {}, scale,
 
  572                                         zeroPoint, quantizedType);
 
  574                       linalg::YieldOp::create(builder, loc, result);
 
  592 Value convertPerChannel(OpBuilder &builder, Location loc, Operation *op,
 
  594                         UniformQuantizedPerAxisType quantizedType) {
 
  596   bool isUnranked = isa<UnrankedTensorType>(input.getType());
 
  597   int64_t channelAxis = quantizedType.getQuantizedDimension();
 
  598   int64_t channelAxisSize = (int64_t)quantizedType.getScales().size();
 
  601     std::tie(input, inputShape) = flattenUnrankedTensorAroundAxis(
 
  602         builder, loc, input, channelAxis, channelAxisSize);
 
  607   auto result = convertPerChannelRanked(builder, loc, op, input, quantizedType,
 
  612     result = restoreUnrankedTensorShape(builder, loc, result, inputShape);
 
  628 Value convertSubChannel(OpBuilder &builder, Location loc, Operation *op,
 
  630                         UniformQuantizedSubChannelType quantizedType) {
 
  631   auto *context = builder.getContext();
 
  633   auto inputType = cast<RankedTensorType>(input.getType());
 
  634   auto inputRank = inputType.getRank();
 
  636   auto scales = materializeSubChannelScales(builder, loc, quantizedType);
 
  638       materializeSubChannelZeroPoints(builder, loc, quantizedType);
 
  640   auto elementType = isa<FloatType>(inputType.getElementType())
 
  641                          ? quantizedType.getStorageType()
 
  642                          : quantizedType.getExpressedType();
 
  644   Value init = tensor::EmptyOp::create(builder, loc, initShape, elementType);
 
  646   SmallVector<utils::IteratorType> iteratorTypes(inputRank,
 
  647                                                  utils::IteratorType::parallel);
 
  648   const SmallVector<std::pair<int32_t, int64_t>> &blockSizeInfo =
 
  649       quantizedType.getBlockSizeInfo();
 
  650   SmallVector<AffineExpr> affineExprs(inputRank,
 
  651                                       builder.getAffineConstantExpr(0));
 
  652   for (
auto [quantizedDimension, blockSize] : blockSizeInfo) {
 
  653     affineExprs[quantizedDimension] =
 
  654         builder.getAffineDimExpr(quantizedDimension).floorDiv(blockSize);
 
  656   auto affineMap = 
AffineMap::get(inputRank, 0, affineExprs, context);
 
  657   SmallVector<AffineMap> indexingMaps{
 
  658       builder.getMultiDimIdentityMap(inputRank), affineMap, affineMap,
 
  659       builder.getMultiDimIdentityMap(inputRank)};
 
  660   auto result = linalg::GenericOp::create(
 
  663                     ValueRange{input, scales, zeroPoints}, 
 
  665                     indexingMaps, iteratorTypes,
 
  666                     [&](OpBuilder &builder, Location loc, ValueRange args) {
 
  667                       assert(args.size() == 4);
 
  668                       auto input = args[0];
 
  669                       auto scale = args[1];
 
  670                       auto zeroPoint = args[2];
 
  673                           convertRanked(builder, loc, op, input, {}, scale,
 
  674                                         zeroPoint, quantizedType);
 
  676                       linalg::YieldOp::create(builder, loc, result);
 
  696 Value convertQuantized(OpBuilder &builder, Location loc, Operation *op,
 
  697                        Value input, Type quantizedType) {
 
  698   if (
auto uniformQuantizedType = dyn_cast<UniformQuantizedType>(quantizedType))
 
  699     return convertPerLayer(builder, loc, op, input, uniformQuantizedType);
 
  701   if (
auto uniformQuantizedPerAxisType =
 
  702           dyn_cast<UniformQuantizedPerAxisType>(quantizedType))
 
  703     return convertPerChannel(builder, loc, op, input,
 
  704                              uniformQuantizedPerAxisType);
 
  706   if (
auto uniformQuantizedSubChannelType =
 
  707           dyn_cast<UniformQuantizedSubChannelType>(quantizedType))
 
  708     return convertSubChannel(builder, loc, op, input,
 
  709                              uniformQuantizedSubChannelType);
 
  711   llvm_unreachable(
"unexpected quantized type");
 
  715 struct DequantizeCastOpConversion
 
  716     : 
public OpConversionPattern<quant::DequantizeCastOp> {
 
  720   matchAndRewrite(quant::DequantizeCastOp op, OpAdaptor adaptor,
 
  721                   ConversionPatternRewriter &rewriter)
 const override {
 
  722     auto loc = op.getLoc();
 
  723     auto input = op.getInput();
 
  725         cast<QuantizedType>(getScalarType(op.getInput().getType()));
 
  728     auto storageScalarOrTensorType =
 
  729         getScalarOrTensorType(quantizedType.getStorageType(), input.getType());
 
  730     input = quant::StorageCastOp::create(rewriter, loc,
 
  731                                          storageScalarOrTensorType, input);
 
  733     auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
 
  735     rewriter.replaceOp(op, result);
 
  741 struct QuantizeCastOpConversion
 
  742     : 
public OpConversionPattern<quant::QuantizeCastOp> {
 
  746   matchAndRewrite(quant::QuantizeCastOp op, OpAdaptor adaptor,
 
  747                   ConversionPatternRewriter &rewriter)
 const override {
 
  748     auto loc = op.getLoc();
 
  749     auto input = op.getInput();
 
  750     auto quantizedType = getScalarType(op.getResult().getType());
 
  753     auto result = convertQuantized(rewriter, loc, op, input, quantizedType);
 
  756     rewriter.replaceOpWithNewOp<quant::StorageCastOp>(
 
  757         op, op.getResult().getType(), result);
 
  762 struct LowerQuantOps : 
public impl::LowerQuantOpsBase<LowerQuantOps> {
 
  763   void runOnOperation()
 override {
 
  768     target.addLegalOp<quant::StorageCastOp>();
 
  769     target.addIllegalDialect<quant::QuantDialect>();
 
  770     target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
 
  771                            shape::ShapeDialect, tensor::TensorDialect>();
 
  782   patterns.add<DequantizeCastOpConversion, QuantizeCastOpConversion>(
 
static MLIRContext * getContext(OpFoldResult val)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
void populateLowerQuantOpsPatterns(RewritePatternSet &patterns)
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.