MLIR  20.0.0git
TosaToLinalgNamed.cpp
Go to the documentation of this file.
1 //===- TosaToLinalgNamed.cpp - Lowering Tosa to Linalg Named Ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg named ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
28 
30 
31 #include <numeric>
32 #include <type_traits>
33 
34 using namespace mlir;
35 using namespace mlir::tosa;
36 
38  TypedAttr padAttr, OpBuilder &rewriter) {
39  // Input should be padded only if necessary.
40  if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
41  return input;
42 
43  ShapedType inputTy = cast<ShapedType>(input.getType());
44  Type inputETy = inputTy.getElementType();
45  auto inputShape = inputTy.getShape();
46 
47  assert((inputShape.size() * 2) == pad.size());
48 
49  SmallVector<int64_t, 4> paddedShape;
51  SmallVector<OpFoldResult, 8> highIndices;
52  for (size_t i : llvm::seq(inputShape.size())) {
53  auto lowPad = pad[i * 2];
54  auto highPad = pad[i * 2 + 1];
55  if (ShapedType::isDynamic(inputShape[i]))
56  paddedShape.push_back(inputShape[i]);
57  else
58  paddedShape.push_back(inputShape[i] + highPad + lowPad);
59  lowIndices.push_back(rewriter.getIndexAttr(lowPad));
60  highIndices.push_back(rewriter.getIndexAttr(highPad));
61  }
62 
63  Value padValue = rewriter.create<arith::ConstantOp>(loc, padAttr);
64 
65  return rewriter.create<tensor::PadOp>(
66  loc, RankedTensorType::get(paddedShape, inputETy), input, lowIndices,
67  highIndices, padValue);
68 }
69 
70 static mlir::Value
72  Value conv, Value result,
73  ArrayRef<AffineMap> indexingMaps) {
74  ShapedType resultTy = cast<ShapedType>(conv.getType());
75  return rewriter
76  .create<linalg::GenericOp>(
77  loc, resultTy, ValueRange({bias, conv}), result, indexingMaps,
78  getNParallelLoopsAttrs(resultTy.getRank()),
79  [](OpBuilder &builder, Location loc, ValueRange args) {
80  Value biasVal = args[0];
81  Type resType = args[1].getType();
82  if (resType != biasVal.getType()) {
83  biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal);
84  }
85  Value added = builder.create<arith::AddIOp>(loc, biasVal, args[1]);
86  builder.create<linalg::YieldOp>(loc, added);
87  })
88  .getResult(0);
89 }
90 
91 // Construct the affine map that a linalg generic would use to broadcast the
92 // source tensor into the shape of the result tensor.
94  Value result) {
95  ShapedType resultTy = cast<ShapedType>(result.getType());
96  ShapedType sourceTy = cast<ShapedType>(source.getType());
97  const int64_t resultRank = resultTy.getRank();
98  const int64_t sourceRank = sourceTy.getRank();
99 
100  // The source tensor is broadcast to all the outer dimensions of the
101  // result tensor.
102  SmallVector<AffineExpr> sourceDims;
103  // In the case of a rank one source tensor with a single element TOSA
104  // specifies that the value be broadcast meaning we need an edge case for a
105  // constant map.
106  assert(sourceTy.hasStaticShape() &&
107  "Dynamic broadcasting shapes not supported!");
108  if (sourceRank == 1 && sourceTy.getDimSize(0) == 1) {
109  sourceDims.push_back(rewriter.getAffineConstantExpr(0));
110  } else {
111  for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
112  auto expr = rewriter.getAffineDimExpr(dim + resultRank - sourceRank);
113  sourceDims.push_back(expr);
114  }
115  }
116 
117  return AffineMap::get(/*dimCount=*/resultRank,
118  /*symbolCount=*/0, sourceDims, rewriter.getContext());
119 }
120 
121 // Broadcast the source value to all the outer dimensions of the result value.
122 // If required, the element type is expanded using an arith.extsi operation.
124  Location loc, Value source,
125  Value result) {
126  ShapedType resultTy = cast<ShapedType>(result.getType());
127  const int64_t resultRank = resultTy.getRank();
128  // Creating maps for the input and output of the broacast-like generic op.
129  SmallVector<AffineMap, 2> indexingMaps;
130  indexingMaps.push_back(getBroadcastingMap(rewriter, source, result));
131  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
132 
133  // Build the broadcast-like operation as a linalg.generic.
134  return rewriter
135  .create<linalg::GenericOp>(
136  loc, resultTy, ValueRange({source}), result, indexingMaps,
137  getNParallelLoopsAttrs(resultTy.getRank()),
138  [](OpBuilder &builder, Location loc, ValueRange args) {
139  Value biasVal = args[0];
140  Type resType = args[1].getType();
141  if (resType != biasVal.getType()) {
142  biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal);
143  }
144  builder.create<linalg::YieldOp>(loc, biasVal);
145  })
146  .getResult(0);
147 }
148 
149 static mlir::Value reifyConstantDim(int64_t attr,
150  ImplicitLocOpBuilder &builder) {
151  return builder.create<arith::ConstantIndexOp>(attr);
152 }
153 
154 // Calculating the output width/height using the formula:
155 // H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
156 // W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
157 
159  int64_t padBeforeAttr,
160  int64_t padAfterAttr, Value kernelDim,
161  int64_t strideAttr,
162  int64_t dilationAttr,
163  OpBuilder &rewriter) {
164  ImplicitLocOpBuilder builder(loc, rewriter);
165  auto one = rewriter.create<arith::ConstantOp>(
166  loc, IntegerAttr::get(inputDim.getType(), 1));
167  Value padBefore = reifyConstantDim(padBeforeAttr, builder);
168  Value paddedBefore = builder.create<arith::AddIOp>(inputDim, padBefore);
169  Value padAfter = reifyConstantDim(padAfterAttr, builder);
170  Value paddedAfter = builder.create<arith::AddIOp>(paddedBefore, padAfter);
171 
172  Value subOne = builder.create<arith::SubIOp>(kernelDim, one);
173  Value dilation = reifyConstantDim(dilationAttr, builder);
174  Value dilated = builder.create<arith::MulIOp>(dilation, subOne);
175  Value addOne = builder.create<arith::AddIOp>(dilated, one);
176 
177  Value subtract = builder.create<arith::SubIOp>(paddedAfter, addOne);
178  Value stride = reifyConstantDim(strideAttr, builder);
179  Value divide = builder.create<arith::DivUIOp>(subtract, stride);
180  return builder.create<arith::AddIOp>(divide, one);
181 }
182 
183 // Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D
185  Location loc, Value input, Value weight, ShapedType resultTy,
186  ArrayRef<int64_t> padAttr, ArrayRef<int64_t> strideAttr,
187  ArrayRef<int64_t> dilationAttr, ArrayRef<int64_t> inputSizeDims,
188  ArrayRef<int64_t> kernelSizeDims, OpBuilder &rewriter) {
189  ShapedType inputTy = cast<ShapedType>(input.getType());
190  int64_t inputRank = inputTy.getRank();
191 
192  SmallVector<Value> dynDims;
193  dynDims.resize(resultTy.getRank());
194 
195  for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) {
196  int64_t inputDim = inputSizeDims[i];
197  int64_t kernelDim = kernelSizeDims[i];
198  if (resultTy.isDynamicDim(inputDim)) {
199  auto padTop = padAttr[i * 2];
200  auto padBottom = padAttr[i * 2 + 1];
201  auto stride = strideAttr[i];
202  auto dilation = dilationAttr[i];
203  Value initDynDim = rewriter.create<tensor::DimOp>(loc, input, inputDim);
204  Value kernelDynDim =
205  rewriter.create<tensor::DimOp>(loc, weight, kernelDim);
206  // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
207  dynDims[inputDim] =
208  getConvOrPoolOutputDim(loc, initDynDim, padTop, padBottom,
209  kernelDynDim, stride, dilation, rewriter);
210  }
211  }
212 
213  // Get the batch/channels dimensions.
214  for (int i = 0; i < inputRank; i++) {
215  if (resultTy.isDynamicDim(i) && !dynDims[i])
216  dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
217  }
218 
219  SmallVector<Value> filteredDims = condenseValues(dynDims);
220  return filteredDims;
221 }
222 
223 // Creates a map to collapse the last dimension of the Depthwise convolution op
224 // due to a shape mismatch
226  int64_t outputRank, SmallVector<ReassociationExprs, 4> &reassociationMap,
227  OpBuilder &rewriter) {
228  reassociationMap.resize(outputRank);
229  for (int i = 0; i < outputRank; i++) {
230  reassociationMap[i].push_back(rewriter.getAffineDimExpr(i));
231  }
232  reassociationMap[outputRank - 1].push_back(
233  rewriter.getAffineDimExpr(outputRank));
234 }
235 
236 namespace {
237 
238 template <typename TosaConvOp, typename LinalgConvOp, typename LinalgConvQOp>
239 class ConvConverter : public OpConversionPattern<TosaConvOp> {
240 public:
242  LogicalResult
243  matchAndRewrite(TosaConvOp op, typename TosaConvOp::Adaptor adaptor,
244  ConversionPatternRewriter &rewriter) const final {
245  Location loc = op->getLoc();
246  Value input = op->getOperand(0);
247  Value weight = op->getOperand(1);
248  Value bias = op->getOperand(2);
249 
250  ShapedType inputTy = cast<ShapedType>(input.getType());
251  ShapedType weightTy = cast<ShapedType>(weight.getType());
252  ShapedType biasTy = cast<ShapedType>(bias.getType());
253  ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
254 
255  Type inputETy = inputTy.getElementType();
256  Type resultETy = resultTy.getElementType();
257 
258  DenseI64ArrayAttr padAttr = op.getPadAttr();
259  DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
260  DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
261  bool isQuantized = op.getQuantizationInfo().has_value();
262 
263  if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
264  return rewriter.notifyMatchFailure(
265  op, "tosa.conv ops require static shapes for weight and bias");
266 
267  if (inputETy.isUnsignedInteger())
268  return rewriter.notifyMatchFailure(
269  op, "tosa.conv ops does not support unsigned integer input");
270 
271  llvm::SmallVector<int64_t> inputSizeDims;
272  llvm::SmallVector<int64_t> kernelSizeDims;
273  for (int i = 1; i < resultTy.getRank() - 1; i++) {
274  inputSizeDims.push_back(i);
275  kernelSizeDims.push_back(i);
276  }
277 
279  loc, input, weight, resultTy, padAttr.asArrayRef(),
280  strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
281  inputSizeDims, kernelSizeDims, rewriter);
282 
283  auto weightShape = weightTy.getShape();
284 
285  // Apply padding as necessary.
286  TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
287  if (isQuantized) {
288  auto quantizationInfo = *op.getQuantizationInfo();
289  int64_t iZp = quantizationInfo.getInputZp();
290 
291  int64_t intMin =
292  APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
293  .getSExtValue();
294  int64_t intMax =
295  APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
296  .getSExtValue();
297 
298  if (iZp < intMin || iZp > intMax)
299  return rewriter.notifyMatchFailure(
300  op, "tosa.conv op quantization has zp outside of input range");
301 
302  zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
303  }
304 
306  pad.resize(2, 0);
307  llvm::append_range(pad, padAttr.asArrayRef());
308  pad.resize(pad.size() + 2, 0);
309  input = applyPad(loc, input, pad, zeroAttr, rewriter);
310 
311  if (4 == inputTy.getRank()) {
312  // For 2D convolutions, we need to check if the target convolution op
313  // wants a HWCF kernel layout.
314  bool wantHwcf =
315  isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
316  : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
317  if (wantHwcf) {
318  // Transpose the kernel to match dimension ordering of the linalg
319  // convolution operation.
320  // TODO(suderman): See if this can be efficiently folded - check whether
321  // the input is used anywhere else, if not fold the constant.
322  SmallVector<int32_t> weightPerm;
323  for (int i = 1; i < resultTy.getRank(); i++)
324  weightPerm.push_back(i);
325  weightPerm.push_back(0);
326 
327  SmallVector<int64_t> newWeightShape;
328  for (auto dim : weightPerm)
329  newWeightShape.push_back(weightShape[dim]);
330  auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
331  Value weightPermValue =
332  rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
333  Type newWeightTy =
334  RankedTensorType::get(newWeightShape, weightTy.getElementType());
335  weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
336  weightPermValue);
337  }
338  }
339 
340  // For Conv3D transpose the kernel to match dimension ordering of the linalg
341  // convolution operation. Conv2D has a 1-1 mapping in linalg so better to
342  // map directly and then transpose later if desired.
343  if (5 == inputTy.getRank()) {
344  // TODO(suderman): See if this can be efficiently folded - check whether
345  // the input is used anywhere else, if not fold the constant.
346  SmallVector<int32_t> weightPerm;
347  for (int i = 1; i < resultTy.getRank(); i++)
348  weightPerm.push_back(i);
349  weightPerm.push_back(0);
350 
351  SmallVector<int64_t> newWeightShape;
352  for (auto dim : weightPerm)
353  newWeightShape.push_back(weightShape[dim]);
354  auto weightPermAttr = rewriter.getI32TensorAttr(weightPerm);
355  Value weightPermValue =
356  rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
357  Type newWeightTy =
358  RankedTensorType::get(newWeightShape, weightTy.getElementType());
359  weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
360  weightPermValue);
361  }
362 
363  // Extract the attributes for convolution.
364  ArrayRef<int64_t> stride = strideTosaAttr;
365  ArrayRef<int64_t> dilation = dilationTosaAttr;
366 
367  // Create the convolution op.
368  auto strideAttr = rewriter.getI64TensorAttr(stride);
369  auto dilationAttr = rewriter.getI64TensorAttr(dilation);
370 
371  Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
372  loc, resultTy.getShape(), resultETy, filteredDims);
373 
374  Value broadcastBias =
375  linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
376 
377  if (isQuantized) {
378  auto quantizationInfo = *op.getQuantizationInfo();
379  auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
380  auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
381 
382  auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
383  auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
384 
385  Value conv =
386  rewriter
387  .create<LinalgConvQOp>(
388  loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
389  ValueRange{broadcastBias}, strideAttr, dilationAttr)
390  ->getResult(0);
391 
392  rewriter.replaceOp(op, conv);
393  return success();
394  }
395 
396  Value conv = rewriter
397  .create<LinalgConvOp>(
398  loc, resultTy, ValueRange{input, weight},
399  ValueRange{broadcastBias}, strideAttr, dilationAttr)
400  ->getResult(0);
401 
402  rewriter.replaceOp(op, conv);
403  return success();
404  }
405 };
406 
407 class DepthwiseConvConverter
408  : public OpConversionPattern<tosa::DepthwiseConv2DOp> {
409 public:
411  LogicalResult
412  matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor,
413  ConversionPatternRewriter &rewriter) const final {
414  Location loc = op->getLoc();
415  Value input = op->getOperand(0);
416  Value weight = op->getOperand(1);
417  Value bias = op->getOperand(2);
418 
419  ShapedType inputTy = cast<ShapedType>(input.getType());
420  ShapedType weightTy = cast<ShapedType>(weight.getType());
421  ShapedType biasTy = cast<ShapedType>(bias.getType());
422  ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
423  int64_t resultRank = resultTy.getRank();
424 
425  Type inputETy = inputTy.getElementType();
426  Type resultETy = resultTy.getElementType();
427 
428  auto padAttr = cast<DenseI64ArrayAttr>(op->getAttr("pad"));
429  auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("stride"));
430  auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("dilation"));
431 
432  if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
433  return rewriter.notifyMatchFailure(
434  op, "tosa.depthwise_conv ops require static shapes");
435 
436  // Compute output dynamic dims
438  loc, input, weight, resultTy, padAttr.asArrayRef(),
439  strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
440  /*inputSizeDims=*/{1, 2},
441  /*kernelSizeDims=*/{0, 1}, rewriter);
442 
443  bool isQuantized = op->hasAttr("quantization_info");
444  IntegerAttr iZp;
445  IntegerAttr kZp;
446  if (isQuantized) {
447  auto quantizationInfo =
448  cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
449  iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
450  kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
451  }
452 
453  auto weightShape = weightTy.getShape();
454  auto resultShape = resultTy.getShape();
455 
456  // Apply padding as necessary.
457  TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
458  if (isQuantized) {
459  auto quantizationInfo =
460  cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
461  int64_t iZp = quantizationInfo.getInputZp();
462 
463  int64_t intMin =
464  APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
465  .getSExtValue();
466  int64_t intMax =
467  APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
468  .getSExtValue();
469 
470  if (iZp < intMin || iZp > intMax)
471  return rewriter.notifyMatchFailure(
472  op, "tosa.depthwise_conv op quantization has zp outside of input "
473  "range");
474 
475  zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
476  }
477 
479  pad.resize(2, 0);
480  llvm::append_range(pad, padAttr.asArrayRef());
481  pad.resize(pad.size() + 2, 0);
482 
483  input = applyPad(loc, input, pad, zeroAttr, rewriter);
484 
485  // Extract the attributes for convolution.
486  ArrayRef<int64_t> stride = strideTosaAttr;
487  ArrayRef<int64_t> dilation = dilationTosaAttr;
488 
489  // Create the convolution op.
490  auto strideAttr = rewriter.getI64TensorAttr(stride);
491  auto dilationAttr = rewriter.getI64TensorAttr(dilation);
492  ShapedType linalgConvTy =
493  RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
494  weightShape[2], weightShape[3]},
495  resultETy);
496 
497  auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
498  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
499  loc, linalgConvTy.getShape(), resultETy, filteredDims);
500  Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
501  Value zeroTensor = rewriter
502  .create<linalg::FillOp>(loc, ValueRange{zero},
503  ValueRange{emptyTensor})
504  .result();
505 
506  Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
507  loc, resultTy.getShape(), resultETy, filteredDims);
508 
509  // Broadcast the initial value to the output tensor before convolving.
510  SmallVector<AffineMap, 4> indexingMaps;
511  indexingMaps.push_back(getBroadcastingMap(rewriter, bias, biasEmptyTensor));
512  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
513  indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
514 
515  if (!isQuantized) {
516  Value conv = rewriter
517  .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
518  loc, linalgConvTy, ValueRange{input, weight},
519  ValueRange{zeroTensor}, strideAttr, dilationAttr)
520  .getResult(0);
521 
522  SmallVector<ReassociationExprs, 4> reassociationMap;
523  createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
524  Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
525  loc, resultTy, conv, reassociationMap);
526 
527  Value result =
528  rewriter
529  .create<linalg::GenericOp>(
530  loc, resultTy, ValueRange({bias, convReshape}),
531  biasEmptyTensor, indexingMaps,
532  getNParallelLoopsAttrs(resultRank),
533  [&](OpBuilder &nestedBuilder, Location nestedLoc,
534  ValueRange args) {
535  Value added = nestedBuilder.create<arith::AddFOp>(
536  loc, args[0], args[1]);
537  nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
538  })
539  .getResult(0);
540  rewriter.replaceOp(op, result);
541  } else {
542  auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
543  auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
544  Value conv =
545  rewriter
546  .create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
547  loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
548  ValueRange{zeroTensor}, strideAttr, dilationAttr)
549  .getResult(0);
550  SmallVector<ReassociationExprs, 4> reassociationMap;
551  createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
552  Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
553  loc, resultTy, conv, reassociationMap);
555  rewriter, loc, bias, convReshape, biasEmptyTensor, indexingMaps);
556  rewriter.replaceOp(op, result);
557  }
558  return success();
559  }
560 };
561 
562 class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
563 public:
565  LogicalResult
566  matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor,
567  ConversionPatternRewriter &rewriter) const final {
568  Location loc = op.getLoc();
569 
570  auto outputTy = cast<ShapedType>(op.getType());
571  auto outputElementTy = outputTy.getElementType();
572 
573  SmallVector<Value> dynDims;
574  dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
575 
576  if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) {
577  dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0);
578  }
579 
580  if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) {
581  dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1);
582  }
583 
584  if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) {
585  dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2);
586  }
587 
588  SmallVector<Value> filteredDims = condenseValues(dynDims);
589 
590  auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
591  Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
592  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
593  loc, outputTy.getShape(), outputTy.getElementType(), filteredDims);
594  Value zeroTensor = rewriter
595  .create<linalg::FillOp>(loc, ValueRange{zero},
596  ValueRange{emptyTensor})
597  .result();
598  if (!op.getQuantizationInfo()) {
599  rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
600  op, TypeRange{op.getType()},
601  ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
602  return success();
603  }
604 
605  auto quantizationInfo = *op.getQuantizationInfo();
606  auto aZp = rewriter.create<arith::ConstantOp>(
607  loc, rewriter.getI32IntegerAttr(quantizationInfo.getAZp()));
608  auto bZp = rewriter.create<arith::ConstantOp>(
609  loc, rewriter.getI32IntegerAttr(quantizationInfo.getBZp()));
610  rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
611  op, TypeRange{op.getType()},
612  ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
613 
614  return success();
615  }
616 };
617 
618 class FullyConnectedConverter
619  : public OpConversionPattern<tosa::FullyConnectedOp> {
620 public:
622  LogicalResult
623  matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor,
624  ConversionPatternRewriter &rewriter) const final {
625  Location loc = op.getLoc();
626  auto outputTy = cast<ShapedType>(op.getType());
627  auto input = op.getInput();
628  auto inputTy = cast<ShapedType>(input.getType());
629 
630  auto bias = op.getBias();
631 
632  auto weight = op.getWeight();
633  auto weightTy = cast<ShapedType>(weight.getType());
634  auto weightShape = weightTy.getShape();
635 
636  auto outputETy = outputTy.getElementType();
637 
638  SmallVector<Value> dynDims;
639  dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
640 
641  if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) {
642  dynDims[0] = rewriter.create<tensor::DimOp>(loc, input, 0);
643  }
644 
645  if (!weightTy.hasRank() || weightTy.isDynamicDim(0)) {
646  dynDims[1] = rewriter.create<tensor::DimOp>(loc, weight, 0);
647  }
648 
649  SmallVector<Value> filteredDims = condenseValues(dynDims);
650 
651  SmallVector<int64_t> permutation{1, 0};
652  auto permutationAttr = rewriter.getI64TensorAttr(permutation);
653  Value permutationValue =
654  rewriter.create<arith::ConstantOp>(loc, permutationAttr);
655 
656  SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[0]};
657  Type newWeightTy =
658  RankedTensorType::get(newWeightShape, weightTy.getElementType());
659 
660  Value transposedWeight = rewriter.create<tosa::TransposeOp>(
661  loc, newWeightTy, weight, permutationValue);
662 
663  Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
664  loc, outputTy.getShape(), outputETy, filteredDims);
665 
666  Value broadcastBias =
667  linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
668 
669  if (!op.getQuantizationInfo()) {
670  Value matmul = rewriter
671  .create<linalg::MatmulOp>(
672  loc, TypeRange{op.getType()},
673  ValueRange{input, transposedWeight}, broadcastBias)
674  ->getResult(0);
675 
676  rewriter.replaceOp(op, matmul);
677  return success();
678  }
679 
680  auto quantizationInfo = *op.getQuantizationInfo();
681  auto inputZp = rewriter.create<arith::ConstantOp>(
682  loc, rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()));
683  auto outputZp = rewriter.create<arith::ConstantOp>(
684  loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()));
685  Value matmul =
686  rewriter
687  .create<linalg::QuantizedMatmulOp>(
688  loc, TypeRange{op.getType()},
689  ValueRange{input, transposedWeight, inputZp, outputZp},
690  broadcastBias)
691  ->getResult(0);
692 
693  rewriter.replaceOp(op, matmul);
694  return success();
695  }
696 };
697 
698 class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
699 public:
701 
702  // Compute the dynamic output sizes of the maxpool operation.
703  static SmallVector<Value>
704  computeDynamicOutputSizes(tosa::MaxPool2dOp op, PatternRewriter &rewriter) {
705  TensorType resultTy = op.getType();
706  Location loc = op.getLoc();
707 
708  TypedValue<TensorType> input = op.getInput();
709  ArrayRef<int64_t> kernel = op.getKernel();
710  ArrayRef<int64_t> pad = op.getPad();
711  ArrayRef<int64_t> stride = op.getStride();
712 
713  SmallVector<Value> dynamicDims;
714 
715  // Batch dimension
716  if (resultTy.isDynamicDim(0))
717  dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
718 
719  // Height/width dimensions
720  for (int64_t dim : {1, 2}) {
721  if (!resultTy.isDynamicDim(dim))
722  continue;
723 
724  // Index into the attribute arrays
725  int64_t index = dim - 1;
726 
727  // Input height/width
728  Value ihw = rewriter.create<tensor::DimOp>(loc, input, dim);
729 
730  // Kernel height/width
731  Value khw = rewriter.create<arith::ConstantIndexOp>(loc, kernel[index]);
732 
733  // Output height/width
734  Value ohw = getConvOrPoolOutputDim(loc, ihw, pad[index * 2],
735  pad[index * 2 + 1], khw, stride[index],
736  /*dilationAttr=*/1, rewriter);
737  dynamicDims.push_back(ohw);
738  }
739 
740  // Channel dimension
741  if (resultTy.isDynamicDim(3))
742  dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 3));
743 
744  return dynamicDims;
745  }
746 
747  LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
748  PatternRewriter &rewriter) const final {
749  Location loc = op.getLoc();
750  TypedValue<TensorType> input = op.getInput();
751  ShapedType inputTy = input.getType();
752 
753  ShapedType resultTy = op.getType();
754  Type resultETy = inputTy.getElementType();
755 
756  SmallVector<Value> dynamicDims = computeDynamicOutputSizes(op, rewriter);
757 
758  // Determine what the initial value needs to be for the max pool op.
759  TypedAttr initialAttr;
760  if (resultETy.isF32() || resultETy.isBF16() || resultETy.isF16())
761  initialAttr = rewriter.getFloatAttr(
762  resultETy, APFloat::getLargest(
763  cast<FloatType>(resultETy).getFloatSemantics(), true));
764 
765  if (isa<IntegerType>(resultETy))
766  initialAttr = rewriter.getIntegerAttr(
767  resultETy,
768  APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth()));
769 
770  if (!initialAttr)
771  return rewriter.notifyMatchFailure(
772  op, "Unsupported initial value for tosa.maxpool_2d op");
773 
774  // Apply padding as necessary.
776  pad.resize(2, 0);
777  llvm::append_range(pad, op.getPad());
778  pad.resize(pad.size() + 2, 0);
779 
780  Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
781 
782  Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
783 
784  ArrayRef<int64_t> kernel = op.getKernel();
785  ArrayRef<int64_t> stride = op.getStride();
786 
787  Attribute strideAttr = rewriter.getI64VectorAttr(stride);
788  Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
789 
790  // Create the linalg op that performs pooling.
791  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
792  loc, resultTy.getShape(), resultTy.getElementType(), dynamicDims);
793 
794  Value filledEmptyTensor =
795  rewriter.create<linalg::FillOp>(loc, initialValue, emptyTensor)
796  .result();
797 
798  Value fakeWindowDims =
799  rewriter.create<tensor::EmptyOp>(loc, kernel, resultETy);
800 
801  rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
802  op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
803  filledEmptyTensor, strideAttr, dilationAttr);
804  return success();
805  }
806 };
807 
808 class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
809 public:
811 
812  LogicalResult matchAndRewrite(tosa::AvgPool2dOp op,
813  PatternRewriter &rewriter) const final {
814  Location loc = op.getLoc();
815  Value input = op.getInput();
816  ShapedType inputTy = cast<ShapedType>(input.getType());
817  Type inElementTy = inputTy.getElementType();
818 
819  ShapedType resultTy = cast<ShapedType>(op.getType());
820  Type resultETy = cast<ShapedType>(op.getType()).getElementType();
821 
822  Type accETy = op.getAccType();
823  ShapedType accTy = resultTy.clone(accETy);
824 
825  auto dynamicDimsOr =
826  checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
827  if (!dynamicDimsOr.has_value())
828  return failure();
829  SmallVector<Value> dynamicDims = *dynamicDimsOr;
830 
831  // Apply padding as necessary.
833  pad.resize(2, 0);
834  llvm::append_range(pad, op.getPad());
835  pad.resize(pad.size() + 2, 0);
836  TypedAttr padAttr = rewriter.getZeroAttr(inElementTy);
837  // Unsupported element type
838  if (!padAttr)
839  return failure();
840  Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter);
841 
842  auto initialAttr = rewriter.getZeroAttr(accETy);
843  Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
844 
845  ArrayRef<int64_t> kernel = op.getKernel();
846  ArrayRef<int64_t> stride = op.getStride();
847 
848  Attribute strideAttr = rewriter.getI64VectorAttr(stride);
849  Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
850 
851  // Create the linalg op that performs pooling.
852  Value poolEmptyTensor = rewriter.create<tensor::EmptyOp>(
853  loc, accTy.getShape(), accETy, dynamicDims);
854 
855  Value filledEmptyTensor =
856  rewriter
857  .create<linalg::FillOp>(loc, ValueRange{initialValue},
858  ValueRange{poolEmptyTensor})
859  .result();
860 
861  Value fakeWindowDims =
862  rewriter.create<tensor::EmptyOp>(loc, kernel, accETy);
863 
864  // Sum across the pooled region.
865  Value poolingOp = rewriter
866  .create<linalg::PoolingNhwcSumOp>(
867  loc, ArrayRef<Type>{accTy},
868  ValueRange{paddedInput, fakeWindowDims},
869  filledEmptyTensor, strideAttr, dilationAttr)
870  .getResult(0);
871 
872  // Normalize the summed value by the number of elements grouped in each
873  // pool.
874  Value iH = rewriter.create<tensor::DimOp>(loc, poolingOp, 1);
875  Value iW = rewriter.create<tensor::DimOp>(loc, poolingOp, 2);
876 
877  auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
878  iH = rewriter.create<arith::SubIOp>(loc, iH, one);
879  iW = rewriter.create<arith::SubIOp>(loc, iW, one);
880 
881  Value genericEmptyTensor = rewriter.create<tensor::EmptyOp>(
882  loc, resultTy.getShape(), resultETy, dynamicDims);
883 
884  auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
885  auto genericOp = rewriter.create<linalg::GenericOp>(
886  loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
887  ValueRange{genericEmptyTensor},
888  ArrayRef<AffineMap>({affineMap, affineMap}),
889  getNParallelLoopsAttrs(resultTy.getRank()),
890  [&](OpBuilder &b, Location loc, ValueRange args) {
891  auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
892 
893  // Determines what the portion of valid input is covered by the
894  // kernel.
895  auto padFn = [&](Value valid, Value pos, int64_t pad) -> Value {
896  if (pad == 0)
897  return valid;
898 
899  auto padVal = rewriter.create<arith::ConstantIndexOp>(loc, pad);
900  Value dpos = rewriter.create<arith::SubIOp>(loc, pos, padVal);
901 
902  Value offset = rewriter.create<arith::MinSIOp>(loc, dpos, zero);
903  return rewriter.create<arith::AddIOp>(loc, valid, offset)
904  ->getResult(0);
905  };
906 
907  auto coverageFn = [&](int64_t i, Value isize) -> Value {
908  Value strideVal =
909  rewriter.create<arith::ConstantIndexOp>(loc, stride[i - 1]);
910  Value val =
911  rewriter.create<arith::ConstantIndexOp>(loc, kernel[i - 1]);
912 
913  // Find the position relative to the input tensor's ends.
914  Value left = rewriter.create<linalg::IndexOp>(loc, i);
915  Value right = rewriter.create<arith::SubIOp>(loc, isize, left);
916  left = rewriter.create<arith::MulIOp>(loc, left, strideVal);
917  right = rewriter.create<arith::MulIOp>(loc, right, strideVal);
918 
919  // Determine how much padding was included.
920  val = padFn(val, left, pad[i * 2]);
921  val = padFn(val, right, pad[i * 2 + 1]);
922  return rewriter.create<arith::MaxSIOp>(loc, one, val);
923  };
924 
925  // Compute the indices from either end.
926  Value kH3 = coverageFn(1, iH);
927  Value kW3 = coverageFn(2, iW);
928 
929  // Compute the total number of elements and normalize.
930  auto count = rewriter.create<arith::IndexCastOp>(
931  loc, rewriter.getI32Type(),
932  rewriter.create<arith::MulIOp>(loc, kH3, kW3));
933 
934  // Divide by the number of summed values. For floats this is just
935  // a div however for quantized values input normalization had
936  // to be applied.
937  Value poolVal = args[0];
938  if (isa<FloatType>(accETy)) {
939  auto countF = rewriter.create<arith::SIToFPOp>(loc, accETy, count);
940  poolVal = rewriter.create<arith::DivFOp>(loc, poolVal, countF)
941  ->getResult(0);
942  if (accETy.getIntOrFloatBitWidth() >
943  resultETy.getIntOrFloatBitWidth())
944  poolVal =
945  rewriter.create<arith::TruncFOp>(loc, resultETy, poolVal);
946  } else {
947 
948  // If we have quantization information we need to apply an offset
949  // for the input zp value.
950  if (op.getQuantizationInfo()) {
951  auto quantizationInfo = *op.getQuantizationInfo();
952  auto inputZp = rewriter.create<arith::ConstantOp>(
953  loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp()));
954  Value offset =
955  rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
956  poolVal =
957  rewriter.create<arith::SubIOp>(loc, accETy, poolVal, offset);
958  }
959 
960  // Compute: k = 32 - count_leading_zeros(value - 1)
961  Value one32 = rewriter.create<arith::ConstantOp>(
962  loc, rewriter.getI32IntegerAttr(1));
963  Value thirtyTwo32 = rewriter.create<arith::ConstantOp>(
964  loc, rewriter.getI32IntegerAttr(32));
965 
966  Value countSubOne =
967  rewriter.create<arith::SubIOp>(loc, count, one32);
968  Value leadingZeros =
969  rewriter.create<math::CountLeadingZerosOp>(loc, countSubOne);
970  Value k =
971  rewriter.create<arith::SubIOp>(loc, thirtyTwo32, leadingZeros);
972 
973  // Compute: numerator = ((1 << 30) + 1) << k
974  Value k64 =
975  rewriter.create<arith::ExtUIOp>(loc, rewriter.getI64Type(), k);
976  Value thirtyShiftPlusOne = rewriter.create<arith::ConstantOp>(
977  loc, rewriter.getI64IntegerAttr((1 << 30) + 1));
978  Value numerator =
979  rewriter.create<arith::ShLIOp>(loc, thirtyShiftPlusOne, k64);
980 
981  // Compute: scale.multiplier = numerator / value;
982  Value count64 = rewriter.create<arith::ExtUIOp>(
983  loc, rewriter.getI64Type(), count);
984  Value multiplier =
985  rewriter.create<arith::DivUIOp>(loc, numerator, count64);
986  multiplier = rewriter.create<arith::TruncIOp>(
987  loc, rewriter.getI32Type(), multiplier);
988 
989  // Compute: scale.shift = 30 + k
990  Value k8 =
991  rewriter.create<arith::TruncIOp>(loc, rewriter.getI8Type(), k);
992  Value thirty8 = rewriter.create<arith::ConstantOp>(
993  loc, rewriter.getI8IntegerAttr(30));
994  Value shift = rewriter.create<arith::AddIOp>(loc, k8, thirty8);
995 
996  auto scaled =
997  rewriter
998  .create<tosa::ApplyScaleOp>(loc, rewriter.getI32Type(),
999  poolVal, multiplier, shift,
1000  rewriter.getBoolAttr(false))
1001  .getResult();
1002 
1003  // If we have quantization information we need to apply output
1004  // zeropoint.
1005  if (op.getQuantizationInfo()) {
1006  auto quantizationInfo = *op.getQuantizationInfo();
1007  auto outputZp = rewriter.create<arith::ConstantOp>(
1008  loc, b.getIntegerAttr(scaled.getType(),
1009  quantizationInfo.getOutputZp()));
1010  scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
1011  .getResult();
1012  }
1013 
1014  // Apply Clip.
1015  int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
1016 
1017  auto min = rewriter.create<arith::ConstantIntOp>(
1018  loc, APInt::getSignedMinValue(outBitwidth).getSExtValue(),
1019  accETy);
1020  auto max = rewriter.create<arith::ConstantIntOp>(
1021  loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
1022  accETy);
1023  auto clamp = clampIntHelper(loc, scaled, min, max, rewriter,
1024  /*isUnsigned=*/false);
1025 
1026  poolVal = clamp;
1027  // Convert type.
1028  if (resultETy != clamp.getType()) {
1029  poolVal =
1030  rewriter.create<arith::TruncIOp>(loc, resultETy, poolVal);
1031  }
1032  }
1033 
1034  rewriter.create<linalg::YieldOp>(loc, poolVal);
1035  });
1036 
1037  rewriter.replaceOp(op, genericOp.getResult(0));
1038  return success();
1039  }
1040 };
1041 
1042 class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
1043 public:
1045 
1046  LogicalResult matchAndRewrite(tosa::TransposeOp op,
1047  PatternRewriter &rewriter) const final {
1048  SmallVector<int32_t> constantPerms;
1049  if (failed(op.getConstantPerms(constantPerms)))
1050  return failure();
1051 
1052  Location loc = op.getLoc();
1053  // The verifier should have made sure we have a valid TOSA permutation
1054  // tensor. isPermutationVector doesn't actually check the TOSA perms we
1055  // expect.
1056  SmallVector<OpFoldResult> inputSizes =
1057  tensor::getMixedSizes(rewriter, loc, op.getInput1());
1058  auto permutedSizes =
1059  applyTOSAPermutation<OpFoldResult>(inputSizes, constantPerms);
1060 
1061  auto permutedInit = rewriter.create<tensor::EmptyOp>(
1062  loc, permutedSizes, op.getInput1().getType().getElementType());
1063  rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
1064  op, op.getInput1(), permutedInit,
1065  llvm::to_vector(llvm::map_range(
1066  constantPerms, [](int32_t v) -> int64_t { return v; })));
1067  return success();
1068  }
1069 };
1070 } // namespace
1071 
1073  RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options) {
1074  if (options.preferConv2DKernelLayoutHWCF) {
1075  patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
1076  linalg::Conv2DNhwcHwcfQOp>>(
1077  patterns->getContext());
1078  } else {
1079  patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp,
1080  linalg::Conv2DNhwcFhwcQOp>>(
1081  patterns->getContext());
1082  }
1083  patterns->add<
1084  // clang-format off
1085  ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
1086  DepthwiseConvConverter,
1087  MatMulConverter,
1088  MaxPool2dConverter,
1089  AvgPool2dConverter,
1090  FullyConnectedConverter,
1091  TransposeConverter
1092  >(patterns->getContext());
1093  // clang-format on
1094 }
static llvm::ManagedStatic< PassManagerOptions > options
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source, Value result)
static mlir::Value applyPad(Location loc, Value input, ArrayRef< int64_t > pad, TypedAttr padAttr, OpBuilder &rewriter)
static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter, Location loc, Value source, Value result)
static void createDepthwiseConvCollapseMap(int64_t outputRank, SmallVector< ReassociationExprs, 4 > &reassociationMap, OpBuilder &rewriter)
static mlir::Value linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, Value conv, Value result, ArrayRef< AffineMap > indexingMaps)
static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim, int64_t padBeforeAttr, int64_t padAfterAttr, Value kernelDim, int64_t strideAttr, int64_t dilationAttr, OpBuilder &rewriter)
static mlir::Value reifyConstantDim(int64_t attr, ImplicitLocOpBuilder &builder)
static SmallVector< Value > inferDynamicDimsForConv(Location loc, Value input, Value weight, ShapedType resultTy, ArrayRef< int64_t > padAttr, ArrayRef< int64_t > strideAttr, ArrayRef< int64_t > dilationAttr, ArrayRef< int64_t > inputSizeDims, ArrayRef< int64_t > kernelSizeDims, OpBuilder &rewriter)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:240
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:427
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:294
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:412
IntegerType getI64Type()
Definition: Builders.cpp:109
IntegerType getI32Type()
Definition: Builders.cpp:107
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:152
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:140
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:404
MLIRContext * getContext() const
Definition: Builders.h:56
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:168
IntegerType getI8Type()
Definition: Builders.cpp:103
IntegerAttr getI8IntegerAttr(int8_t value)
Definition: Builders.cpp:261
This class implements a pattern rewriter for use with ConversionPatterns.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:216
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:102
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:59
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:99
bool isF16() const
Definition: Types.cpp:57
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
bool isBF16() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:66
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
void populateTosaToLinalgNamedConversionPatterns(RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options)
Populates conversion passes from TOSA dialect to Linalg named operations.
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358