MLIR  22.0.0git
TosaToLinalgNamed.cpp
Go to the documentation of this file.
1 //===- TosaToLinalgNamed.cpp - Lowering Tosa to Linalg Named Ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg named ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
21 #include "mlir/IR/PatternMatch.h"
23 
24 #include <type_traits>
25 
26 using namespace mlir;
27 using namespace mlir::tosa;
28 
30  TypedAttr padAttr, OpBuilder &rewriter) {
31  // Input should be padded only if necessary.
32  if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
33  return input;
34 
35  ShapedType inputTy = cast<ShapedType>(input.getType());
36  Type inputETy = inputTy.getElementType();
37  auto inputShape = inputTy.getShape();
38 
39  assert((inputShape.size() * 2) == pad.size());
40 
41  SmallVector<int64_t, 4> paddedShape;
43  SmallVector<OpFoldResult, 8> highIndices;
44  for (size_t i : llvm::seq(inputShape.size())) {
45  auto lowPad = pad[i * 2];
46  auto highPad = pad[i * 2 + 1];
47  if (ShapedType::isDynamic(inputShape[i]))
48  paddedShape.push_back(inputShape[i]);
49  else
50  paddedShape.push_back(inputShape[i] + highPad + lowPad);
51  lowIndices.push_back(rewriter.getIndexAttr(lowPad));
52  highIndices.push_back(rewriter.getIndexAttr(highPad));
53  }
54 
55  Value padValue = arith::ConstantOp::create(rewriter, loc, padAttr);
56 
57  return tensor::PadOp::create(rewriter, loc,
58  RankedTensorType::get(paddedShape, inputETy),
59  input, lowIndices, highIndices, padValue);
60 }
61 
62 static mlir::Value
64  Value conv, Value result,
65  ArrayRef<AffineMap> indexingMaps) {
66  ShapedType resultTy = cast<ShapedType>(conv.getType());
67  return linalg::GenericOp::create(
68  rewriter, loc, resultTy, ValueRange({bias, conv}), result,
69  indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
70  [](OpBuilder &builder, Location loc, ValueRange args) {
71  Value biasVal = args[0];
72  Type resType = args[1].getType();
73  if (resType != biasVal.getType()) {
74  biasVal =
75  arith::ExtSIOp::create(builder, loc, resType, biasVal);
76  }
77  Value added =
78  arith::AddIOp::create(builder, loc, biasVal, args[1]);
79  linalg::YieldOp::create(builder, loc, added);
80  })
81  .getResult(0);
82 }
83 
84 // Construct the affine map that a linalg generic would use to broadcast the
85 // source tensor into the shape of the result tensor.
87  Value result) {
88  ShapedType resultTy = cast<ShapedType>(result.getType());
89  ShapedType sourceTy = cast<ShapedType>(source.getType());
90  const int64_t resultRank = resultTy.getRank();
91  const int64_t sourceRank = sourceTy.getRank();
92 
93  // The source tensor is broadcast to all the outer dimensions of the
94  // result tensor.
95  SmallVector<AffineExpr> sourceDims;
96  // In the case of a rank one source tensor with a single element TOSA
97  // specifies that the value be broadcast meaning we need an edge case for a
98  // constant map.
99  assert(sourceTy.hasStaticShape() &&
100  "Dynamic broadcasting shapes not supported!");
101  if (sourceRank == 1 && sourceTy.getDimSize(0) == 1) {
102  sourceDims.push_back(rewriter.getAffineConstantExpr(0));
103  } else {
104  for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
105  auto expr = rewriter.getAffineDimExpr(dim + resultRank - sourceRank);
106  sourceDims.push_back(expr);
107  }
108  }
109 
110  return AffineMap::get(/*dimCount=*/resultRank,
111  /*symbolCount=*/0, sourceDims, rewriter.getContext());
112 }
113 
114 // Broadcast the source value to all the outer dimensions of the result value.
115 // If required, the element type is expanded using an arith.extsi or arith.extf
116 // operation as appropriate.
118  Location loc, Value source,
119  Value result) {
120  ShapedType resultTy = cast<ShapedType>(result.getType());
121  const int64_t resultRank = resultTy.getRank();
122  // Creating maps for the input and output of the broacast-like generic op.
123  SmallVector<AffineMap, 2> indexingMaps;
124  indexingMaps.push_back(getBroadcastingMap(rewriter, source, result));
125  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
126 
127  // Build the broadcast-like operation as a linalg.generic.
128  return linalg::GenericOp::create(
129  rewriter, loc, resultTy, ValueRange({source}), result,
130  indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
131  [&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
132  Value biasVal = args[0];
133  Type resType = args[1].getType();
134  if (resType != biasVal.getType()) {
135  biasVal =
136  resultTy.getElementType().isFloat()
137  ? arith::ExtFOp::create(builder, loc, resType, biasVal)
138  .getResult()
139  : arith::ExtSIOp::create(builder, loc, resType,
140  biasVal)
141  .getResult();
142  }
143  linalg::YieldOp::create(builder, loc, biasVal);
144  })
145  .getResult(0);
146 }
147 
148 static mlir::Value reifyConstantDim(int64_t attr,
149  ImplicitLocOpBuilder &builder) {
150  return arith::ConstantIndexOp::create(builder, attr);
151 }
152 
153 // Calculating the output width/height using the formula:
154 // H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
155 // W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
156 
158  int64_t padBeforeAttr,
159  int64_t padAfterAttr, Value kernelDim,
160  int64_t strideAttr,
161  int64_t dilationAttr,
162  OpBuilder &rewriter) {
163  ImplicitLocOpBuilder builder(loc, rewriter);
164  auto one = arith::ConstantOp::create(rewriter, loc,
165  IntegerAttr::get(inputDim.getType(), 1));
166  Value padBefore = reifyConstantDim(padBeforeAttr, builder);
167  Value paddedBefore = arith::AddIOp::create(builder, inputDim, padBefore);
168  Value padAfter = reifyConstantDim(padAfterAttr, builder);
169  Value paddedAfter = arith::AddIOp::create(builder, paddedBefore, padAfter);
170 
171  Value subOne = arith::SubIOp::create(builder, kernelDim, one);
172  Value dilation = reifyConstantDim(dilationAttr, builder);
173  Value dilated = arith::MulIOp::create(builder, dilation, subOne);
174  Value addOne = arith::AddIOp::create(builder, dilated, one);
175 
176  Value subtract = arith::SubIOp::create(builder, paddedAfter, addOne);
177  Value stride = reifyConstantDim(strideAttr, builder);
178  Value divide = arith::DivUIOp::create(builder, subtract, stride);
179  return arith::AddIOp::create(builder, divide, one);
180 }
181 
182 // Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D
184  Location loc, Value input, Value weight, ShapedType resultTy,
185  ArrayRef<int64_t> padAttr, ArrayRef<int64_t> strideAttr,
186  ArrayRef<int64_t> dilationAttr, ArrayRef<int64_t> inputSizeDims,
187  ArrayRef<int64_t> kernelSizeDims, OpBuilder &rewriter) {
188  ShapedType inputTy = cast<ShapedType>(input.getType());
189  int64_t inputRank = inputTy.getRank();
190 
191  SmallVector<Value> dynDims;
192  dynDims.resize(resultTy.getRank());
193 
194  for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) {
195  int64_t inputDim = inputSizeDims[i];
196  int64_t kernelDim = kernelSizeDims[i];
197  if (resultTy.isDynamicDim(inputDim)) {
198  auto padTop = padAttr[i * 2];
199  auto padBottom = padAttr[i * 2 + 1];
200  auto stride = strideAttr[i];
201  auto dilation = dilationAttr[i];
202  Value initDynDim = tensor::DimOp::create(rewriter, loc, input, inputDim);
203  Value kernelDynDim =
204  tensor::DimOp::create(rewriter, loc, weight, kernelDim);
205  // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
206  dynDims[inputDim] =
207  getConvOrPoolOutputDim(loc, initDynDim, padTop, padBottom,
208  kernelDynDim, stride, dilation, rewriter);
209  }
210  }
211 
212  // Get the batch/channels dimensions.
213  for (int i = 0; i < inputRank; i++) {
214  if (resultTy.isDynamicDim(i) && !dynDims[i])
215  dynDims[i] = tensor::DimOp::create(rewriter, loc, input, i);
216  }
217 
218  SmallVector<Value> filteredDims = condenseValues(dynDims);
219  return filteredDims;
220 }
221 
222 // Creates a map to collapse the last dimension of the Depthwise convolution op
223 // due to a shape mismatch
225  int64_t outputRank, SmallVector<ReassociationExprs, 4> &reassociationMap,
226  OpBuilder &rewriter) {
227  reassociationMap.resize(outputRank);
228  for (int i = 0; i < outputRank; i++) {
229  reassociationMap[i].push_back(rewriter.getAffineDimExpr(i));
230  }
231  reassociationMap[outputRank - 1].push_back(
232  rewriter.getAffineDimExpr(outputRank));
233 }
234 
235 namespace {
236 
237 template <typename TosaConvOp, typename LinalgConvOp, typename LinalgConvQOp>
238 class ConvConverter : public OpConversionPattern<TosaConvOp> {
239 public:
241  LogicalResult
242  matchAndRewrite(TosaConvOp op, typename TosaConvOp::Adaptor adaptor,
243  ConversionPatternRewriter &rewriter) const final {
244  Location loc = op->getLoc();
245  Value input = op->getOperand(0);
246  Value weight = op->getOperand(1);
247  Value bias = op->getOperand(2);
248 
249  ShapedType inputTy = cast<ShapedType>(input.getType());
250  ShapedType weightTy = cast<ShapedType>(weight.getType());
251  ShapedType biasTy = cast<ShapedType>(bias.getType());
252  ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
253 
254  Type inputETy = inputTy.getElementType();
255 
256  DenseI64ArrayAttr padAttr = op.getPadAttr();
257  DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
258  DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
259 
260  Type accETy = op.getAccType();
261  Type accTy = RankedTensorType::get(resultTy.getShape(), accETy);
262 
263  // Get and verify zero points.
264  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
265  if (failed(maybeIZp))
266  return rewriter.notifyMatchFailure(
267  op, "input zero point cannot be statically determined");
268 
269  FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
270  if (failed(maybeWZp))
271  return rewriter.notifyMatchFailure(
272  op, "weight zero point cannot be statically determined");
273 
274  const int64_t inputZpVal = *maybeIZp;
275  const int64_t weightZpVal = *maybeWZp;
276 
277  if (op.verifyInputZeroPoint(inputZpVal).failed())
278  return rewriter.notifyMatchFailure(
279  op, "input zero point must be zero for non-int8 integer types");
280 
281  if (op.verifyWeightZeroPoint(weightZpVal).failed())
282  return rewriter.notifyMatchFailure(
283  op, "weight zero point must be zero for non-int8 integer types");
284 
285  bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
286 
287  if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
288  return rewriter.notifyMatchFailure(
289  op, "tosa.conv ops require static shapes for weight and bias");
290 
291  if (inputETy.isUnsignedInteger())
292  return rewriter.notifyMatchFailure(
293  op, "tosa.conv ops does not support unsigned integer input");
294 
295  llvm::SmallVector<int64_t> inputSizeDims;
296  llvm::SmallVector<int64_t> kernelSizeDims;
297  for (int i = 1; i < resultTy.getRank() - 1; i++) {
298  inputSizeDims.push_back(i);
299  kernelSizeDims.push_back(i);
300  }
301 
303  loc, input, weight, resultTy, padAttr.asArrayRef(),
304  strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
305  inputSizeDims, kernelSizeDims, rewriter);
306 
307  auto weightShape = weightTy.getShape();
308 
309  // Apply padding as necessary.
310  TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
311  if (hasZp) {
312  int64_t intMin =
313  APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
314  .getSExtValue();
315  int64_t intMax =
316  APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
317  .getSExtValue();
318 
319  if (inputZpVal < intMin || inputZpVal > intMax)
320  return rewriter.notifyMatchFailure(
321  op, "tosa.conv op quantization has zp outside of input range");
322 
323  zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
324  }
325 
327  pad.resize(2, 0);
328  llvm::append_range(pad, padAttr.asArrayRef());
329  pad.resize(pad.size() + 2, 0);
330  input = applyPad(loc, input, pad, zeroAttr, rewriter);
331 
332  if (4 == inputTy.getRank()) {
333  // For 2D convolutions, we need to check if the target convolution op
334  // wants a HWCF kernel layout.
335  bool wantHwcf =
336  hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
337  : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
338  if (wantHwcf) {
339  // Transpose the kernel to match dimension ordering of the linalg
340  // convolution operation.
341  // TODO(suderman): See if this can be efficiently folded - check whether
342  // the input is used anywhere else, if not fold the constant.
343  SmallVector<int32_t> weightPerm;
344  for (int i = 1; i < resultTy.getRank(); i++)
345  weightPerm.push_back(i);
346  weightPerm.push_back(0);
347 
348  SmallVector<int64_t> newWeightShape;
349  for (auto dim : weightPerm)
350  newWeightShape.push_back(weightShape[dim]);
351  auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
352  Type newWeightTy =
353  RankedTensorType::get(newWeightShape, weightTy.getElementType());
354  weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight,
355  weightPermAttr);
356  }
357  }
358 
359  // For Conv3D transpose the kernel to match dimension ordering of the linalg
360  // convolution operation. Conv2D has a 1-1 mapping in linalg so better to
361  // map directly and then transpose later if desired.
362  if (5 == inputTy.getRank()) {
363  // TODO(suderman): See if this can be efficiently folded - check whether
364  // the input is used anywhere else, if not fold the constant.
365  SmallVector<int32_t> weightPerm;
366  for (int i = 1; i < resultTy.getRank(); i++)
367  weightPerm.push_back(i);
368  weightPerm.push_back(0);
369 
370  SmallVector<int64_t> newWeightShape;
371  for (auto dim : weightPerm)
372  newWeightShape.push_back(weightShape[dim]);
373  auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
374  Type newWeightTy =
375  RankedTensorType::get(newWeightShape, weightTy.getElementType());
376  weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight,
377  weightPermAttr);
378  }
379 
380  // Extract the attributes for convolution.
381  ArrayRef<int64_t> stride = strideTosaAttr;
382  ArrayRef<int64_t> dilation = dilationTosaAttr;
383 
384  // Create the convolution op.
385  auto strideAttr = rewriter.getI64TensorAttr(stride);
386  auto dilationAttr = rewriter.getI64TensorAttr(dilation);
387 
388  Value biasEmptyTensor = tensor::EmptyOp::create(
389  rewriter, loc, resultTy.getShape(), accETy, filteredDims);
390 
391  Value broadcastBias =
392  linalgBroadcastAndMaybeExt(rewriter, loc, bias, biasEmptyTensor);
393 
394  if (hasZp) {
395  auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
396  auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
397 
398  auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
399  auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp);
400 
401  Value conv = LinalgConvQOp::create(
402  rewriter, loc, resultTy,
403  ValueRange{input, weight, iZpVal, kZpVal},
404  ValueRange{broadcastBias}, strideAttr, dilationAttr)
405  ->getResult(0);
406 
407  rewriter.replaceOp(op, conv);
408  return success();
409  }
410 
411  Value conv = LinalgConvOp::create(
412  rewriter, loc, accTy, ValueRange{input, weight},
413  ValueRange{broadcastBias}, strideAttr, dilationAttr)
414  ->getResult(0);
415 
416  // We may need to truncate back to the result type if the accumulator was
417  // wider than the result.
418  if (resultTy != accTy)
419  conv = tosa::CastOp::create(rewriter, loc, resultTy, conv);
420 
421  rewriter.replaceOp(op, conv);
422  return success();
423  }
424 };
425 
426 class DepthwiseConvConverter
427  : public OpConversionPattern<tosa::DepthwiseConv2DOp> {
428 public:
430  LogicalResult
431  matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor,
432  ConversionPatternRewriter &rewriter) const final {
433  Location loc = op->getLoc();
434  Value input = op->getOperand(0);
435  Value weight = op->getOperand(1);
436  Value bias = op->getOperand(2);
437 
438  ShapedType inputTy = cast<ShapedType>(input.getType());
439  ShapedType weightTy = cast<ShapedType>(weight.getType());
440  ShapedType biasTy = cast<ShapedType>(bias.getType());
441  ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
442  int64_t resultRank = resultTy.getRank();
443 
444  Type inputETy = inputTy.getElementType();
445  Type resultETy = resultTy.getElementType();
446 
447  auto padAttr = cast<DenseI64ArrayAttr>(op->getAttr("pad"));
448  auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("stride"));
449  auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("dilation"));
450 
451  Type accETy = op.getAccType();
452 
453  if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
454  return rewriter.notifyMatchFailure(
455  op, "tosa.depthwise_conv ops require static shapes");
456 
457  // Compute output dynamic dims
459  loc, input, weight, resultTy, padAttr.asArrayRef(),
460  strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
461  /*inputSizeDims=*/{1, 2},
462  /*kernelSizeDims=*/{0, 1}, rewriter);
463 
464  // Get and verify zero points.
465 
466  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
467  FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
468  if (failed(maybeIZp))
469  return rewriter.notifyMatchFailure(
470  op, "input zero point cannot be statically determined");
471  if (failed(maybeWZp))
472  return rewriter.notifyMatchFailure(
473  op, "weight zero point cannot be statically determined");
474 
475  const int64_t inputZpVal = *maybeIZp;
476  const int64_t weightZpVal = *maybeWZp;
477 
478  if (op.verifyInputZeroPoint(inputZpVal).failed())
479  return rewriter.notifyMatchFailure(
480  op, "input zero point must be zero for non-int8 integer types");
481 
482  if (op.verifyWeightZeroPoint(weightZpVal).failed())
483  return rewriter.notifyMatchFailure(
484  op, "weight zero point must be zero for non-int8 integer types");
485 
486  bool hasNullZps = (inputZpVal == 0) && (weightZpVal == 0);
487  auto weightShape = weightTy.getShape();
488  auto resultShape = resultTy.getShape();
489 
490  // Apply padding as necessary.
491  TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
492  if (!hasNullZps) {
493  int64_t intMin =
494  APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
495  .getSExtValue();
496  int64_t intMax =
497  APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
498  .getSExtValue();
499 
500  if (inputZpVal < intMin || inputZpVal > intMax)
501  return rewriter.notifyMatchFailure(
502  op, "tosa.depthwise_conv op quantization has zp outside of input "
503  "range");
504 
505  zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
506  }
507 
509  pad.resize(2, 0);
510  llvm::append_range(pad, padAttr.asArrayRef());
511  pad.resize(pad.size() + 2, 0);
512 
513  input = applyPad(loc, input, pad, zeroAttr, rewriter);
514 
515  // Extract the attributes for convolution.
516  ArrayRef<int64_t> stride = strideTosaAttr;
517  ArrayRef<int64_t> dilation = dilationTosaAttr;
518 
519  // Create the convolution op.
520  auto strideAttr = rewriter.getI64TensorAttr(stride);
521  auto dilationAttr = rewriter.getI64TensorAttr(dilation);
522  ShapedType linalgConvTy =
523  RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
524  weightShape[2], weightShape[3]},
525  accETy);
526 
527  auto resultZeroAttr = rewriter.getZeroAttr(accETy);
528  Value emptyTensor = tensor::EmptyOp::create(
529  rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims);
530  Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr);
531  Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero},
532  ValueRange{emptyTensor})
533  .result();
534 
535  Value biasEmptyTensor = tensor::EmptyOp::create(
536  rewriter, loc, resultTy.getShape(), resultETy, filteredDims);
537 
538  // Broadcast the initial value to the output tensor before convolving.
539  SmallVector<AffineMap, 4> indexingMaps;
540  indexingMaps.push_back(getBroadcastingMap(rewriter, bias, biasEmptyTensor));
541  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
542  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
543 
544  if (hasNullZps) {
545  Value conv = linalg::DepthwiseConv2DNhwcHwcmOp::create(
546  rewriter, loc, linalgConvTy, ValueRange{input, weight},
547  ValueRange{zeroTensor}, strideAttr, dilationAttr)
548  .getResult(0);
549 
550  // We may need to truncate back to the result type if the accumulator was
551  // wider than the result.
552  if (accETy != resultETy)
553  conv = tosa::CastOp::create(
554  rewriter, loc,
555  RankedTensorType::get(cast<ShapedType>(conv.getType()).getShape(),
556  resultETy),
557  conv);
558 
559  SmallVector<ReassociationExprs, 4> reassociationMap;
560  createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
561  Value convReshape = tensor::CollapseShapeOp::create(
562  rewriter, loc, resultTy, conv, reassociationMap);
563 
564  Value result =
565  linalg::GenericOp::create(
566  rewriter, loc, resultTy, ValueRange({bias, convReshape}),
567  biasEmptyTensor, indexingMaps, getNParallelLoopsAttrs(resultRank),
568  [&](OpBuilder &nestedBuilder, Location nestedLoc,
569  ValueRange args) {
570  Value added;
571  if (llvm::isa<FloatType>(inputETy))
572  added = arith::AddFOp::create(nestedBuilder, loc, args[0],
573  args[1]);
574  else
575  added = arith::AddIOp::create(nestedBuilder, loc, args[0],
576  args[1]);
577  linalg::YieldOp::create(nestedBuilder, nestedLoc, added);
578  })
579  .getResult(0);
580  rewriter.replaceOp(op, result);
581  } else {
582  IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal);
583  IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
584  auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
585  auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp);
586  Value conv = linalg::DepthwiseConv2DNhwcHwcmQOp::create(
587  rewriter, loc, linalgConvTy,
588  ValueRange{input, weight, iZpVal, kZpVal},
589  ValueRange{zeroTensor}, strideAttr, dilationAttr)
590  .getResult(0);
591  SmallVector<ReassociationExprs, 4> reassociationMap;
592  createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
593  Value convReshape = tensor::CollapseShapeOp::create(
594  rewriter, loc, resultTy, conv, reassociationMap);
596  rewriter, loc, bias, convReshape, biasEmptyTensor, indexingMaps);
597  rewriter.replaceOp(op, result);
598  }
599  return success();
600  }
601 };
602 
603 class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
604 public:
606  LogicalResult
607  matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor,
608  ConversionPatternRewriter &rewriter) const final {
609  Location loc = op.getLoc();
610 
611  auto outputTy = cast<ShapedType>(op.getType());
612  auto outputElementTy = outputTy.getElementType();
613 
614  SmallVector<Value> dynDims;
615  dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
616 
617  if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) {
618  dynDims[0] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 0);
619  }
620 
621  if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) {
622  dynDims[1] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 1);
623  }
624 
625  if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) {
626  dynDims[2] = tensor::DimOp::create(rewriter, loc, op->getOperand(1), 2);
627  }
628 
629  SmallVector<Value> filteredDims = condenseValues(dynDims);
630 
631  auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
632  Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
633  auto emptyTensor =
634  tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(),
635  outputTy.getElementType(), filteredDims);
636  Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero},
637  ValueRange{emptyTensor})
638  .result();
639 
640  FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
641  FailureOr<int64_t> maybeBZp = op.getBZeroPoint();
642  if (failed(maybeAZp))
643  return rewriter.notifyMatchFailure(
644  op, "input a zero point cannot be statically determined");
645  if (failed(maybeBZp))
646  return rewriter.notifyMatchFailure(
647  op, "input b zero point cannot be statically determined");
648 
649  const int64_t aZpVal = *maybeAZp;
650  const int64_t bZpVal = *maybeBZp;
651 
652  if (op.verifyAZeroPoint(aZpVal).failed())
653  return rewriter.notifyMatchFailure(
654  op, "input a zero point must be zero for non-int8 integer types");
655 
656  if (op.verifyBZeroPoint(bZpVal).failed())
657  return rewriter.notifyMatchFailure(
658  op, "input b zero point must be zero for non-int8 integer types");
659 
660  if (aZpVal == 0 && bZpVal == 0) {
661  rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
662  op, TypeRange{op.getType()},
663  ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
664  return success();
665  }
666 
667  auto aZp = arith::ConstantOp::create(rewriter, loc,
668  rewriter.getI32IntegerAttr(aZpVal));
669  auto bZp = arith::ConstantOp::create(rewriter, loc,
670  rewriter.getI32IntegerAttr(bZpVal));
671  rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
672  op, TypeRange{op.getType()},
673  ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
674 
675  return success();
676  }
677 };
678 
679 class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
680 public:
682 
683  // Compute the dynamic output sizes of the maxpool operation.
684  static SmallVector<Value>
685  computeDynamicOutputSizes(tosa::MaxPool2dOp op, OpAdaptor adaptor,
686  ConversionPatternRewriter &rewriter) {
687  TensorType resultTy = op.getType();
688  Location loc = op.getLoc();
689 
690  Value input = adaptor.getInput();
691  ArrayRef<int64_t> kernel = op.getKernel();
692  ArrayRef<int64_t> pad = op.getPad();
693  ArrayRef<int64_t> stride = op.getStride();
694 
695  SmallVector<Value> dynamicDims;
696 
697  // Batch dimension
698  if (resultTy.isDynamicDim(0))
699  dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0));
700 
701  // Height/width dimensions
702  for (int64_t dim : {1, 2}) {
703  if (!resultTy.isDynamicDim(dim))
704  continue;
705 
706  // Index into the attribute arrays
707  int64_t index = dim - 1;
708 
709  // Input height/width
710  Value ihw = tensor::DimOp::create(rewriter, loc, input, dim);
711 
712  // Kernel height/width
713  Value khw = arith::ConstantIndexOp::create(rewriter, loc, kernel[index]);
714 
715  // Output height/width
716  Value ohw = getConvOrPoolOutputDim(loc, ihw, pad[index * 2],
717  pad[index * 2 + 1], khw, stride[index],
718  /*dilationAttr=*/1, rewriter);
719  dynamicDims.push_back(ohw);
720  }
721 
722  // Channel dimension
723  if (resultTy.isDynamicDim(3))
724  dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 3));
725 
726  return dynamicDims;
727  }
728 
729  LogicalResult
730  matchAndRewrite(tosa::MaxPool2dOp op, OpAdaptor adaptor,
731  ConversionPatternRewriter &rewriter) const final {
732  Location loc = op.getLoc();
733  Value input = adaptor.getInput();
734  ShapedType inputTy = cast<ShapedType>(input.getType());
735 
736  bool isUnsigned = op.getType().getElementType().isUnsignedInteger();
737  ShapedType resultTy =
738  getTypeConverter()->convertType<ShapedType>(op.getType());
739  if (!resultTy)
740  return rewriter.notifyMatchFailure(op, "failed to convert type");
741  Type resultETy = inputTy.getElementType();
742 
743  SmallVector<Value> dynamicDims =
744  computeDynamicOutputSizes(op, adaptor, rewriter);
745 
746  // Determine what the initial value needs to be for the max pool op.
747  TypedAttr initialAttr;
748  if (resultETy.isF32() || resultETy.isBF16() || resultETy.isF16())
749  initialAttr = rewriter.getFloatAttr(
750  resultETy, APFloat::getLargest(
751  cast<FloatType>(resultETy).getFloatSemantics(), true));
752 
753  else if (isUnsigned)
754  initialAttr = rewriter.getIntegerAttr(
755  resultETy, APInt::getZero(resultETy.getIntOrFloatBitWidth()));
756  else if (isa<IntegerType>(resultETy))
757  initialAttr = rewriter.getIntegerAttr(
758  resultETy,
759  APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth()));
760 
761  if (!initialAttr)
762  return rewriter.notifyMatchFailure(
763  op, "Unsupported initial value for tosa.maxpool_2d op");
764 
765  // Apply padding as necessary.
767  pad.resize(2, 0);
768  llvm::append_range(pad, op.getPad());
769  pad.resize(pad.size() + 2, 0);
770 
771  Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
772 
773  Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr);
774 
775  ArrayRef<int64_t> kernel = op.getKernel();
776  ArrayRef<int64_t> stride = op.getStride();
777 
778  Attribute strideAttr = rewriter.getI64VectorAttr(stride);
779  Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
780 
781  // Create the linalg op that performs pooling.
782  Value emptyTensor =
783  tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
784  resultTy.getElementType(), dynamicDims);
785 
786  Value filledEmptyTensor =
787  linalg::FillOp::create(rewriter, loc, initialValue, emptyTensor)
788  .result();
789 
790  Value fakeWindowDims =
791  tensor::EmptyOp::create(rewriter, loc, kernel, resultETy);
792 
793  if (isUnsigned) {
794  rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
795  op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
796  filledEmptyTensor, strideAttr, dilationAttr);
797  return llvm::success();
798  }
799 
800  auto resultOp = linalg::PoolingNhwcMaxOp::create(
801  rewriter, op->getLoc(), ArrayRef<Type>{resultTy},
802  ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
803  dilationAttr);
804 
805  rewriter.setInsertionPointAfter(op);
806  NanPropagationMode nanMode = op.getNanMode();
807  rewriter.replaceOp(op, resultOp);
808 
809  // NaN propagation has no meaning for non floating point types.
810  if (!isa<FloatType>(getElementTypeOrSelf(inputTy)))
811  return success();
812 
813  // "PROPAGATE" mode matches the behaviour of the LinAlg named op, so no
814  // compare and select materialization is required.
815  //
816  // In the case of "IGNORE" we need to insert a compare and select. Since
817  // we've already produced a named op we will just take its body and modify
818  // it to include the appropriate checks. If the current value is NaN the
819  // old value of pool will be taken otherwise we use the result.
820  if (nanMode == NanPropagationMode::IGNORE) {
821  auto genericOp = linalg::GenericOp::create(
822  rewriter, loc, resultOp.getType(0), resultOp.getInputs(),
823  resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
824  resultOp.getIteratorTypesArray(),
825  [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
826  IRMapping map;
827  auto oldBlock = resultOp.getRegion().begin();
828  auto oldArgs = oldBlock->getArguments();
829  auto &oldMaxOp = *resultOp.getBlock()->begin();
830  map.map(oldArgs, blockArgs);
831  auto *newOp = opBuilder.clone(oldMaxOp, map);
832  Value isNaN =
833  arith::CmpFOp::create(opBuilder, loc, arith::CmpFPredicate::UNO,
834  blockArgs.front(), blockArgs.front());
835  auto selectOp = arith::SelectOp::create(
836  opBuilder, loc, isNaN, blockArgs.back(), newOp->getResult(0));
837  linalg::YieldOp::create(opBuilder, loc, selectOp.getResult());
838  });
839  rewriter.replaceOp(resultOp, genericOp);
840  }
841 
842  return success();
843  }
844 };
845 
846 class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
847 public:
849 
850  LogicalResult matchAndRewrite(tosa::AvgPool2dOp op,
851  PatternRewriter &rewriter) const final {
852  Location loc = op.getLoc();
853  Value input = op.getInput();
854  ShapedType inputTy = cast<ShapedType>(input.getType());
855  Type inElementTy = inputTy.getElementType();
856 
857  ShapedType resultTy = cast<ShapedType>(op.getType());
858  Type resultETy = cast<ShapedType>(op.getType()).getElementType();
859 
860  Type accETy = op.getAccType();
861  ShapedType accTy = resultTy.clone(accETy);
862 
863  auto dynamicDimsOr =
864  checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
865  if (!dynamicDimsOr.has_value())
866  return failure();
867  SmallVector<Value> dynamicDims = *dynamicDimsOr;
868 
869  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
870  FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
871  if (failed(maybeIZp))
872  return rewriter.notifyMatchFailure(
873  op, "input zero point could not be statically determined");
874  if (failed(maybeOZp))
875  return rewriter.notifyMatchFailure(
876  op, "output zero point could not be statically determined");
877 
878  const int64_t inputZpVal = *maybeIZp;
879  const int64_t outputZpVal = *maybeOZp;
880 
881  // Apply padding as necessary.
883  pad.resize(2, 0);
884  llvm::append_range(pad, op.getPad());
885  pad.resize(pad.size() + 2, 0);
886  TypedAttr padAttr = rewriter.getZeroAttr(inElementTy);
887  // Unsupported element type
888  if (!padAttr)
889  return failure();
890  Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter);
891 
892  auto initialAttr = rewriter.getZeroAttr(accETy);
893  Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr);
894 
895  ArrayRef<int64_t> kernel = op.getKernel();
896  ArrayRef<int64_t> stride = op.getStride();
897 
898  Attribute strideAttr = rewriter.getI64VectorAttr(stride);
899  Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
900 
901  // Create the linalg op that performs pooling.
902  Value poolEmptyTensor = tensor::EmptyOp::create(
903  rewriter, loc, accTy.getShape(), accETy, dynamicDims);
904 
905  Value filledEmptyTensor =
906  linalg::FillOp::create(rewriter, loc, ValueRange{initialValue},
907  ValueRange{poolEmptyTensor})
908  .result();
909 
910  Value fakeWindowDims =
911  tensor::EmptyOp::create(rewriter, loc, kernel, accETy);
912 
913  // Sum across the pooled region.
914  Value poolingOp = linalg::PoolingNhwcSumOp::create(
915  rewriter, loc, ArrayRef<Type>{accTy},
916  ValueRange{paddedInput, fakeWindowDims},
917  filledEmptyTensor, strideAttr, dilationAttr)
918  .getResult(0);
919 
920  // Normalize the summed value by the number of elements grouped in each
921  // pool.
922  Value iH = tensor::DimOp::create(rewriter, loc, poolingOp, 1);
923  Value iW = tensor::DimOp::create(rewriter, loc, poolingOp, 2);
924 
925  auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
926  iH = arith::SubIOp::create(rewriter, loc, iH, one);
927  iW = arith::SubIOp::create(rewriter, loc, iW, one);
928 
929  Value genericEmptyTensor = tensor::EmptyOp::create(
930  rewriter, loc, resultTy.getShape(), resultETy, dynamicDims);
931 
932  auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
933  auto genericOp = linalg::GenericOp::create(
934  rewriter, loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
935  ValueRange{genericEmptyTensor},
936  ArrayRef<AffineMap>({affineMap, affineMap}),
937  getNParallelLoopsAttrs(resultTy.getRank()),
938  [&](OpBuilder &b, Location loc, ValueRange args) {
939  auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
940 
941  // Determines what the portion of valid input is covered by the
942  // kernel.
943  auto padFn = [&](Value valid, Value pos, int64_t pad) -> Value {
944  if (pad == 0)
945  return valid;
946 
947  auto padVal = arith::ConstantIndexOp::create(rewriter, loc, pad);
948  Value dpos = arith::SubIOp::create(rewriter, loc, pos, padVal);
949 
950  Value offset = arith::MinSIOp::create(rewriter, loc, dpos, zero);
951  return arith::AddIOp::create(rewriter, loc, valid, offset)
952  ->getResult(0);
953  };
954 
955  auto coverageFn = [&](int64_t i, Value isize) -> Value {
956  Value strideVal =
957  arith::ConstantIndexOp::create(rewriter, loc, stride[i - 1]);
958  Value val =
959  arith::ConstantIndexOp::create(rewriter, loc, kernel[i - 1]);
960 
961  // Find the position relative to the input tensor's ends.
962  Value left = linalg::IndexOp::create(rewriter, loc, i);
963  Value right = arith::SubIOp::create(rewriter, loc, isize, left);
964  left = arith::MulIOp::create(rewriter, loc, left, strideVal);
965  right = arith::MulIOp::create(rewriter, loc, right, strideVal);
966 
967  // Determine how much padding was included.
968  val = padFn(val, left, pad[i * 2]);
969  val = padFn(val, right, pad[i * 2 + 1]);
970  return arith::MaxSIOp::create(rewriter, loc, one, val);
971  };
972 
973  // Compute the indices from either end.
974  Value kH3 = coverageFn(1, iH);
975  Value kW3 = coverageFn(2, iW);
976 
977  // Compute the total number of elements and normalize.
978  auto count = arith::IndexCastOp::create(
979  rewriter, loc, rewriter.getI32Type(),
980  arith::MulIOp::create(rewriter, loc, kH3, kW3));
981 
982  // Divide by the number of summed values. For floats this is just
983  // a div however for quantized values input normalization had
984  // to be applied.
985  Value poolVal = args[0];
986  if (isa<FloatType>(accETy)) {
987  auto countF = arith::SIToFPOp::create(rewriter, loc, accETy, count);
988  poolVal = arith::DivFOp::create(rewriter, loc, poolVal, countF)
989  ->getResult(0);
990  if (accETy.getIntOrFloatBitWidth() >
991  resultETy.getIntOrFloatBitWidth())
992  poolVal =
993  arith::TruncFOp::create(rewriter, loc, resultETy, poolVal);
994  } else {
995 
996  // If we have quantization information we need to apply an offset
997  // for the input zp value.
998  if (inputZpVal != 0) {
999  auto inputZp = arith::ConstantOp::create(
1000  rewriter, loc, b.getIntegerAttr(accETy, inputZpVal));
1001  Value offset =
1002  arith::MulIOp::create(rewriter, loc, accETy, count, inputZp);
1003  poolVal =
1004  arith::SubIOp::create(rewriter, loc, accETy, poolVal, offset);
1005  }
1006 
1007  // Compute: k = 32 - count_leading_zeros(value - 1)
1008  Value one32 = arith::ConstantOp::create(
1009  rewriter, loc, rewriter.getI32IntegerAttr(1));
1010  Value thirtyTwo32 = arith::ConstantOp::create(
1011  rewriter, loc, rewriter.getI32IntegerAttr(32));
1012 
1013  Value countSubOne =
1014  arith::SubIOp::create(rewriter, loc, count, one32);
1015  Value leadingZeros =
1016  math::CountLeadingZerosOp::create(rewriter, loc, countSubOne);
1017  Value k =
1018  arith::SubIOp::create(rewriter, loc, thirtyTwo32, leadingZeros);
1019 
1020  // Compute: numerator = ((1 << 30) + 1) << k
1021  Value k64 =
1022  arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), k);
1023  Value thirtyShiftPlusOne = arith::ConstantOp::create(
1024  rewriter, loc, rewriter.getI64IntegerAttr((1 << 30) + 1));
1025  Value numerator =
1026  arith::ShLIOp::create(rewriter, loc, thirtyShiftPlusOne, k64);
1027 
1028  // Compute: scale.multiplier = numerator / value;
1029  Value count64 = arith::ExtUIOp::create(
1030  rewriter, loc, rewriter.getI64Type(), count);
1031  Value multiplier =
1032  arith::DivUIOp::create(rewriter, loc, numerator, count64);
1033  multiplier = arith::TruncIOp::create(
1034  rewriter, loc, rewriter.getI32Type(), multiplier);
1035 
1036  // Compute: scale.shift = 30 + k
1037  Value k8 =
1038  arith::TruncIOp::create(rewriter, loc, rewriter.getI8Type(), k);
1039  Value thirty8 = arith::ConstantOp::create(
1040  rewriter, loc, rewriter.getI8IntegerAttr(30));
1041  Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
1042 
1043  auto roundingAttr = RoundingModeAttr::get(
1044  rewriter.getContext(), RoundingMode::SINGLE_ROUND);
1045 
1046  auto scaled = tosa::ApplyScaleOp::create(
1047  rewriter, loc, rewriter.getI32Type(), poolVal,
1048  multiplier, shift, roundingAttr)
1049  .getResult();
1050 
1051  // If we have quantization information we need to apply output
1052  // zeropoint.
1053  if (outputZpVal != 0) {
1054  auto outputZp = arith::ConstantOp::create(
1055  rewriter, loc,
1056  b.getIntegerAttr(scaled.getType(), outputZpVal));
1057  scaled = arith::AddIOp::create(rewriter, loc, scaled, outputZp)
1058  .getResult();
1059  }
1060 
1061  // Apply Clip.
1062  int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
1063 
1065  rewriter, loc, accETy,
1066  APInt::getSignedMinValue(outBitwidth).getSExtValue());
1068  rewriter, loc, accETy,
1069  APInt::getSignedMaxValue(outBitwidth).getSExtValue());
1070  auto clamp = clampIntHelper(loc, scaled, min, max, rewriter,
1071  /*isUnsigned=*/false);
1072 
1073  poolVal = clamp;
1074  // Convert type.
1075  if (resultETy != clamp.getType()) {
1076  poolVal =
1077  arith::TruncIOp::create(rewriter, loc, resultETy, poolVal);
1078  }
1079  }
1080 
1081  linalg::YieldOp::create(rewriter, loc, poolVal);
1082  });
1083 
1084  rewriter.replaceOp(op, genericOp.getResult(0));
1085  return success();
1086  }
1087 };
1088 
1089 class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
1090 public:
1092 
1093  LogicalResult matchAndRewrite(tosa::TransposeOp op,
1094  PatternRewriter &rewriter) const final {
1095  const llvm::ArrayRef<int32_t> constantPerms = op.getPerms();
1096 
1097  Location loc = op.getLoc();
1098  // The verifier should have made sure we have a valid TOSA permutation
1099  // tensor. isPermutationVector doesn't actually check the TOSA perms we
1100  // expect.
1101  SmallVector<OpFoldResult> inputSizes =
1102  tensor::getMixedSizes(rewriter, loc, op.getInput1());
1103  auto permutedSizes =
1104  applyTOSAPermutation<OpFoldResult>(inputSizes, constantPerms);
1105 
1106  auto permutedInit =
1107  tensor::EmptyOp::create(rewriter, loc, permutedSizes,
1108  op.getInput1().getType().getElementType());
1109  rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
1110  op, op.getInput1(), permutedInit,
1111  llvm::to_vector(llvm::map_range(
1112  constantPerms, [](int32_t v) -> int64_t { return v; })));
1113  return success();
1114  }
1115 };
1116 } // namespace
1117 
1119  const TypeConverter &converter, RewritePatternSet *patterns,
1120  const TosaToLinalgNamedOptions &options) {
1121  if (options.preferConv2DKernelLayoutHWCF) {
1122  patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
1123  linalg::Conv2DNhwcHwcfQOp>>(
1124  patterns->getContext());
1125  } else {
1126  patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp,
1127  linalg::Conv2DNhwcFhwcQOp>>(
1128  patterns->getContext());
1129  }
1130  patterns->add<
1131  // clang-format off
1132  ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
1133  DepthwiseConvConverter,
1134  MatMulConverter,
1135  AvgPool2dConverter,
1136  TransposeConverter
1137  >(patterns->getContext());
1138 
1139  patterns->add<
1140  MaxPool2dConverter
1141  >(converter, patterns->getContext());
1142  // clang-format on
1143 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static llvm::ManagedStatic< PassManagerOptions > options
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source, Value result)
static mlir::Value applyPad(Location loc, Value input, ArrayRef< int64_t > pad, TypedAttr padAttr, OpBuilder &rewriter)
static void createDepthwiseConvCollapseMap(int64_t outputRank, SmallVector< ReassociationExprs, 4 > &reassociationMap, OpBuilder &rewriter)
static mlir::Value linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, Value conv, Value result, ArrayRef< AffineMap > indexingMaps)
static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim, int64_t padBeforeAttr, int64_t padAfterAttr, Value kernelDim, int64_t strideAttr, int64_t dilationAttr, OpBuilder &rewriter)
static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter, Location loc, Value source, Value result)
static mlir::Value reifyConstantDim(int64_t attr, ImplicitLocOpBuilder &builder)
static SmallVector< Value > inferDynamicDimsForConv(Location loc, Value input, Value weight, ShapedType resultTy, ArrayRef< int64_t > padAttr, ArrayRef< int64_t > strideAttr, ArrayRef< int64_t > dilationAttr, ArrayRef< int64_t > inputSizeDims, ArrayRef< int64_t > kernelSizeDims, OpBuilder &rewriter)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:199
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:227
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:386
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:253
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:371
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:111
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:323
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:363
MLIRContext * getContext() const
Definition: Builders.h:56
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:127
IntegerType getI8Type()
Definition: Builders.cpp:58
IntegerAttr getI8IntegerAttr(int8_t value)
Definition: Builders.cpp:220
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:623
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
Type conversion class.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition: ArithOps.cpp:258
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
void populateTosaToLinalgNamedConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options)
Populates conversion passes from TOSA dialect to Linalg named operations.
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.
Definition: PatternMatch.h:297