MLIR  21.0.0git
TosaToLinalgNamed.cpp
Go to the documentation of this file.
1 //===- TosaToLinalgNamed.cpp - Lowering Tosa to Linalg Named Ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg named ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
28 
30 
31 #include <numeric>
32 #include <type_traits>
33 
34 using namespace mlir;
35 using namespace mlir::tosa;
36 
38  TypedAttr padAttr, OpBuilder &rewriter) {
39  // Input should be padded only if necessary.
40  if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
41  return input;
42 
43  ShapedType inputTy = cast<ShapedType>(input.getType());
44  Type inputETy = inputTy.getElementType();
45  auto inputShape = inputTy.getShape();
46 
47  assert((inputShape.size() * 2) == pad.size());
48 
49  SmallVector<int64_t, 4> paddedShape;
51  SmallVector<OpFoldResult, 8> highIndices;
52  for (size_t i : llvm::seq(inputShape.size())) {
53  auto lowPad = pad[i * 2];
54  auto highPad = pad[i * 2 + 1];
55  if (ShapedType::isDynamic(inputShape[i]))
56  paddedShape.push_back(inputShape[i]);
57  else
58  paddedShape.push_back(inputShape[i] + highPad + lowPad);
59  lowIndices.push_back(rewriter.getIndexAttr(lowPad));
60  highIndices.push_back(rewriter.getIndexAttr(highPad));
61  }
62 
63  Value padValue = rewriter.create<arith::ConstantOp>(loc, padAttr);
64 
65  return rewriter.create<tensor::PadOp>(
66  loc, RankedTensorType::get(paddedShape, inputETy), input, lowIndices,
67  highIndices, padValue);
68 }
69 
70 static mlir::Value
72  Value conv, Value result,
73  ArrayRef<AffineMap> indexingMaps) {
74  ShapedType resultTy = cast<ShapedType>(conv.getType());
75  return rewriter
76  .create<linalg::GenericOp>(
77  loc, resultTy, ValueRange({bias, conv}), result, indexingMaps,
78  getNParallelLoopsAttrs(resultTy.getRank()),
79  [](OpBuilder &builder, Location loc, ValueRange args) {
80  Value biasVal = args[0];
81  Type resType = args[1].getType();
82  if (resType != biasVal.getType()) {
83  biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal);
84  }
85  Value added = builder.create<arith::AddIOp>(loc, biasVal, args[1]);
86  builder.create<linalg::YieldOp>(loc, added);
87  })
88  .getResult(0);
89 }
90 
91 // Construct the affine map that a linalg generic would use to broadcast the
92 // source tensor into the shape of the result tensor.
94  Value result) {
95  ShapedType resultTy = cast<ShapedType>(result.getType());
96  ShapedType sourceTy = cast<ShapedType>(source.getType());
97  const int64_t resultRank = resultTy.getRank();
98  const int64_t sourceRank = sourceTy.getRank();
99 
100  // The source tensor is broadcast to all the outer dimensions of the
101  // result tensor.
102  SmallVector<AffineExpr> sourceDims;
103  // In the case of a rank one source tensor with a single element TOSA
104  // specifies that the value be broadcast meaning we need an edge case for a
105  // constant map.
106  assert(sourceTy.hasStaticShape() &&
107  "Dynamic broadcasting shapes not supported!");
108  if (sourceRank == 1 && sourceTy.getDimSize(0) == 1) {
109  sourceDims.push_back(rewriter.getAffineConstantExpr(0));
110  } else {
111  for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
112  auto expr = rewriter.getAffineDimExpr(dim + resultRank - sourceRank);
113  sourceDims.push_back(expr);
114  }
115  }
116 
117  return AffineMap::get(/*dimCount=*/resultRank,
118  /*symbolCount=*/0, sourceDims, rewriter.getContext());
119 }
120 
121 // Broadcast the source value to all the outer dimensions of the result value.
122 // If required, the element type is expanded using an arith.extsi or arith.extf
123 // operation as appropriate.
125  Location loc, Value source,
126  Value result) {
127  ShapedType resultTy = cast<ShapedType>(result.getType());
128  const int64_t resultRank = resultTy.getRank();
129  // Creating maps for the input and output of the broacast-like generic op.
130  SmallVector<AffineMap, 2> indexingMaps;
131  indexingMaps.push_back(getBroadcastingMap(rewriter, source, result));
132  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
133 
134  // Build the broadcast-like operation as a linalg.generic.
135  return rewriter
136  .create<linalg::GenericOp>(
137  loc, resultTy, ValueRange({source}), result, indexingMaps,
138  getNParallelLoopsAttrs(resultTy.getRank()),
139  [&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
140  Value biasVal = args[0];
141  Type resType = args[1].getType();
142  if (resType != biasVal.getType()) {
143  biasVal =
144  resultTy.getElementType().isFloat()
145  ? builder.create<arith::ExtFOp>(loc, resType, biasVal)
146  .getResult()
147  : builder.create<arith::ExtSIOp>(loc, resType, biasVal)
148  .getResult();
149  }
150  builder.create<linalg::YieldOp>(loc, biasVal);
151  })
152  .getResult(0);
153 }
154 
155 static mlir::Value reifyConstantDim(int64_t attr,
156  ImplicitLocOpBuilder &builder) {
157  return builder.create<arith::ConstantIndexOp>(attr);
158 }
159 
160 // Calculating the output width/height using the formula:
161 // H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
162 // W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
163 
165  int64_t padBeforeAttr,
166  int64_t padAfterAttr, Value kernelDim,
167  int64_t strideAttr,
168  int64_t dilationAttr,
169  OpBuilder &rewriter) {
170  ImplicitLocOpBuilder builder(loc, rewriter);
171  auto one = rewriter.create<arith::ConstantOp>(
172  loc, IntegerAttr::get(inputDim.getType(), 1));
173  Value padBefore = reifyConstantDim(padBeforeAttr, builder);
174  Value paddedBefore = builder.create<arith::AddIOp>(inputDim, padBefore);
175  Value padAfter = reifyConstantDim(padAfterAttr, builder);
176  Value paddedAfter = builder.create<arith::AddIOp>(paddedBefore, padAfter);
177 
178  Value subOne = builder.create<arith::SubIOp>(kernelDim, one);
179  Value dilation = reifyConstantDim(dilationAttr, builder);
180  Value dilated = builder.create<arith::MulIOp>(dilation, subOne);
181  Value addOne = builder.create<arith::AddIOp>(dilated, one);
182 
183  Value subtract = builder.create<arith::SubIOp>(paddedAfter, addOne);
184  Value stride = reifyConstantDim(strideAttr, builder);
185  Value divide = builder.create<arith::DivUIOp>(subtract, stride);
186  return builder.create<arith::AddIOp>(divide, one);
187 }
188 
189 // Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D
191  Location loc, Value input, Value weight, ShapedType resultTy,
192  ArrayRef<int64_t> padAttr, ArrayRef<int64_t> strideAttr,
193  ArrayRef<int64_t> dilationAttr, ArrayRef<int64_t> inputSizeDims,
194  ArrayRef<int64_t> kernelSizeDims, OpBuilder &rewriter) {
195  ShapedType inputTy = cast<ShapedType>(input.getType());
196  int64_t inputRank = inputTy.getRank();
197 
198  SmallVector<Value> dynDims;
199  dynDims.resize(resultTy.getRank());
200 
201  for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) {
202  int64_t inputDim = inputSizeDims[i];
203  int64_t kernelDim = kernelSizeDims[i];
204  if (resultTy.isDynamicDim(inputDim)) {
205  auto padTop = padAttr[i * 2];
206  auto padBottom = padAttr[i * 2 + 1];
207  auto stride = strideAttr[i];
208  auto dilation = dilationAttr[i];
209  Value initDynDim = rewriter.create<tensor::DimOp>(loc, input, inputDim);
210  Value kernelDynDim =
211  rewriter.create<tensor::DimOp>(loc, weight, kernelDim);
212  // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
213  dynDims[inputDim] =
214  getConvOrPoolOutputDim(loc, initDynDim, padTop, padBottom,
215  kernelDynDim, stride, dilation, rewriter);
216  }
217  }
218 
219  // Get the batch/channels dimensions.
220  for (int i = 0; i < inputRank; i++) {
221  if (resultTy.isDynamicDim(i) && !dynDims[i])
222  dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
223  }
224 
225  SmallVector<Value> filteredDims = condenseValues(dynDims);
226  return filteredDims;
227 }
228 
229 // Creates a map to collapse the last dimension of the Depthwise convolution op
230 // due to a shape mismatch
232  int64_t outputRank, SmallVector<ReassociationExprs, 4> &reassociationMap,
233  OpBuilder &rewriter) {
234  reassociationMap.resize(outputRank);
235  for (int i = 0; i < outputRank; i++) {
236  reassociationMap[i].push_back(rewriter.getAffineDimExpr(i));
237  }
238  reassociationMap[outputRank - 1].push_back(
239  rewriter.getAffineDimExpr(outputRank));
240 }
241 
242 namespace {
243 
244 template <typename TosaConvOp, typename LinalgConvOp, typename LinalgConvQOp>
245 class ConvConverter : public OpConversionPattern<TosaConvOp> {
246 public:
248  LogicalResult
249  matchAndRewrite(TosaConvOp op, typename TosaConvOp::Adaptor adaptor,
250  ConversionPatternRewriter &rewriter) const final {
251  Location loc = op->getLoc();
252  Value input = op->getOperand(0);
253  Value weight = op->getOperand(1);
254  Value bias = op->getOperand(2);
255 
256  ShapedType inputTy = cast<ShapedType>(input.getType());
257  ShapedType weightTy = cast<ShapedType>(weight.getType());
258  ShapedType biasTy = cast<ShapedType>(bias.getType());
259  ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
260 
261  Type inputETy = inputTy.getElementType();
262 
263  DenseI64ArrayAttr padAttr = op.getPadAttr();
264  DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
265  DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
266 
267  Type accETy = op.getAccType();
268  Type accTy = RankedTensorType::get(resultTy.getShape(), accETy);
269 
270  // Get and verify zero points.
271  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
272  if (failed(maybeIZp))
273  return rewriter.notifyMatchFailure(
274  op, "input zero point cannot be statically determined");
275 
276  FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
277  if (failed(maybeWZp))
278  return rewriter.notifyMatchFailure(
279  op, "weight zero point cannot be statically determined");
280 
281  const int64_t inputZpVal = *maybeIZp;
282  const int64_t weightZpVal = *maybeWZp;
283 
284  if (op.verifyInputZeroPoint(inputZpVal).failed())
285  return rewriter.notifyMatchFailure(
286  op, "input zero point must be zero for non-int8 integer types");
287 
288  if (op.verifyWeightZeroPoint(weightZpVal).failed())
289  return rewriter.notifyMatchFailure(
290  op, "weight zero point must be zero for non-int8 integer types");
291 
292  bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
293 
294  if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
295  return rewriter.notifyMatchFailure(
296  op, "tosa.conv ops require static shapes for weight and bias");
297 
298  if (inputETy.isUnsignedInteger())
299  return rewriter.notifyMatchFailure(
300  op, "tosa.conv ops does not support unsigned integer input");
301 
302  llvm::SmallVector<int64_t> inputSizeDims;
303  llvm::SmallVector<int64_t> kernelSizeDims;
304  for (int i = 1; i < resultTy.getRank() - 1; i++) {
305  inputSizeDims.push_back(i);
306  kernelSizeDims.push_back(i);
307  }
308 
310  loc, input, weight, resultTy, padAttr.asArrayRef(),
311  strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
312  inputSizeDims, kernelSizeDims, rewriter);
313 
314  auto weightShape = weightTy.getShape();
315 
316  // Apply padding as necessary.
317  TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
318  if (hasZp) {
319  int64_t intMin =
320  APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
321  .getSExtValue();
322  int64_t intMax =
323  APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
324  .getSExtValue();
325 
326  if (inputZpVal < intMin || inputZpVal > intMax)
327  return rewriter.notifyMatchFailure(
328  op, "tosa.conv op quantization has zp outside of input range");
329 
330  zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
331  }
332 
334  pad.resize(2, 0);
335  llvm::append_range(pad, padAttr.asArrayRef());
336  pad.resize(pad.size() + 2, 0);
337  input = applyPad(loc, input, pad, zeroAttr, rewriter);
338 
339  if (4 == inputTy.getRank()) {
340  // For 2D convolutions, we need to check if the target convolution op
341  // wants a HWCF kernel layout.
342  bool wantHwcf =
343  hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
344  : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
345  if (wantHwcf) {
346  // Transpose the kernel to match dimension ordering of the linalg
347  // convolution operation.
348  // TODO(suderman): See if this can be efficiently folded - check whether
349  // the input is used anywhere else, if not fold the constant.
350  SmallVector<int32_t> weightPerm;
351  for (int i = 1; i < resultTy.getRank(); i++)
352  weightPerm.push_back(i);
353  weightPerm.push_back(0);
354 
355  SmallVector<int64_t> newWeightShape;
356  for (auto dim : weightPerm)
357  newWeightShape.push_back(weightShape[dim]);
358  auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
359  Type newWeightTy =
360  RankedTensorType::get(newWeightShape, weightTy.getElementType());
361  weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
362  weightPermAttr);
363  }
364  }
365 
366  // For Conv3D transpose the kernel to match dimension ordering of the linalg
367  // convolution operation. Conv2D has a 1-1 mapping in linalg so better to
368  // map directly and then transpose later if desired.
369  if (5 == inputTy.getRank()) {
370  // TODO(suderman): See if this can be efficiently folded - check whether
371  // the input is used anywhere else, if not fold the constant.
372  SmallVector<int32_t> weightPerm;
373  for (int i = 1; i < resultTy.getRank(); i++)
374  weightPerm.push_back(i);
375  weightPerm.push_back(0);
376 
377  SmallVector<int64_t> newWeightShape;
378  for (auto dim : weightPerm)
379  newWeightShape.push_back(weightShape[dim]);
380  auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
381  Type newWeightTy =
382  RankedTensorType::get(newWeightShape, weightTy.getElementType());
383  weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
384  weightPermAttr);
385  }
386 
387  // Extract the attributes for convolution.
388  ArrayRef<int64_t> stride = strideTosaAttr;
389  ArrayRef<int64_t> dilation = dilationTosaAttr;
390 
391  // Create the convolution op.
392  auto strideAttr = rewriter.getI64TensorAttr(stride);
393  auto dilationAttr = rewriter.getI64TensorAttr(dilation);
394 
395  Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
396  loc, resultTy.getShape(), accETy, filteredDims);
397 
398  Value broadcastBias =
399  linalgBroadcastAndMaybeExt(rewriter, loc, bias, biasEmptyTensor);
400 
401  if (hasZp) {
402  auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
403  auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
404 
405  auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
406  auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
407 
408  Value conv =
409  rewriter
410  .create<LinalgConvQOp>(
411  loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
412  ValueRange{broadcastBias}, strideAttr, dilationAttr)
413  ->getResult(0);
414 
415  rewriter.replaceOp(op, conv);
416  return success();
417  }
418 
419  Value conv = rewriter
420  .create<LinalgConvOp>(
421  loc, accTy, ValueRange{input, weight},
422  ValueRange{broadcastBias}, strideAttr, dilationAttr)
423  ->getResult(0);
424 
425  // We may need to truncate back to the result type if the accumulator was
426  // wider than the result.
427  if (resultTy != accTy)
428  conv = rewriter.create<tosa::CastOp>(loc, resultTy, conv);
429 
430  rewriter.replaceOp(op, conv);
431  return success();
432  }
433 };
434 
435 class DepthwiseConvConverter
436  : public OpConversionPattern<tosa::DepthwiseConv2DOp> {
437 public:
439  LogicalResult
440  matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor,
441  ConversionPatternRewriter &rewriter) const final {
442  Location loc = op->getLoc();
443  Value input = op->getOperand(0);
444  Value weight = op->getOperand(1);
445  Value bias = op->getOperand(2);
446 
447  ShapedType inputTy = cast<ShapedType>(input.getType());
448  ShapedType weightTy = cast<ShapedType>(weight.getType());
449  ShapedType biasTy = cast<ShapedType>(bias.getType());
450  ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
451  int64_t resultRank = resultTy.getRank();
452 
453  Type inputETy = inputTy.getElementType();
454  Type resultETy = resultTy.getElementType();
455 
456  auto padAttr = cast<DenseI64ArrayAttr>(op->getAttr("pad"));
457  auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("stride"));
458  auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("dilation"));
459 
460  Type accETy = op.getAccType();
461 
462  if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
463  return rewriter.notifyMatchFailure(
464  op, "tosa.depthwise_conv ops require static shapes");
465 
466  // Compute output dynamic dims
468  loc, input, weight, resultTy, padAttr.asArrayRef(),
469  strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
470  /*inputSizeDims=*/{1, 2},
471  /*kernelSizeDims=*/{0, 1}, rewriter);
472 
473  // Get and verify zero points.
474 
475  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
476  FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
477  if (failed(maybeIZp))
478  return rewriter.notifyMatchFailure(
479  op, "input zero point cannot be statically determined");
480  if (failed(maybeWZp))
481  return rewriter.notifyMatchFailure(
482  op, "weight zero point cannot be statically determined");
483 
484  const int64_t inputZpVal = *maybeIZp;
485  const int64_t weightZpVal = *maybeWZp;
486 
487  if (op.verifyInputZeroPoint(inputZpVal).failed())
488  return rewriter.notifyMatchFailure(
489  op, "input zero point must be zero for non-int8 integer types");
490 
491  if (op.verifyWeightZeroPoint(weightZpVal).failed())
492  return rewriter.notifyMatchFailure(
493  op, "weight zero point must be zero for non-int8 integer types");
494 
495  bool hasNullZps = (inputZpVal == 0) && (weightZpVal == 0);
496  auto weightShape = weightTy.getShape();
497  auto resultShape = resultTy.getShape();
498 
499  // Apply padding as necessary.
500  TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
501  if (!hasNullZps) {
502  int64_t intMin =
503  APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
504  .getSExtValue();
505  int64_t intMax =
506  APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
507  .getSExtValue();
508 
509  if (inputZpVal < intMin || inputZpVal > intMax)
510  return rewriter.notifyMatchFailure(
511  op, "tosa.depthwise_conv op quantization has zp outside of input "
512  "range");
513 
514  zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
515  }
516 
518  pad.resize(2, 0);
519  llvm::append_range(pad, padAttr.asArrayRef());
520  pad.resize(pad.size() + 2, 0);
521 
522  input = applyPad(loc, input, pad, zeroAttr, rewriter);
523 
524  // Extract the attributes for convolution.
525  ArrayRef<int64_t> stride = strideTosaAttr;
526  ArrayRef<int64_t> dilation = dilationTosaAttr;
527 
528  // Create the convolution op.
529  auto strideAttr = rewriter.getI64TensorAttr(stride);
530  auto dilationAttr = rewriter.getI64TensorAttr(dilation);
531  ShapedType linalgConvTy =
532  RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
533  weightShape[2], weightShape[3]},
534  accETy);
535 
536  auto resultZeroAttr = rewriter.getZeroAttr(accETy);
537  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
538  loc, linalgConvTy.getShape(), accETy, filteredDims);
539  Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
540  Value zeroTensor = rewriter
541  .create<linalg::FillOp>(loc, ValueRange{zero},
542  ValueRange{emptyTensor})
543  .result();
544 
545  Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
546  loc, resultTy.getShape(), resultETy, filteredDims);
547 
548  // Broadcast the initial value to the output tensor before convolving.
549  SmallVector<AffineMap, 4> indexingMaps;
550  indexingMaps.push_back(getBroadcastingMap(rewriter, bias, biasEmptyTensor));
551  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
552  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
553 
554  if (hasNullZps) {
555  Value conv = rewriter
556  .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
557  loc, linalgConvTy, ValueRange{input, weight},
558  ValueRange{zeroTensor}, strideAttr, dilationAttr)
559  .getResult(0);
560 
561  // We may need to truncate back to the result type if the accumulator was
562  // wider than the result.
563  if (accETy != resultETy)
564  conv = rewriter.create<tosa::CastOp>(
565  loc,
566  RankedTensorType::get(cast<ShapedType>(conv.getType()).getShape(),
567  resultETy),
568  conv);
569 
570  SmallVector<ReassociationExprs, 4> reassociationMap;
571  createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
572  Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
573  loc, resultTy, conv, reassociationMap);
574 
575  Value result =
576  rewriter
577  .create<linalg::GenericOp>(
578  loc, resultTy, ValueRange({bias, convReshape}),
579  biasEmptyTensor, indexingMaps,
580  getNParallelLoopsAttrs(resultRank),
581  [&](OpBuilder &nestedBuilder, Location nestedLoc,
582  ValueRange args) {
583  Value added;
584  if (llvm::isa<FloatType>(inputETy))
585  added = nestedBuilder.create<arith::AddFOp>(loc, args[0],
586  args[1]);
587  else
588  added = nestedBuilder.create<arith::AddIOp>(loc, args[0],
589  args[1]);
590  nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
591  })
592  .getResult(0);
593  rewriter.replaceOp(op, result);
594  } else {
595  IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal);
596  IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
597  auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
598  auto kZpVal = rewriter.create<arith::ConstantOp>(loc, wZp);
599  Value conv =
600  rewriter
601  .create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
602  loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
603  ValueRange{zeroTensor}, strideAttr, dilationAttr)
604  .getResult(0);
605  SmallVector<ReassociationExprs, 4> reassociationMap;
606  createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
607  Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
608  loc, resultTy, conv, reassociationMap);
610  rewriter, loc, bias, convReshape, biasEmptyTensor, indexingMaps);
611  rewriter.replaceOp(op, result);
612  }
613  return success();
614  }
615 };
616 
617 class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
618 public:
620  LogicalResult
621  matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor,
622  ConversionPatternRewriter &rewriter) const final {
623  Location loc = op.getLoc();
624 
625  auto outputTy = cast<ShapedType>(op.getType());
626  auto outputElementTy = outputTy.getElementType();
627 
628  SmallVector<Value> dynDims;
629  dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
630 
631  if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) {
632  dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0);
633  }
634 
635  if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) {
636  dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1);
637  }
638 
639  if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) {
640  dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2);
641  }
642 
643  SmallVector<Value> filteredDims = condenseValues(dynDims);
644 
645  auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
646  Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
647  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
648  loc, outputTy.getShape(), outputTy.getElementType(), filteredDims);
649  Value zeroTensor = rewriter
650  .create<linalg::FillOp>(loc, ValueRange{zero},
651  ValueRange{emptyTensor})
652  .result();
653 
654  FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
655  FailureOr<int64_t> maybeBZp = op.getBZeroPoint();
656  if (failed(maybeAZp))
657  return rewriter.notifyMatchFailure(
658  op, "input a zero point cannot be statically determined");
659  if (failed(maybeBZp))
660  return rewriter.notifyMatchFailure(
661  op, "input b zero point cannot be statically determined");
662 
663  const int64_t aZpVal = *maybeAZp;
664  const int64_t bZpVal = *maybeBZp;
665 
666  if (op.verifyAZeroPoint(aZpVal).failed())
667  return rewriter.notifyMatchFailure(
668  op, "input a zero point must be zero for non-int8 integer types");
669 
670  if (op.verifyBZeroPoint(bZpVal).failed())
671  return rewriter.notifyMatchFailure(
672  op, "input b zero point must be zero for non-int8 integer types");
673 
674  if (aZpVal == 0 && bZpVal == 0) {
675  rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
676  op, TypeRange{op.getType()},
677  ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
678  return success();
679  }
680 
681  auto aZp = rewriter.create<arith::ConstantOp>(
682  loc, rewriter.getI32IntegerAttr(aZpVal));
683  auto bZp = rewriter.create<arith::ConstantOp>(
684  loc, rewriter.getI32IntegerAttr(bZpVal));
685  rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
686  op, TypeRange{op.getType()},
687  ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
688 
689  return success();
690  }
691 };
692 
693 class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
694 public:
696 
697  // Compute the dynamic output sizes of the maxpool operation.
698  static SmallVector<Value>
699  computeDynamicOutputSizes(tosa::MaxPool2dOp op, OpAdaptor adaptor,
700  ConversionPatternRewriter &rewriter) {
701  TensorType resultTy = op.getType();
702  Location loc = op.getLoc();
703 
704  Value input = adaptor.getInput();
705  ArrayRef<int64_t> kernel = op.getKernel();
706  ArrayRef<int64_t> pad = op.getPad();
707  ArrayRef<int64_t> stride = op.getStride();
708 
709  SmallVector<Value> dynamicDims;
710 
711  // Batch dimension
712  if (resultTy.isDynamicDim(0))
713  dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
714 
715  // Height/width dimensions
716  for (int64_t dim : {1, 2}) {
717  if (!resultTy.isDynamicDim(dim))
718  continue;
719 
720  // Index into the attribute arrays
721  int64_t index = dim - 1;
722 
723  // Input height/width
724  Value ihw = rewriter.create<tensor::DimOp>(loc, input, dim);
725 
726  // Kernel height/width
727  Value khw = rewriter.create<arith::ConstantIndexOp>(loc, kernel[index]);
728 
729  // Output height/width
730  Value ohw = getConvOrPoolOutputDim(loc, ihw, pad[index * 2],
731  pad[index * 2 + 1], khw, stride[index],
732  /*dilationAttr=*/1, rewriter);
733  dynamicDims.push_back(ohw);
734  }
735 
736  // Channel dimension
737  if (resultTy.isDynamicDim(3))
738  dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 3));
739 
740  return dynamicDims;
741  }
742 
743  LogicalResult
744  matchAndRewrite(tosa::MaxPool2dOp op, OpAdaptor adaptor,
745  ConversionPatternRewriter &rewriter) const final {
746  Location loc = op.getLoc();
747  Value input = adaptor.getInput();
748  ShapedType inputTy = cast<ShapedType>(input.getType());
749 
750  bool isUnsigned = op.getType().getElementType().isUnsignedInteger();
751  ShapedType resultTy =
752  cast<ShapedType>(getTypeConverter()->convertType(op.getType()));
753  if (!resultTy)
754  return rewriter.notifyMatchFailure(op, "failed to convert type");
755  Type resultETy = inputTy.getElementType();
756 
757  SmallVector<Value> dynamicDims =
758  computeDynamicOutputSizes(op, adaptor, rewriter);
759 
760  // Determine what the initial value needs to be for the max pool op.
761  TypedAttr initialAttr;
762  if (resultETy.isF32() || resultETy.isBF16() || resultETy.isF16())
763  initialAttr = rewriter.getFloatAttr(
764  resultETy, APFloat::getLargest(
765  cast<FloatType>(resultETy).getFloatSemantics(), true));
766 
767  else if (isUnsigned)
768  initialAttr = rewriter.getIntegerAttr(
769  resultETy, APInt::getZero(resultETy.getIntOrFloatBitWidth()));
770  else if (isa<IntegerType>(resultETy))
771  initialAttr = rewriter.getIntegerAttr(
772  resultETy,
773  APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth()));
774 
775  if (!initialAttr)
776  return rewriter.notifyMatchFailure(
777  op, "Unsupported initial value for tosa.maxpool_2d op");
778 
779  // Apply padding as necessary.
781  pad.resize(2, 0);
782  llvm::append_range(pad, op.getPad());
783  pad.resize(pad.size() + 2, 0);
784 
785  Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
786 
787  Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
788 
789  ArrayRef<int64_t> kernel = op.getKernel();
790  ArrayRef<int64_t> stride = op.getStride();
791 
792  Attribute strideAttr = rewriter.getI64VectorAttr(stride);
793  Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
794 
795  // Create the linalg op that performs pooling.
796  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
797  loc, resultTy.getShape(), resultTy.getElementType(), dynamicDims);
798 
799  Value filledEmptyTensor =
800  rewriter.create<linalg::FillOp>(loc, initialValue, emptyTensor)
801  .result();
802 
803  Value fakeWindowDims =
804  rewriter.create<tensor::EmptyOp>(loc, kernel, resultETy);
805 
806  if (isUnsigned) {
807  rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
808  op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
809  filledEmptyTensor, strideAttr, dilationAttr);
810  return llvm::success();
811  }
812 
813  auto resultOp = rewriter.create<linalg::PoolingNhwcMaxOp>(
814  op->getLoc(), ArrayRef<Type>{resultTy},
815  ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
816  dilationAttr);
817 
818  rewriter.replaceOp(op, resultOp);
819 
820  // NaN propagation has no meaning for non floating point types.
821  if (!isa<FloatType>(getElementTypeOrSelf(inputTy)))
822  return success();
823 
824  // "PROPAGATE" mode matches the behaviour of the LinAlg named op, so no
825  // compare and select materialization is required.
826  //
827  // In the case of "IGNORE" we need to insert a compare and select. Since
828  // we've already produced a named op we will just take its body and modify
829  // it to include the appropriate checks. If the current value is NaN the
830  // old value of pool will be taken otherwise we use the result.
831  if (const auto nanMode = op.getNanMode(); nanMode == "IGNORE") {
832  auto genericOp = rewriter.create<linalg::GenericOp>(
833  op->getLoc(), resultOp.getType(0), resultOp.getInputs(),
834  resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
835  resultOp.getIteratorTypesArray(),
836  [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
837  IRMapping map;
838  auto oldBlock = resultOp.getRegion().begin();
839  auto oldArgs = oldBlock->getArguments();
840  auto &oldMaxOp = *resultOp.getBlock()->begin();
841  map.map(oldArgs, blockArgs);
842  auto *newOp = opBuilder.clone(oldMaxOp, map);
843  Value isNaN = opBuilder.create<arith::CmpFOp>(
844  op->getLoc(), arith::CmpFPredicate::UNO, blockArgs.front(),
845  blockArgs.front());
846  auto selectOp = opBuilder.create<arith::SelectOp>(
847  op->getLoc(), isNaN, blockArgs.back(), newOp->getResult(0));
848  opBuilder.create<linalg::YieldOp>(loc, selectOp.getResult());
849  });
850  rewriter.replaceOp(resultOp, genericOp);
851  }
852 
853  return success();
854  }
855 };
856 
857 class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
858 public:
860 
861  LogicalResult matchAndRewrite(tosa::AvgPool2dOp op,
862  PatternRewriter &rewriter) const final {
863  Location loc = op.getLoc();
864  Value input = op.getInput();
865  ShapedType inputTy = cast<ShapedType>(input.getType());
866  Type inElementTy = inputTy.getElementType();
867 
868  ShapedType resultTy = cast<ShapedType>(op.getType());
869  Type resultETy = cast<ShapedType>(op.getType()).getElementType();
870 
871  Type accETy = op.getAccType();
872  ShapedType accTy = resultTy.clone(accETy);
873 
874  auto dynamicDimsOr =
875  checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
876  if (!dynamicDimsOr.has_value())
877  return failure();
878  SmallVector<Value> dynamicDims = *dynamicDimsOr;
879 
880  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
881  FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
882  if (failed(maybeIZp))
883  return rewriter.notifyMatchFailure(
884  op, "input zero point could not be statically determined");
885  if (failed(maybeOZp))
886  return rewriter.notifyMatchFailure(
887  op, "output zero point could not be statically determined");
888 
889  const int64_t inputZpVal = *maybeIZp;
890  const int64_t outputZpVal = *maybeOZp;
891 
892  // Apply padding as necessary.
894  pad.resize(2, 0);
895  llvm::append_range(pad, op.getPad());
896  pad.resize(pad.size() + 2, 0);
897  TypedAttr padAttr = rewriter.getZeroAttr(inElementTy);
898  // Unsupported element type
899  if (!padAttr)
900  return failure();
901  Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter);
902 
903  auto initialAttr = rewriter.getZeroAttr(accETy);
904  Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
905 
906  ArrayRef<int64_t> kernel = op.getKernel();
907  ArrayRef<int64_t> stride = op.getStride();
908 
909  Attribute strideAttr = rewriter.getI64VectorAttr(stride);
910  Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
911 
912  // Create the linalg op that performs pooling.
913  Value poolEmptyTensor = rewriter.create<tensor::EmptyOp>(
914  loc, accTy.getShape(), accETy, dynamicDims);
915 
916  Value filledEmptyTensor =
917  rewriter
918  .create<linalg::FillOp>(loc, ValueRange{initialValue},
919  ValueRange{poolEmptyTensor})
920  .result();
921 
922  Value fakeWindowDims =
923  rewriter.create<tensor::EmptyOp>(loc, kernel, accETy);
924 
925  // Sum across the pooled region.
926  Value poolingOp = rewriter
927  .create<linalg::PoolingNhwcSumOp>(
928  loc, ArrayRef<Type>{accTy},
929  ValueRange{paddedInput, fakeWindowDims},
930  filledEmptyTensor, strideAttr, dilationAttr)
931  .getResult(0);
932 
933  // Normalize the summed value by the number of elements grouped in each
934  // pool.
935  Value iH = rewriter.create<tensor::DimOp>(loc, poolingOp, 1);
936  Value iW = rewriter.create<tensor::DimOp>(loc, poolingOp, 2);
937 
938  auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
939  iH = rewriter.create<arith::SubIOp>(loc, iH, one);
940  iW = rewriter.create<arith::SubIOp>(loc, iW, one);
941 
942  Value genericEmptyTensor = rewriter.create<tensor::EmptyOp>(
943  loc, resultTy.getShape(), resultETy, dynamicDims);
944 
945  auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
946  auto genericOp = rewriter.create<linalg::GenericOp>(
947  loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
948  ValueRange{genericEmptyTensor},
949  ArrayRef<AffineMap>({affineMap, affineMap}),
950  getNParallelLoopsAttrs(resultTy.getRank()),
951  [&](OpBuilder &b, Location loc, ValueRange args) {
952  auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
953 
954  // Determines what the portion of valid input is covered by the
955  // kernel.
956  auto padFn = [&](Value valid, Value pos, int64_t pad) -> Value {
957  if (pad == 0)
958  return valid;
959 
960  auto padVal = rewriter.create<arith::ConstantIndexOp>(loc, pad);
961  Value dpos = rewriter.create<arith::SubIOp>(loc, pos, padVal);
962 
963  Value offset = rewriter.create<arith::MinSIOp>(loc, dpos, zero);
964  return rewriter.create<arith::AddIOp>(loc, valid, offset)
965  ->getResult(0);
966  };
967 
968  auto coverageFn = [&](int64_t i, Value isize) -> Value {
969  Value strideVal =
970  rewriter.create<arith::ConstantIndexOp>(loc, stride[i - 1]);
971  Value val =
972  rewriter.create<arith::ConstantIndexOp>(loc, kernel[i - 1]);
973 
974  // Find the position relative to the input tensor's ends.
975  Value left = rewriter.create<linalg::IndexOp>(loc, i);
976  Value right = rewriter.create<arith::SubIOp>(loc, isize, left);
977  left = rewriter.create<arith::MulIOp>(loc, left, strideVal);
978  right = rewriter.create<arith::MulIOp>(loc, right, strideVal);
979 
980  // Determine how much padding was included.
981  val = padFn(val, left, pad[i * 2]);
982  val = padFn(val, right, pad[i * 2 + 1]);
983  return rewriter.create<arith::MaxSIOp>(loc, one, val);
984  };
985 
986  // Compute the indices from either end.
987  Value kH3 = coverageFn(1, iH);
988  Value kW3 = coverageFn(2, iW);
989 
990  // Compute the total number of elements and normalize.
991  auto count = rewriter.create<arith::IndexCastOp>(
992  loc, rewriter.getI32Type(),
993  rewriter.create<arith::MulIOp>(loc, kH3, kW3));
994 
995  // Divide by the number of summed values. For floats this is just
996  // a div however for quantized values input normalization had
997  // to be applied.
998  Value poolVal = args[0];
999  if (isa<FloatType>(accETy)) {
1000  auto countF = rewriter.create<arith::SIToFPOp>(loc, accETy, count);
1001  poolVal = rewriter.create<arith::DivFOp>(loc, poolVal, countF)
1002  ->getResult(0);
1003  if (accETy.getIntOrFloatBitWidth() >
1004  resultETy.getIntOrFloatBitWidth())
1005  poolVal =
1006  rewriter.create<arith::TruncFOp>(loc, resultETy, poolVal);
1007  } else {
1008 
1009  // If we have quantization information we need to apply an offset
1010  // for the input zp value.
1011  if (inputZpVal != 0) {
1012  auto inputZp = rewriter.create<arith::ConstantOp>(
1013  loc, b.getIntegerAttr(accETy, inputZpVal));
1014  Value offset =
1015  rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
1016  poolVal =
1017  rewriter.create<arith::SubIOp>(loc, accETy, poolVal, offset);
1018  }
1019 
1020  // Compute: k = 32 - count_leading_zeros(value - 1)
1021  Value one32 = rewriter.create<arith::ConstantOp>(
1022  loc, rewriter.getI32IntegerAttr(1));
1023  Value thirtyTwo32 = rewriter.create<arith::ConstantOp>(
1024  loc, rewriter.getI32IntegerAttr(32));
1025 
1026  Value countSubOne =
1027  rewriter.create<arith::SubIOp>(loc, count, one32);
1028  Value leadingZeros =
1029  rewriter.create<math::CountLeadingZerosOp>(loc, countSubOne);
1030  Value k =
1031  rewriter.create<arith::SubIOp>(loc, thirtyTwo32, leadingZeros);
1032 
1033  // Compute: numerator = ((1 << 30) + 1) << k
1034  Value k64 =
1035  rewriter.create<arith::ExtUIOp>(loc, rewriter.getI64Type(), k);
1036  Value thirtyShiftPlusOne = rewriter.create<arith::ConstantOp>(
1037  loc, rewriter.getI64IntegerAttr((1 << 30) + 1));
1038  Value numerator =
1039  rewriter.create<arith::ShLIOp>(loc, thirtyShiftPlusOne, k64);
1040 
1041  // Compute: scale.multiplier = numerator / value;
1042  Value count64 = rewriter.create<arith::ExtUIOp>(
1043  loc, rewriter.getI64Type(), count);
1044  Value multiplier =
1045  rewriter.create<arith::DivUIOp>(loc, numerator, count64);
1046  multiplier = rewriter.create<arith::TruncIOp>(
1047  loc, rewriter.getI32Type(), multiplier);
1048 
1049  // Compute: scale.shift = 30 + k
1050  Value k8 =
1051  rewriter.create<arith::TruncIOp>(loc, rewriter.getI8Type(), k);
1052  Value thirty8 = rewriter.create<arith::ConstantOp>(
1053  loc, rewriter.getI8IntegerAttr(30));
1054  Value shift = rewriter.create<arith::AddIOp>(loc, k8, thirty8);
1055 
1056  auto scaled =
1057  rewriter
1058  .create<tosa::ApplyScaleOp>(
1059  loc, rewriter.getI32Type(), poolVal, multiplier, shift,
1060  rewriter.getStringAttr("SINGLE_ROUND"))
1061  .getResult();
1062 
1063  // If we have quantization information we need to apply output
1064  // zeropoint.
1065  if (outputZpVal != 0) {
1066  auto outputZp = rewriter.create<arith::ConstantOp>(
1067  loc, b.getIntegerAttr(scaled.getType(), outputZpVal));
1068  scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
1069  .getResult();
1070  }
1071 
1072  // Apply Clip.
1073  int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
1074 
1075  auto min = rewriter.create<arith::ConstantIntOp>(
1076  loc, APInt::getSignedMinValue(outBitwidth).getSExtValue(),
1077  accETy);
1078  auto max = rewriter.create<arith::ConstantIntOp>(
1079  loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
1080  accETy);
1081  auto clamp = clampIntHelper(loc, scaled, min, max, rewriter,
1082  /*isUnsigned=*/false);
1083 
1084  poolVal = clamp;
1085  // Convert type.
1086  if (resultETy != clamp.getType()) {
1087  poolVal =
1088  rewriter.create<arith::TruncIOp>(loc, resultETy, poolVal);
1089  }
1090  }
1091 
1092  rewriter.create<linalg::YieldOp>(loc, poolVal);
1093  });
1094 
1095  rewriter.replaceOp(op, genericOp.getResult(0));
1096  return success();
1097  }
1098 };
1099 
1100 class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
1101 public:
1103 
1104  LogicalResult matchAndRewrite(tosa::TransposeOp op,
1105  PatternRewriter &rewriter) const final {
1106  const llvm::ArrayRef<int32_t> constantPerms = op.getPerms();
1107 
1108  Location loc = op.getLoc();
1109  // The verifier should have made sure we have a valid TOSA permutation
1110  // tensor. isPermutationVector doesn't actually check the TOSA perms we
1111  // expect.
1112  SmallVector<OpFoldResult> inputSizes =
1113  tensor::getMixedSizes(rewriter, loc, op.getInput1());
1114  auto permutedSizes =
1115  applyTOSAPermutation<OpFoldResult>(inputSizes, constantPerms);
1116 
1117  auto permutedInit = rewriter.create<tensor::EmptyOp>(
1118  loc, permutedSizes, op.getInput1().getType().getElementType());
1119  rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
1120  op, op.getInput1(), permutedInit,
1121  llvm::to_vector(llvm::map_range(
1122  constantPerms, [](int32_t v) -> int64_t { return v; })));
1123  return success();
1124  }
1125 };
1126 } // namespace
1127 
1129  const TypeConverter &converter, RewritePatternSet *patterns,
1130  const TosaToLinalgNamedOptions &options) {
1131  if (options.preferConv2DKernelLayoutHWCF) {
1132  patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
1133  linalg::Conv2DNhwcHwcfQOp>>(
1134  patterns->getContext());
1135  } else {
1136  patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp,
1137  linalg::Conv2DNhwcFhwcQOp>>(
1138  patterns->getContext());
1139  }
1140  patterns->add<
1141  // clang-format off
1142  ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
1143  DepthwiseConvConverter,
1144  MatMulConverter,
1145  AvgPool2dConverter,
1146  TransposeConverter
1147  >(patterns->getContext());
1148 
1149  patterns->add<
1150  MaxPool2dConverter
1151  >(converter, patterns->getContext());
1152  // clang-format on
1153 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static llvm::ManagedStatic< PassManagerOptions > options
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source, Value result)
static mlir::Value applyPad(Location loc, Value input, ArrayRef< int64_t > pad, TypedAttr padAttr, OpBuilder &rewriter)
static void createDepthwiseConvCollapseMap(int64_t outputRank, SmallVector< ReassociationExprs, 4 > &reassociationMap, OpBuilder &rewriter)
static mlir::Value linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, Value conv, Value result, ArrayRef< AffineMap > indexingMaps)
static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim, int64_t padBeforeAttr, int64_t padAfterAttr, Value kernelDim, int64_t strideAttr, int64_t dilationAttr, OpBuilder &rewriter)
static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter, Location loc, Value source, Value result)
static mlir::Value reifyConstantDim(int64_t attr, ImplicitLocOpBuilder &builder)
static SmallVector< Value > inferDynamicDimsForConv(Location loc, Value input, Value weight, ShapedType resultTy, ArrayRef< int64_t > padAttr, ArrayRef< int64_t > strideAttr, ArrayRef< int64_t > dilationAttr, ArrayRef< int64_t > inputSizeDims, ArrayRef< int64_t > kernelSizeDims, OpBuilder &rewriter)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:383
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:368
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:360
MLIRContext * getContext() const
Definition: Builders.h:56
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:124
IntegerType getI8Type()
Definition: Builders.cpp:59
IntegerAttr getI8IntegerAttr(int8_t value)
Definition: Builders.cpp:217
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:549
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:753
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:686
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:504
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
void populateTosaToLinalgNamedConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options)
Populates conversion passes from TOSA dialect to Linalg named operations.
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:318