MLIR  14.0.0git
TosaOps.cpp
Go to the documentation of this file.
1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://developer.mlplatform.org/w/tosa/
12 //
13 //===----------------------------------------------------------------------===//
14 
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Matchers.h"
22 #include "mlir/IR/PatternMatch.h"
26 #include "llvm/ADT/DenseMap.h"
27 
28 using namespace mlir;
29 using namespace mlir::tosa;
30 
31 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
32 
33 //===----------------------------------------------------------------------===//
34 // Tosa dialect structs and interface includes.
35 //===----------------------------------------------------------------------===//
36 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
37 #include "mlir/Dialect/Tosa/IR/TosaStructs.cpp.inc"
38 
39 namespace {
40 //===----------------------------------------------------------------------===//
41 // Dialect Function Inliner Interface.
42 //===----------------------------------------------------------------------===//
43 struct TosaInlinerInterface : public DialectInlinerInterface {
45 
46  //===--------------------------------------------------------------------===//
47  // Analysis Hooks.
48  //===--------------------------------------------------------------------===//
49 
50  /// All operations can be inlined by default.
51  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
52  BlockAndValueMapping &map) const final {
53  return true;
54  }
55 
56  /// All regions with If and While parent operators can be inlined.
57  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
58  BlockAndValueMapping &map) const final {
59  return (isa<tosa::IfOp>(dest->getParentOp()) ||
60  isa<tosa::WhileOp>(dest->getParentOp()));
61  }
62 };
63 } // namespace
64 
65 //===----------------------------------------------------------------------===//
66 // TOSA control flow support.
67 //===----------------------------------------------------------------------===//
68 
69 /// Returns the while loop body.
70 Region &tosa::WhileOp::getLoopBody() { return body(); }
71 
72 bool tosa::WhileOp::isDefinedOutsideOfLoop(Value value) {
73  return !body().isAncestor(value.getParentRegion());
74 }
75 
76 LogicalResult WhileOp::moveOutOfLoop(ArrayRef<mlir::Operation *> ops) {
77  if (ops.empty())
78  return success();
79 
80  Operation *tosaWhileOp = this->getOperation();
81  for (auto *op : ops)
82  op->moveBefore(tosaWhileOp);
83 
84  return success();
85 }
86 
87 //===----------------------------------------------------------------------===//
88 // Tosa dialect initialization.
89 //===----------------------------------------------------------------------===//
90 
91 void TosaDialect::initialize() {
92  addOperations<
93 #define GET_OP_LIST
94 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
95  >();
96  addInterfaces<TosaInlinerInterface>();
97 }
98 
100  Type type, Location loc) {
101  // Tosa dialect constants only support ElementsAttr unlike standard dialect
102  // constant which supports all attributes.
103  if (value.isa<ElementsAttr>())
104  return builder.create<tosa::ConstOp>(loc, type, value.cast<ElementsAttr>());
105  return nullptr;
106 }
107 
108 //===----------------------------------------------------------------------===//
109 // Operator Canonicalizers.
110 //===----------------------------------------------------------------------===//
111 
112 struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
114 
115  LogicalResult matchAndRewrite(tosa::ConcatOp op,
116  PatternRewriter &rewriter) const override {
117  if (op.input1().size() != 1)
118  return failure();
119  if (op.input1().front().getType() != op.getType()) {
120  rewriter
121  .replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
122  op.input1().front())
123  .getResult();
124  return success();
125  }
126 
127  rewriter.replaceOp(op, op.input1().front());
128  return success();
129  }
130 };
131 
132 void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
133  MLIRContext *context) {
134  results.insert<ConcatOptimization>(context);
135 }
136 
137 struct ReshapeReshapeOptimization : public OpRewritePattern<tosa::ReshapeOp> {
139 
140  LogicalResult matchAndRewrite(tosa::ReshapeOp op,
141  PatternRewriter &rewriter) const override {
142  Value input = op.input1();
143  Operation *definingOp = input.getDefiningOp();
144  if (!definingOp)
145  return failure();
146 
147  if (tosa::ReshapeOp reshapeOp = dyn_cast<tosa::ReshapeOp>(definingOp)) {
148  rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
149  op, op.getType(), reshapeOp.input1(), op.new_shape());
150  return success();
151  }
152 
153  return failure();
154  }
155 };
156 
157 struct ReshapeConstOptimization : public OpRewritePattern<tosa::ReshapeOp> {
159 
160  LogicalResult matchAndRewrite(tosa::ReshapeOp op,
161  PatternRewriter &rewriter) const override {
162  Value input = op.input1();
163  ArrayAttr newShape = op.new_shape();
164 
165  // Check if input is constant
166  DenseElementsAttr inputAttr;
167  if (!matchPattern(input, m_Constant(&inputAttr)))
168  return failure();
169 
170  // Check if has >1 consumer and is not splat
171  if (!input.hasOneUse() && !inputAttr.isSplat())
172  return failure();
173 
174  // Grab the new shape
175  SmallVector<int64_t> newShapeValues = llvm::to_vector<6>(
176  llvm::map_range(newShape.getValue(), [](const Attribute &val) {
177  return val.cast<IntegerAttr>().getValue().getSExtValue();
178  }));
179 
180  // Build new const op with correct output shape
181  ShapedType inputShape = input.getType().cast<ShapedType>();
182  DenseElementsAttr outputAttr =
183  inputAttr.reshape(inputShape.clone(newShapeValues));
184  rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputAttr.getType(),
185  outputAttr);
186  return success();
187  }
188 };
189 
190 void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
191  MLIRContext *context) {
192  results.insert<ReshapeReshapeOptimization>(context);
193  results.insert<ReshapeConstOptimization>(context);
194 }
195 
197  : public OpRewritePattern<tosa::TransposeOp> {
199 
200  LogicalResult matchAndRewrite(tosa::TransposeOp op,
201  PatternRewriter &rewriter) const override {
202  auto outputType = op.getType().cast<ShapedType>();
203  ArrayRef<int64_t> outputShape = outputType.getShape();
204  // TOSA supports quantized types.
205  if (!outputType.getElementType().isIntOrIndexOrFloat())
206  return failure();
207 
208  DenseElementsAttr inputValues;
209  if (!matchPattern(op.input1(), m_Constant(&inputValues)))
210  return failure();
211  // Make sure the input is a constant that has a single user.
212  if (!llvm::hasSingleElement(op.input1().getDefiningOp()->getUsers()))
213  return failure();
214 
215  DenseIntElementsAttr permAttr;
216  if (!matchPattern(op.perms(), m_Constant(&permAttr)))
217  return failure();
218  auto permValues = llvm::to_vector<6>(llvm::map_range(
219  // TOSA allows both 32- and 64-bit integer tensors here.
220  permAttr.getValues<APInt>(),
221  [](const APInt &val) { return val.getZExtValue(); }));
222 
223  auto inputType = op.input1().getType().cast<ShapedType>();
224  ArrayRef<int64_t> inputShape = inputType.getShape();
225  int64_t numElements = inputType.getNumElements();
226 
227  SmallVector<Attribute, 4> outputValues;
228  outputValues.resize(numElements);
229 
230  // Transpose the input constant. Because we don't know its rank in advance,
231  // we need to loop over the range [0, element count) and delinearize the
232  // index.
233  auto attrValues = inputValues.getValues<Attribute>();
234  for (int srcLinearIndex = 0; srcLinearIndex < numElements;
235  ++srcLinearIndex) {
236  SmallVector<uint64_t, 6> srcIndices(inputType.getRank(), 0);
237  int totalCount = srcLinearIndex;
238  for (int dim = inputType.getRank() - 1; dim >= 0; --dim) {
239  srcIndices[dim] = totalCount % inputShape[dim];
240  totalCount /= inputShape[dim];
241  }
242 
243  SmallVector<uint64_t, 6> dstIndices(outputType.getRank(), 0);
244  for (int dim = outputType.getRank() - 1; dim >= 0; --dim)
245  dstIndices[dim] = srcIndices[permValues[dim]];
246 
247  uint64_t dstLinearIndex = dstIndices.front();
248  for (int dim = 1; dim < outputType.getRank(); ++dim)
249  dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
250 
251  outputValues[dstLinearIndex] = attrValues[srcIndices];
252  }
253 
254  rewriter.replaceOpWithNewOp<tosa::ConstOp>(
255  op, outputType, DenseElementsAttr::get(outputType, outputValues));
256  return success();
257  }
258 };
259 
260 struct NoOpOptimization : public OpRewritePattern<tosa::TransposeOp> {
262 
263  LogicalResult matchAndRewrite(tosa::TransposeOp op,
264  PatternRewriter &rewriter) const override {
265  auto perm = op.perms();
266 
267  DenseIntElementsAttr permAttr;
268  if (!matchPattern(perm, m_Constant(&permAttr))) {
269  return failure();
270  }
271 
272  SmallVector<int64_t> permValues = llvm::to_vector<6>(
273  llvm::map_range(permAttr.getValues<APInt>(),
274  [](const APInt &val) { return val.getSExtValue(); }));
275 
276  for (int i = 0, s = permValues.size(); i < s; i++) {
277  if (i != permValues[i]) {
278  return failure();
279  }
280  }
281 
282  rewriter.replaceOp(op, op.input1());
283  return success();
284  }
285 };
286 
287 void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
288  MLIRContext *context) {
289  results.insert<ConstantTransposeOptimization>(context);
290  results.insert<NoOpOptimization>(context);
291 }
292 
293 struct AddZeroOptimization : public OpRewritePattern<tosa::AddOp> {
295 
297  PatternRewriter &rewriter) const override {
298  auto input1 = op.input1();
299  auto input2 = op.input2();
300 
301  DenseElementsAttr input1Attr;
302  if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
303  input2.getType() == op.getType()) {
304  if (input1Attr.getType().getElementType().isa<IntegerType>() &&
305  input1Attr.getSplatValue<APInt>().isZero()) {
306  rewriter.replaceOp(op, op.input2());
307  return success();
308  }
309  }
310 
311  DenseElementsAttr input2Attr;
312  if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() &&
313  input1.getType() == op.getType()) {
314  if (input2Attr.getType().getElementType().isa<IntegerType>() &&
315  input2Attr.getSplatValue<APInt>().isZero()) {
316  rewriter.replaceOp(op, op.input1());
317  return success();
318  }
319  }
320 
321  return failure();
322  }
323 };
324 
325 void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
326  MLIRContext *context) {
327  results.insert<AddZeroOptimization>(context);
328 }
329 
330 struct MulOneOptimization : public OpRewritePattern<tosa::MulOp> {
332 
334  PatternRewriter &rewriter) const override {
335  auto input1 = op.input1();
336  auto input2 = op.input2();
337 
338  DenseElementsAttr input1Attr;
339  if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
340  input2.getType() == op.getType()) {
341  if (input1Attr.getType().getElementType().isa<FloatType>() &&
342  input1Attr.getSplatValue<APFloat>().isExactlyValue(1)) {
343  rewriter.replaceOp(op, op.input2());
344  return success();
345  }
346 
347  if (input1Attr.getType().getElementType().isa<IntegerType>() &&
348  matchPattern(input1, m_One())) {
349  rewriter.replaceOp(op, op.input2());
350  return success();
351  }
352  }
353 
354  DenseElementsAttr input2Attr;
355  if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() &&
356  input1.getType() == op.getType()) {
357  if (input2Attr.getType().getElementType().isa<FloatType>() &&
358  input2Attr.getSplatValue<APFloat>().isExactlyValue(1)) {
359  rewriter.replaceOp(op, op.input1());
360  return success();
361  }
362 
363  if (input2Attr.getType().getElementType().isa<IntegerType>() &&
364  matchPattern(input2, m_One())) {
365  rewriter.replaceOp(op, op.input1());
366  return success();
367  }
368  }
369 
370  return failure();
371  }
372 };
373 
374 void MulOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
375  MLIRContext *context) {
376  results.insert<MulOneOptimization>(context);
377 }
378 
379 struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
381 
383  PatternRewriter &rewriter) const override {
384  if (op.pad_const())
385  return failure();
386 
387  auto input = op.input1();
388  auto padding = op.padding();
389 
390  ShapedType inputTy = input.getType().cast<ShapedType>();
391  Type elementTy = inputTy.getElementType();
392 
393  Attribute constantAttr;
394  if (elementTy.isa<FloatType>())
395  constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
396  else if (elementTy.isa<IntegerType>() && !op.quantization_info())
397  constantAttr = rewriter.getIntegerAttr(elementTy, 0);
398  else if (elementTy.isa<IntegerType>() && op.quantization_info()) {
399  auto value = op.quantization_info().getValue().input_zp().getValue();
400  constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue());
401  }
402 
403  if (!constantAttr) {
404  return rewriter.notifyMatchFailure(
405  op,
406  "tosa.pad to linalg lowering encountered an unknown element type");
407  }
408 
409  auto denseAttr = DenseElementsAttr::get(
410  RankedTensorType::get({}, elementTy), constantAttr);
411  auto constantVal = rewriter.create<tosa::ConstOp>(
412  op.getLoc(), denseAttr.getType(), denseAttr);
413 
414  rewriter.replaceOpWithNewOp<tosa::PadOp>(
415  op, op.getType(), ValueRange{input, padding, constantVal},
416  op->getAttrs());
417  return success();
418  }
419 };
420 
421 void PadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
422  MLIRContext *context) {
423  results.insert<MaterializePadValue>(context);
424 }
425 
426 struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
428 
429  LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
430  PatternRewriter &rewriter) const override {
431  Value input = op.input();
432  Value output = op.output();
433  ShapedType inputType = input.getType().cast<ShapedType>();
434  ShapedType outputType = output.getType().cast<ShapedType>();
435 
436  if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
437  return failure();
438  }
439 
440  // If the output and input shapes are 1x1, then this is a no op.
441  ArrayRef<int64_t> outputShape = outputType.getShape();
442  if (outputShape[1] != 1 || outputShape[2] != 1) {
443  return failure();
444  }
445 
446  ArrayRef<int64_t> inputShape = inputType.getShape();
447  if (inputShape[1] != 1 || inputShape[2] != 1) {
448  return failure();
449  }
450 
451  rewriter.replaceOp(op, input);
452  return success();
453  }
454 };
455 
456 void MaxPool2dOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
457  MLIRContext *context) {
458  results.insert<MaxPool2dIsNoOp>(context);
459 }
460 
461 struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
463 
464  LogicalResult matchAndRewrite(tosa::ClampOp op,
465  PatternRewriter &rewriter) const override {
466  Value input = op.input();
467  auto inputType = op.input().getType().template dyn_cast<RankedTensorType>();
468  auto inputElementType = inputType.getElementType();
469 
470  if (!inputType.hasStaticShape()) {
471  return failure();
472  }
473 
474  if (inputElementType.isF32()) {
475  auto minClamp = op.min_fp();
476  auto maxClamp = op.max_fp();
477  bool isMin = (minClamp.isLargest() || minClamp.isInfinity()) &&
478  minClamp.isNegative();
479  bool isMax = (maxClamp.isLargest() || maxClamp.isInfinity()) &&
480  !maxClamp.isNegative();
481 
482  if (isMin && isMax) {
483  rewriter.replaceOp(op, input);
484  return success();
485  }
486  return failure();
487  }
488 
489  if (inputElementType.isUnsignedInteger()) {
490  int64_t minClamp = op.min_int();
491  int64_t maxClamp = op.max_int();
492 
493  int64_t intMin =
494  APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
495  .getZExtValue();
496  int64_t intMax =
497  APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
498  .getZExtValue();
499 
500  if (minClamp <= intMin && maxClamp >= intMax) {
501  rewriter.replaceOp(op, input);
502  return success();
503  }
504  return failure();
505  }
506 
507  if (inputElementType.isa<IntegerType>()) {
508  int64_t minClamp = op.min_int();
509  int64_t maxClamp = op.max_int();
510 
511  int64_t intMin =
512  APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
513  .getSExtValue();
514  int64_t intMax =
515  APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
516  .getSExtValue();
517 
518  if (minClamp <= intMin && maxClamp >= intMax) {
519  rewriter.replaceOp(op, input);
520  return success();
521  }
522  return failure();
523  }
524 
525  return failure();
526  }
527 };
528 
529 struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
531 
532  LogicalResult matchAndRewrite(tosa::ClampOp op,
533  PatternRewriter &rewriter) const override {
534  Value input = op.input();
535 
536  Operation *definingOp = input.getDefiningOp();
537  if (!definingOp)
538  return failure();
539 
540  if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
541  auto min_fp = std::max(op.min_fp(), clampOp.min_fp()).convertToFloat();
542  auto max_fp = std::min(op.max_fp(), clampOp.max_fp()).convertToFloat();
543 
544  auto min_int = std::max(op.min_int(), clampOp.min_int());
545  auto max_int = std::min(op.max_int(), clampOp.max_int());
546 
547  rewriter.replaceOpWithNewOp<tosa::ClampOp>(
548  op, op.getType(), clampOp.input(),
549  rewriter.getI64IntegerAttr(min_int),
550  rewriter.getI64IntegerAttr(max_int), rewriter.getF32FloatAttr(min_fp),
551  rewriter.getF32FloatAttr(max_fp));
552  return success();
553  }
554 
555  return failure();
556  }
557 };
558 
559 void ClampOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
560  MLIRContext *context) {
561  results.insert<ClampIsNoOp>(context);
562  results.insert<ClampClampOptimization>(context);
563 }
564 
565 //===----------------------------------------------------------------------===//
566 // Operator Folders.
567 //===----------------------------------------------------------------------===//
568 
569 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
570  if (input().getType() == getType())
571  return input();
572  return {};
573 }
574 
575 OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
576  assert(operands.empty() && "constant has no operands");
577  return valueAttr();
578 }
579 
580 #define ReduceFolder(OP) \
581  OpFoldResult OP::fold(ArrayRef<Attribute> operands) { \
582  ShapedType inputTy = input().getType().cast<ShapedType>(); \
583  if (!inputTy.hasRank()) \
584  return {}; \
585  if (inputTy.getDimSize(axis()) == 1) \
586  return input(); \
587  return {}; \
588  }
589 
590 ReduceFolder(ReduceAllOp) ReduceFolder(ReduceAnyOp) ReduceFolder(ReduceMaxOp)
591  ReduceFolder(ReduceMinOp) ReduceFolder(ReduceProdOp)
592  ReduceFolder(ReduceSumOp)
593 #undef ReduceFolder
594 
595  OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
596  auto inputTy = input1().getType().dyn_cast<RankedTensorType>();
597  auto outputTy = getType().dyn_cast<RankedTensorType>();
598 
599  if (!inputTy || !outputTy || inputTy != outputTy)
600  return {};
601  return input1();
602 }
603 
604 OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
605  // If the pad is all zeros we can fold this operation away.
606  if (operands[1]) {
607  auto densePad = operands[1].cast<DenseElementsAttr>();
608  if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
609  return input1();
610  }
611  }
612 
613  return {};
614 }
615 
616 OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
617  auto inputTy = input().getType().dyn_cast<RankedTensorType>();
618  auto outputTy = getType().dyn_cast<RankedTensorType>();
619 
620  if (!inputTy || !outputTy || inputTy != outputTy)
621  return {};
622  if (inputTy.hasStaticShape())
623  return input();
624 
625  return {};
626 }
627 
628 OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
629  bool allOnes = true;
630  for (Attribute val : multiples().getValue()) {
631  allOnes = allOnes && val.cast<IntegerAttr>().getValue().getSExtValue() == 1;
632  }
633 
634  if (allOnes && input1().getType() == getType())
635  return input1();
636  return {};
637 }
638 
639 OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
640  if (!operands[1])
641  return {};
642 
643  // Transposing splat values just means reshaping.
644  if (auto input = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
645  if (input.isSplat())
646  return input.reshape(getType().cast<ShapedType>());
647  }
648 
649  auto perms = llvm::to_vector<6>(llvm::map_range(
650  operands[1].cast<DenseIntElementsAttr>().getValues<APInt>(),
651  [](const APInt &val) { return val.getSExtValue(); }));
652 
653  if (llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms) &&
654  input1().getType() == getType())
655  return input1();
656  return {};
657 }
658 
659 //===----------------------------------------------------------------------===//
660 // TOSA Operator Verifiers.
661 //===----------------------------------------------------------------------===//
662 
663 template <typename T>
665  // All TOSA conv ops have an input() and weight().
666  auto inputType = op.input().getType().template dyn_cast<RankedTensorType>();
667  auto weightType = op.weight().getType().template dyn_cast<RankedTensorType>();
668 
669  // Must be ranked tensor types
670  if (!inputType) {
671  op.emitOpError("expect a ranked tensor for input, got ") << op.input();
672  return failure();
673  }
674  if (!weightType) {
675  op.emitOpError("expect a ranked tensor for weight, got ") << op.weight();
676  return failure();
677  }
678 
679  auto inputEType = inputType.getElementType();
680  auto weightEType = weightType.getElementType();
681 
682  bool inputIsQuant = !inputEType.template isa<FloatType>();
683  bool weightIsQuant = !weightEType.template isa<FloatType>();
684 
685  // Either both must be quantized or both unquantized.
686  if (inputIsQuant != weightIsQuant) {
687  op.emitOpError(
688  "expect both input and weight to be float or not together, got ")
689  << inputEType << " and " << weightEType;
690  return failure();
691  }
692 
693  // Quantized type must have constructed the quantizationattr, and unquantized
694  // types should not have a quantizationattr.
695  if ((inputIsQuant && !op.quantization_info()) ||
696  (!inputIsQuant && op.quantization_info())) {
697  op.emitOpError("quantizationattr is required for quantized type, and not "
698  "allowed for float type");
699  return failure();
700  }
701 
702  return success();
703 }
704 
705 static LogicalResult verifyAveragePoolOp(tosa::AvgPool2dOp op) {
706  auto inputETy = op.input().getType().cast<ShapedType>().getElementType();
707  auto resultETy = op.getType().cast<ShapedType>().getElementType();
708 
709  if (auto quantType = inputETy.dyn_cast<mlir::quant::UniformQuantizedType>())
710  inputETy = quantType.getStorageType();
711 
712  if (auto quantType = resultETy.dyn_cast<mlir::quant::UniformQuantizedType>())
713  resultETy = quantType.getStorageType();
714 
715  if (inputETy.isF32() && resultETy.isF32())
716  return success();
717  if (inputETy.isInteger(8) && resultETy.isInteger(8))
718  return success();
719  if (inputETy.isInteger(16) && resultETy.isInteger(16))
720  return success();
721 
722  return op.emitOpError("input/output element types are incompatible.");
723 }
724 
725 //===----------------------------------------------------------------------===//
726 // TOSA Operator Quantization Builders.
727 //===----------------------------------------------------------------------===//
728 
729 /// This builder is called on all convolution operators except TransposeConv,
730 /// which has specialized output shape semantics. The builder also defines the
731 /// bitwidth of the output given the bit width of the input & weight content.
732 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
733  Type outputType, Value input, Value weight,
734  Value bias, ArrayAttr pad,
735  ArrayAttr stride, ArrayAttr dilation) {
736 
737  result.addOperands({input, weight, bias});
738  result.addAttribute("pad", pad);
739  result.addAttribute("stride", stride);
740  result.addAttribute("dilation", dilation);
741 
742  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
743  if (quantAttr) {
744  result.addAttribute("quantization_info", quantAttr);
745  result.addTypes(
746  buildConvOpResultTypeInfo(builder, outputType, input, weight));
747  } else {
748  result.addTypes(outputType);
749  }
750 }
751 
752 /// Handles tosa.transpose_conv2d which has outpad and output shape attributes.
753 static void
755  Type outputType, Value input, Value weight,
756  Value bias, ArrayAttr outpad, ArrayAttr stride,
757  ArrayAttr dilation, ArrayAttr outputShape) {
758  result.addOperands({input, weight, bias});
759  result.addAttribute("out_pad", outpad);
760  result.addAttribute("stride", stride);
761  result.addAttribute("dilation", dilation);
762  result.addAttribute("out_shape", outputShape);
763  auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
764 
765  if (quantAttr) {
766  result.addAttribute("quantization_info", quantAttr);
767  result.addTypes(
768  buildConvOpResultTypeInfo(builder, outputType, input, weight));
769  } else {
770  result.addTypes(outputType);
771  }
772 }
773 
774 /// The tosa.fully_connected op has its own builder as it does not have
775 /// strides/dilation/padding.
776 static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
777  Type outputType, Value input, Value weight,
778  Value bias) {
779 
780  result.addOperands({input, weight, bias});
781  auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
782  if (quantAttr) {
783  result.addAttribute("quantization_info", quantAttr);
784  result.addTypes(
785  buildConvOpResultTypeInfo(builder, outputType, input, weight));
786  } else {
787  result.addTypes(outputType);
788  }
789 }
790 
791 /// The tosa.matmul op is also intended to be generated where a fully_connected
792 /// op must be constructed where the weight is not a constant. In this case,
793 /// the fully_connected op must be expressed using matmul.
794 /// TODO: Add link to the leglization document explaining this.
796  OperationState &result, Type outputType,
797  Value a, Value b) {
798  result.addOperands({a, b});
799  auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
800 
801  if (quantAttr) {
802  result.addAttribute("quantization_info", quantAttr);
803 
804  auto inputType = a.getType().dyn_cast<ShapedType>();
805  assert(inputType && "Input must be a shaped tensor type!");
806 
807  auto inputQType = inputType.getElementType()
809  assert(inputQType && "Tensor must have quantized datatype!");
810 
811  unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
812 
813  auto outputShapedType = outputType.dyn_cast<ShapedType>();
814  assert(outputShapedType && "Output must be a shaped type");
815 
816  IntegerType accElementType;
817  if (inputBits == 16)
818  accElementType = builder.getIntegerType(48);
819  else
820  accElementType = builder.getI32Type();
821  auto accType = outputShapedType.clone(accElementType);
822  result.addTypes(accType);
823  } else {
824  result.addTypes(outputType);
825  }
826 }
827 
828 /// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr
829 /// but avg_pool operator has its own builder as it has additional parameters
830 /// not part of the unary ops.
832  OperationState &result,
833  Type outputType, Value input,
834  ArrayAttr kernel, ArrayAttr stride,
835  ArrayAttr pad) {
836  result.addOperands(input);
837  result.addAttribute("kernel", kernel);
838  result.addAttribute("stride", stride);
839  result.addAttribute("pad", pad);
840  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
841  if (quantAttr)
842  result.addAttribute("quantization_info", quantAttr);
843  result.types.push_back(outputType);
844 }
845 
846 /// This builder is called on single-parameter unary operators that have scale
847 /// relationship between their input and output, expressed by the
848 /// UnaryOpQuantizationAttr.
850  OperationState &result, Type outputType,
851  Value input) {
852  result.addOperands(input);
853  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
854  if (quantAttr)
855  result.addAttribute("quantization_info", quantAttr);
856  result.types.push_back(outputType);
857 }
858 
859 /// This builder is called on TOSA pad operator that needs to create its own
860 /// OptionalAttr quantization_attr parameter to scale the padding values
861 /// correctly. No pad_const is interpreted as zero-padding.
862 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
863  Type outputType, Value input,
864  Value paddings) {
865  result.addOperands({input, paddings});
866  auto quantAttr = buildPadOpQuantizationAttr(builder, input);
867  if (quantAttr)
868  result.addAttribute("quantization_info", quantAttr);
869  result.types.push_back(outputType);
870 }
871 
872 /// This builder is called on TOSA pad operator when an explicit pad_const
873 /// value is passed in. It also optionally constructs quantization_attr.
875  OperationState &result,
876  Type outputType, Value input,
877  Value paddings,
878  Value padConst) {
879  result.addOperands({input, paddings, padConst});
880  auto quantAttr = buildPadOpQuantizationAttr(builder, input);
881  if (quantAttr)
882  result.addAttribute("quantization_info", quantAttr);
883  result.types.push_back(outputType);
884 }
885 
886 //===----------------------------------------------------------------------===//
887 // TOSA Operator Return Type Inference.
888 //===----------------------------------------------------------------------===//
889 
890 static void getI64Values(ArrayAttr arrayAttr, SmallVector<int64_t> &values) {
891  for (auto it : arrayAttr) {
892  values.push_back(it.cast<IntegerAttr>().getValue().getSExtValue());
893  }
894 }
895 
896 static void getF64Values(ArrayAttr arrayAttr, SmallVector<double> &values) {
897  for (auto it : arrayAttr) {
898  values.push_back(it.cast<FloatAttr>().getValueAsDouble());
899  }
900 }
901 
902 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
903  MLIRContext *context, ::llvm::Optional<Location> location,
904  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
905  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
906  ShapeAdaptor inputShape = operands.getShape(0);
907  IntegerAttr axis = attributes.get("axis").cast<IntegerAttr>();
908  int32_t axisVal = axis.getValue().getSExtValue();
909 
910  if (!inputShape.hasRank()) {
911  inferredReturnShapes.push_back(ShapedTypeComponents());
912  return success();
913  }
914 
915  SmallVector<int64_t> outShape;
916  outShape.reserve(inputShape.getRank() - 1);
917  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
918  if (i == axisVal)
919  continue;
920  outShape.push_back(inputShape.getDimSize(i));
921  }
922 
923  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
924  return success();
925 }
926 
927 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
928  MLIRContext *context, ::llvm::Optional<Location> location,
929  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
930  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
931  // Infer all dimension sizes by reducing based on inputs.
932  int32_t axis =
933  attributes.get("axis").cast<IntegerAttr>().getValue().getSExtValue();
934  llvm::SmallVector<int64_t> outputShape;
935  bool hasRankedInput = false;
936  for (auto operand : operands) {
937  ShapeAdaptor operandShape = operands.getShape(operand);
938  if (!operandShape.hasRank())
939  continue;
940 
941  // Copy the Operand's rank.
942  if (!hasRankedInput)
943  outputShape.resize(operandShape.getRank(), ShapedType::kDynamicSize);
944 
945  // Copy shapes until the dim is non-dynamic.
946  for (int i = 0, s = operandShape.getRank(); i < s; i++) {
947  if (i == axis || operandShape.isDynamicDim(i))
948  continue;
949  if (outputShape[i] == ShapedType::kDynamicSize)
950  outputShape[i] = operandShape.getDimSize(i);
951  if (outputShape[i] != operandShape.getDimSize(i))
952  return failure();
953  }
954 
955  hasRankedInput = true;
956  }
957 
958  if (!hasRankedInput) {
959  inferredReturnShapes.push_back(ShapedTypeComponents());
960  return success();
961  }
962 
963  // Determine the dimension size along the concatenation axis.
964  int concatDimSize = 0;
965  for (auto operand : operands) {
966  ShapeAdaptor operandShape = operands.getShape(operand);
967 
968  // We need to know the length of the concatenation axis of all inputs to
969  // determine the dimension size of the output shape.
970  if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
971  concatDimSize = ShapedType::kDynamicSize;
972  break;
973  }
974 
975  concatDimSize += operandShape.getDimSize(axis);
976  }
977 
978  outputShape[axis] = concatDimSize;
979 
980  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
981  return success();
982 }
983 
984 LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
985  MLIRContext *context, ::llvm::Optional<Location> location,
986  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
987  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
988  ShapeAdaptor inputShape = operands.getShape(0);
989  ShapeAdaptor weightShape = operands.getShape(1);
990  ShapeAdaptor biasShape = operands.getShape(2);
991 
992  // All shapes are dynamic.
993  SmallVector<int64_t> outShape;
994  outShape.resize(2, ShapedType::kDynamicSize);
995 
996  if (inputShape.hasRank()) {
997  outShape[0] = inputShape.getDimSize(0);
998  }
999 
1000  if (weightShape.hasRank()) {
1001  outShape[1] = weightShape.getDimSize(0);
1002  }
1003 
1004  if (biasShape.hasRank()) {
1005  outShape[1] = outShape[1] == ShapedType::kDynamicSize
1006  ? biasShape.getDimSize(0)
1007  : outShape[1];
1008  }
1009 
1010  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1011  return success();
1012 }
1013 
1014 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1015  MLIRContext *context, ::llvm::Optional<Location> location,
1016  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1017  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1018  ShapeAdaptor lhsShape = operands.getShape(0);
1019  ShapeAdaptor rhsShape = operands.getShape(1);
1020 
1021  // All shapes are dynamic.
1022  SmallVector<int64_t> outShape;
1023  outShape.resize(3, ShapedType::kDynamicSize);
1024 
1025  if (lhsShape.hasRank()) {
1026  outShape[0] = lhsShape.getDimSize(0);
1027  outShape[1] = lhsShape.getDimSize(1);
1028  }
1029 
1030  if (rhsShape.hasRank()) {
1031  outShape[0] = outShape[0] == ShapedType::kDynamicSize
1032  ? rhsShape.getDimSize(0)
1033  : outShape[0];
1034  outShape[2] = rhsShape.getDimSize(2);
1035  }
1036 
1037  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1038  return success();
1039 }
1040 
1041 LogicalResult tosa::PadOp::inferReturnTypeComponents(
1042  MLIRContext *context, ::llvm::Optional<Location> location,
1043  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1044  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1045  ShapeAdaptor inputShape = operands.getShape(0);
1046  ShapeAdaptor paddingShape = operands.getShape(1);
1047  SmallVector<int64_t> outputShape;
1048 
1049  // If both inputs have unknown shape, we cannot determine the shape of the
1050  // output.
1051  if (!inputShape.hasRank() && !paddingShape.hasRank()) {
1052  inferredReturnShapes.push_back(ShapedTypeComponents());
1053  return success();
1054  }
1055 
1056  // If the input rank is unknown we can info the output rank using the padding
1057  // shape's first dim.
1058  if (!inputShape.hasRank()) {
1059  if (paddingShape.isDynamicDim(0)) {
1060  inferredReturnShapes.push_back(ShapedTypeComponents());
1061  return success();
1062  }
1063 
1064  outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamicSize);
1065  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1066  return success();
1067  }
1068 
1069  DenseIntElementsAttr paddings;
1070  // If the paddings value is not a constant, all dimensions must be dynamic.
1071  if (!matchPattern(operands[1], m_Constant(&paddings))) {
1072  outputShape.resize(inputShape.getRank(), ShapedType::kDynamicSize);
1073  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1074  return success();
1075  }
1076 
1077  SmallVector<int64_t> paddingValues;
1078  for (auto val : paddings) {
1079  paddingValues.push_back(val.getSExtValue());
1080  }
1081 
1082  outputShape.reserve(inputShape.getRank());
1083  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1084  if (inputShape.isDynamicDim(i)) {
1085  outputShape.push_back(ShapedType::kDynamicSize);
1086  continue;
1087  }
1088 
1089  outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] +
1090  paddingValues[i * 2 + 1]);
1091  }
1092 
1093  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1094  return success();
1095 }
1096 
1097 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
1098  MLIRContext *context, ::llvm::Optional<Location> location,
1099  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1100  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1101  ArrayAttr sizes = SliceOpAdaptor(operands, attributes).size();
1102  SmallVector<int64_t> outputShape;
1103  outputShape.reserve(sizes.size());
1104  for (auto val : sizes) {
1105  outputShape.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
1106  }
1107 
1108  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1109  return success();
1110 }
1111 
1112 LogicalResult tosa::TableOp::inferReturnTypeComponents(
1113  MLIRContext *context, ::llvm::Optional<Location> location,
1114  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1115  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1116  ShapeAdaptor inputShape = operands.getShape(0);
1117 
1118  if (!inputShape.hasRank()) {
1119  inferredReturnShapes.push_back(ShapedTypeComponents());
1120  return success();
1121  }
1122 
1123  inferredReturnShapes.resize(1);
1124  inputShape.getDims(inferredReturnShapes[0]);
1125  return success();
1126 }
1127 
1128 LogicalResult tosa::TileOp::inferReturnTypeComponents(
1129  MLIRContext *context, ::llvm::Optional<Location> location,
1130  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1131  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1132  TileOpAdaptor adaptor(operands, attributes);
1133  ArrayAttr multiples = adaptor.multiples();
1134  ShapeAdaptor inputShape = operands.getShape(0);
1135  SmallVector<int64_t> outputShape;
1136  if (!inputShape.hasRank()) {
1137  outputShape.resize(multiples.size(), ShapedType::kDynamicSize);
1138  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1139  return success();
1140  }
1141 
1142  // We need the multiple values to determine the output shape.
1143  SmallVector<int64_t> multipleValues;
1144  multipleValues.reserve(multiples.size());
1145  for (auto val : multiples) {
1146  multipleValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
1147  }
1148 
1149  // Any non dynamic dimension can be multiplied to a known size.
1150  outputShape.reserve(multiples.size());
1151  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1152  int dim = inputShape.getDimSize(i);
1153  if (dim != ShapedType::kDynamicSize)
1154  dim *= multipleValues[i];
1155  outputShape.push_back(dim);
1156  }
1157 
1158  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1159  return success();
1160 }
1161 
1162 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
1163  MLIRContext *context, ::llvm::Optional<Location> location,
1164  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1165  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1166  ReshapeOpAdaptor adaptor(operands, attributes);
1167  ShapeAdaptor inputShape = operands.getShape(0);
1168 
1169  ArrayAttr newShape = adaptor.new_shape();
1170  llvm::SmallVector<int64_t> newShapeValue;
1171  getI64Values(newShape, newShapeValue);
1172 
1173  // We cannot infer from the total number of elements so we must take the
1174  // shape attribute as exact.
1175  if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
1176  inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
1177  return success();
1178  }
1179 
1180  // Determine the number of elements covered by the slice of all static
1181  // dimensions. This allows us to infer the length of the remaining dynamic
1182  // dimension.
1183  int64_t numElements = inputShape.getNumElements();
1184  int64_t staticMul = 1;
1185  for (auto val : newShapeValue) {
1186  if (val != ShapedType::kDynamicSize) {
1187  staticMul *= val;
1188  }
1189  }
1190 
1191  // Determine the length of the dynamic dimension.
1192  for (auto &val : newShapeValue) {
1193  if (val == ShapedType::kDynamicSize)
1194  val = numElements / staticMul;
1195  }
1196 
1197  inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
1198  return success();
1199 }
1200 
1201 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
1202  MLIRContext *context, ::llvm::Optional<Location> location,
1203  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1204  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1205  ShapeAdaptor inputShape = operands.getShape(0);
1206  ShapeAdaptor permsShape = operands.getShape(1);
1207 
1208  // If input rank and permutation length is unknown, the output rank is
1209  // unknown.
1210  if (!inputShape.hasRank() || !permsShape.hasRank() ||
1211  permsShape.isDynamicDim(0)) {
1212  inferredReturnShapes.push_back(ShapedTypeComponents());
1213  return success();
1214  }
1215 
1216  // This would imply the number of permutations does not match the rank of the
1217  // input which is illegal.
1218  if (permsShape.getDimSize(0) != inputShape.getRank()) {
1219  return failure();
1220  }
1221 
1222  // Without the input dims we cannot determine the output dim sizes but we
1223  // can determine the output rank.
1224  SmallVector<int64_t> outputShape;
1225  if (!inputShape.hasRank()) {
1226  outputShape.resize(permsShape.getDimSize(0), ShapedType::kDynamicSize);
1227  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1228  return success();
1229  }
1230 
1231  // Rank-0 means no permutations matter.
1232  if (inputShape.getRank() == 0) {
1233  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1234  return success();
1235  }
1236 
1237  // Check whether the input dimensions are all the same.
1238  bool allTheSame = true;
1239  for (int i = 1, s = inputShape.getRank(); i < s; i++) {
1240  if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
1241  allTheSame = false;
1242  break;
1243  }
1244  }
1245 
1246  // If all of the input dimensions are the same we don't care about the
1247  // permutation.
1248  if (allTheSame) {
1249  outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
1250  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1251  return success();
1252  }
1253 
1254  outputShape.resize(inputShape.getRank(), ShapedType::kDynamicSize);
1255  // If the permuations are a constant we can directly determine the output
1256  // shape.
1257  if (ShapeAdaptor permShape = operands.getValueAsShape(1)) {
1258  outputShape.reserve(inputShape.getRank());
1259  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1260  outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
1261  }
1262  }
1263 
1264  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1265  return success();
1266 }
1267 
1268 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
1269  MLIRContext *context, ::llvm::Optional<Location> location,
1270  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1271  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1272  llvm::SmallVector<int64_t> outputShape;
1273  outputShape.resize(3, ShapedType::kDynamicSize);
1274 
1275  ShapeAdaptor valuesShape = operands.getShape(0);
1276  if (valuesShape.hasRank()) {
1277  outputShape[0] = valuesShape.getDimSize(0);
1278  outputShape[2] = valuesShape.getDimSize(2);
1279  }
1280 
1281  ShapeAdaptor indicesShape = operands.getShape(1);
1282  if (indicesShape.hasRank()) {
1283  if (outputShape[0] == ShapedType::kDynamicSize)
1284  outputShape[0] = indicesShape.getDimSize(0);
1285  if (outputShape[1] == ShapedType::kDynamicSize)
1286  outputShape[1] = indicesShape.getDimSize(1);
1287  }
1288 
1289  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1290  return success();
1291 }
1292 
1293 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
1294  MLIRContext *context, ::llvm::Optional<Location> location,
1295  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1296  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1297  ResizeOpAdaptor adaptor(operands, attributes);
1298  llvm::SmallVector<int64_t, 4> outputShape;
1299  outputShape.resize(4, ShapedType::kDynamicSize);
1300 
1301  int32_t inHeight = ShapedType::kDynamicSize;
1302  int32_t inWidth = ShapedType::kDynamicSize;
1303 
1304  ShapeAdaptor inputShape = operands.getShape(adaptor.input());
1305  if (inputShape.hasRank()) {
1306  outputShape[0] = inputShape.getDimSize(0);
1307  outputShape[3] = inputShape.getDimSize(3);
1308 
1309  inHeight = inputShape.getDimSize(1);
1310  inWidth = inputShape.getDimSize(2);
1311  }
1312 
1313  int32_t shift = adaptor.shift();
1314  llvm::SmallVector<int64_t> newShape;
1315  getI64Values(adaptor.output_size(), newShape);
1316  outputShape[1] = newShape[0];
1317  outputShape[2] = newShape[1];
1318 
1319  llvm::SmallVector<int64_t> strideInt;
1320  llvm::SmallVector<int64_t> offsetInt;
1321  llvm::SmallVector<double> strideFp;
1322  llvm::SmallVector<double> offsetFp;
1323  getI64Values(adaptor.offset(), offsetInt);
1324  getF64Values(adaptor.offset_fp(), offsetFp);
1325  getI64Values(adaptor.stride(), strideInt);
1326  getF64Values(adaptor.stride_fp(), strideFp);
1327 
1328  // If we have a 0 zero in integers we know that the resize indexing needs to
1329  // be performed in floating point. Use the floating point varient to compute
1330  // the resize shape.
1331  bool fpMode = strideInt[0] == 0;
1332 
1333  // We can compute the output shape if attribute specifies unknown dimensions
1334  // based on the offset and stride. If we perfectly line up to the last index
1335  // we need to round up the size to include it.
1336  if (outputShape[1] == ShapedType::kDynamicSize && inHeight >= 0 && fpMode) {
1337  float sizeFp = (inHeight - offsetFp[0] - 1) / strideFp[0];
1338  float round = std::floor(sizeFp) == sizeFp ? 1 : 0;
1339  outputShape[1] = std::ceil(sizeFp) + round;
1340  }
1341 
1342  if (outputShape[2] == ShapedType::kDynamicSize && inWidth >= 0 && fpMode) {
1343  float sizeFp = (inWidth - offsetFp[1] - 1) / strideFp[1];
1344  float round = std::floor(sizeFp) == sizeFp ? 1 : 0;
1345  outputShape[2] = std::ceil(sizeFp) + round;
1346  }
1347 
1348  if (outputShape[1] == ShapedType::kDynamicSize && inHeight >= 0 && !fpMode) {
1349  int64_t size = (inHeight - 1);
1350  size = ((size << shift) - offsetInt[0]) / strideInt[0];
1351  outputShape[1] = size + 1;
1352  }
1353 
1354  if (outputShape[2] == ShapedType::kDynamicSize && inWidth >= 0 && !fpMode) {
1355  int64_t size = (inWidth - 1);
1356  size = ((size << shift) - offsetInt[1]) / strideInt[1];
1357  outputShape[2] = size + 1;
1358  }
1359 
1360  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1361  return success();
1362 }
1363 
1364 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
1365  MLIRContext *context, ::llvm::Optional<Location> location,
1366  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1367  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1368  llvm::SmallVector<int64_t> outputShape;
1369  outputShape.resize(3, ShapedType::kDynamicSize);
1370 
1371  ShapeAdaptor valuesInShape = operands.getShape(0);
1372  if (valuesInShape.hasRank()) {
1373  outputShape[0] = valuesInShape.getDimSize(0);
1374  outputShape[1] = valuesInShape.getDimSize(1);
1375  outputShape[2] = valuesInShape.getDimSize(2);
1376  }
1377 
1378  ShapeAdaptor indicesShape = operands.getShape(1);
1379  if (indicesShape.hasRank()) {
1380  if (outputShape[0] == ShapedType::kDynamicSize)
1381  outputShape[0] = indicesShape.getDimSize(0);
1382  }
1383 
1384  ShapeAdaptor inputShape = operands.getShape(2);
1385  if (inputShape.hasRank()) {
1386  if (outputShape[0] == ShapedType::kDynamicSize)
1387  outputShape[0] = inputShape.getDimSize(0);
1388  if (outputShape[2] == ShapedType::kDynamicSize)
1389  outputShape[2] = inputShape.getDimSize(2);
1390  }
1391 
1392  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1393  return success();
1394 }
1395 
1397  ShapeAdaptor operandShape, IntegerAttr axis,
1398  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1399  if (!operandShape.hasRank()) {
1400  inferredReturnShapes.push_back(ShapedTypeComponents());
1401  return success();
1402  }
1403 
1404  SmallVector<int64_t> outputShape;
1405  operandShape.getDims(outputShape);
1406  int64_t axisVal = axis.getValue().getSExtValue();
1407  outputShape[axisVal] = 1;
1408  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1409  return success();
1410 }
1411 
1412 #define REDUCE_SHAPE_INFER(OP) \
1413  LogicalResult OP::inferReturnTypeComponents( \
1414  MLIRContext *context, ::llvm::Optional<Location> location, \
1415  ValueShapeRange operands, DictionaryAttr attributes, \
1416  RegionRange regions, \
1417  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
1418  return ReduceInferReturnTypes(operands.getShape(0), \
1419  attributes.get("axis").cast<IntegerAttr>(), \
1420  inferredReturnShapes); \
1421  }
1422 
1423 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
1424 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
1425 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
1426 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
1427 REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
1428 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
1429 #undef REDUCE_SHAPE_INFER
1430 
1432  SmallVector<int64_t> &outShape) {
1433  int64_t outRank = 0;
1434  for (int i = 0, e = operands.size(); i != e; ++i) {
1435  auto shape = operands.getShape(i);
1436  if (!shape.hasRank()) {
1437  return failure();
1438  }
1439  outRank = std::max<int64_t>(outRank, shape.getRank());
1440  }
1441 
1442  outShape.resize(outRank, 1);
1443 
1444  for (int i = 0, e = operands.size(); i != e; ++i) {
1445  auto shape = operands.getShape(i);
1446  auto rankDiff = outShape.size() - shape.getRank();
1447 
1448  for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
1449  auto dim1 = outShape[i + rankDiff];
1450  auto dim2 = shape.getDimSize(i);
1451  auto resolvedDim = dim1;
1452 
1453  if (dim1 == 1) {
1454  resolvedDim = dim2;
1455  } else if (dim2 == 1) {
1456  resolvedDim = dim1;
1457  } else if (dim1 != dim2) {
1458  return failure();
1459  }
1460  outShape[i + rankDiff] = resolvedDim;
1461  }
1462  }
1463 
1464  return success();
1465 }
1466 
1468  const ValueShapeRange &operands,
1469  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1470  llvm::SmallVector<int64_t> outShape;
1471  if (resolveBroadcastShape(operands, outShape).failed()) {
1472  inferredReturnShapes.push_back(ShapedTypeComponents());
1473  } else {
1474  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1475  }
1476  return success();
1477 }
1478 
1479 #define NARY_SHAPE_INFER(OP) \
1480  LogicalResult OP::inferReturnTypeComponents( \
1481  MLIRContext *context, ::llvm::Optional<Location> location, \
1482  ValueShapeRange operands, DictionaryAttr attributes, \
1483  RegionRange regions, \
1484  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
1485  return NAryInferReturnTypes(operands, inferredReturnShapes); \
1486  }
1487 
1488 NARY_SHAPE_INFER(tosa::AbsOp)
1489 NARY_SHAPE_INFER(tosa::AddOp)
1490 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
1491 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
1492 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
1493 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
1494 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
1495 NARY_SHAPE_INFER(tosa::CastOp)
1496 NARY_SHAPE_INFER(tosa::CeilOp)
1497 NARY_SHAPE_INFER(tosa::ClampOp)
1498 NARY_SHAPE_INFER(tosa::ClzOp)
1499 NARY_SHAPE_INFER(tosa::DivOp)
1500 NARY_SHAPE_INFER(tosa::EqualOp)
1501 NARY_SHAPE_INFER(tosa::ExpOp)
1502 NARY_SHAPE_INFER(tosa::FloorOp)
1503 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
1504 NARY_SHAPE_INFER(tosa::GreaterOp)
1505 NARY_SHAPE_INFER(tosa::IdentityOp)
1506 NARY_SHAPE_INFER(tosa::LogOp)
1507 NARY_SHAPE_INFER(tosa::LogicalAndOp)
1508 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
1509 NARY_SHAPE_INFER(tosa::LogicalNotOp)
1510 NARY_SHAPE_INFER(tosa::LogicalOrOp)
1511 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
1512 NARY_SHAPE_INFER(tosa::LogicalXorOp)
1513 NARY_SHAPE_INFER(tosa::MaximumOp)
1514 NARY_SHAPE_INFER(tosa::MinimumOp)
1515 NARY_SHAPE_INFER(tosa::MulOp)
1516 NARY_SHAPE_INFER(tosa::NegateOp)
1517 NARY_SHAPE_INFER(tosa::PowOp)
1518 NARY_SHAPE_INFER(tosa::ReciprocalOp)
1519 NARY_SHAPE_INFER(tosa::ReluNOp)
1520 NARY_SHAPE_INFER(tosa::RescaleOp)
1521 NARY_SHAPE_INFER(tosa::ReverseOp)
1522 NARY_SHAPE_INFER(tosa::RsqrtOp)
1523 NARY_SHAPE_INFER(tosa::SelectOp)
1524 NARY_SHAPE_INFER(tosa::SubOp)
1525 NARY_SHAPE_INFER(tosa::TanhOp)
1526 NARY_SHAPE_INFER(tosa::SigmoidOp)
1527 #undef PRED_SHAPE_INFER
1528 
1530  const ValueShapeRange &operands, DictionaryAttr attributes,
1531  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1532  ShapeAdaptor inputShape = operands.getShape(0);
1533  llvm::SmallVector<int64_t> outputShape;
1534  outputShape.resize(4, -1);
1535 
1536  // We only know the rank if the input type is unranked.
1537  if (!inputShape) {
1538  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1539  return success();
1540  }
1541 
1542  // Batch and number of channels are identical for pooling layer.
1543  outputShape[0] = inputShape.getDimSize(0);
1544  outputShape[3] = inputShape.getDimSize(3);
1545 
1546  int32_t height = inputShape.getDimSize(1);
1547  int32_t width = inputShape.getDimSize(2);
1548 
1552 
1553  getI64Values(attributes.get("kernel").cast<ArrayAttr>(), kernel);
1554  getI64Values(attributes.get("stride").cast<ArrayAttr>(), stride);
1555  getI64Values(attributes.get("pad").cast<ArrayAttr>(), pad);
1556 
1557  if (height != -1) {
1558  int32_t padded = height + pad[0] + pad[1] - kernel[0];
1559  outputShape[1] = padded / stride[0] + 1;
1560  }
1561 
1562  if (width != -1) {
1563  int32_t padded = width + pad[2] + pad[3] - kernel[1];
1564  outputShape[2] = padded / stride[1] + 1;
1565  }
1566 
1567  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1568  return success();
1569 }
1570 
1571 LogicalResult Conv2DOp::inferReturnTypeComponents(
1572  MLIRContext *context, ::llvm::Optional<Location> location,
1573  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1574  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1575  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamicSize);
1576  Conv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1577 
1578  int32_t inputWidth = ShapedType::kDynamicSize;
1579  int32_t inputHeight = ShapedType::kDynamicSize;
1580  int32_t weightWidth = ShapedType::kDynamicSize;
1581  int32_t weightHeight = ShapedType::kDynamicSize;
1582 
1583  // Input shape describes input width/height and batch.
1584 
1585  ShapeAdaptor inputShape = operands.getShape(adaptor.input());
1586  if (inputShape.hasRank()) {
1587  outputShape[0] = inputShape.getDimSize(0);
1588  inputHeight = inputShape.getDimSize(1);
1589  inputWidth = inputShape.getDimSize(2);
1590  }
1591 
1592  // Weight shapes describes the filter width/height and the output channels.
1593  ShapeAdaptor weightShape = operands.getShape(adaptor.weight());
1594  if (weightShape.hasRank()) {
1595  outputShape[3] = weightShape.getDimSize(0);
1596  weightHeight = weightShape.getDimSize(1);
1597  weightWidth = weightShape.getDimSize(2);
1598  }
1599 
1600  // Bias shape can describe the output channels.
1601  ShapeAdaptor biasShape = operands.getShape(adaptor.bias());
1602  if (biasShape.hasRank()) {
1603  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1604  ? biasShape.getDimSize(0)
1605  : outputShape[3];
1606  }
1607 
1608  llvm::SmallVector<int64_t> dilation;
1611 
1612  getI64Values(adaptor.dilation(), dilation);
1613  getI64Values(adaptor.pad(), padding);
1614  getI64Values(adaptor.stride(), stride);
1615 
1616  if (!ShapedType::isDynamic(inputHeight) &&
1617  !ShapedType::isDynamic(weightHeight)) {
1618  int32_t inputSize = inputHeight + padding[0] + padding[1];
1619  int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1620  int32_t unstridedResult = inputSize - filterSize + 1;
1621  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1622  }
1623 
1624  if (!ShapedType::isDynamic(inputWidth) &&
1625  !ShapedType::isDynamic(weightWidth)) {
1626  int32_t inputSize = inputWidth + padding[2] + padding[3];
1627  int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1628  int32_t unstridedResult = inputSize - filterSize + 1;
1629  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1630  }
1631 
1632  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1633  return success();
1634 }
1635 
1636 LogicalResult Conv3DOp::inferReturnTypeComponents(
1637  MLIRContext *context, ::llvm::Optional<Location> location,
1638  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1639  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1640  llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamicSize);
1641  Conv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1642 
1643  int32_t inputWidth = ShapedType::kDynamicSize;
1644  int32_t inputHeight = ShapedType::kDynamicSize;
1645  int32_t inputDepth = ShapedType::kDynamicSize;
1646 
1647  int32_t weightWidth = ShapedType::kDynamicSize;
1648  int32_t weightHeight = ShapedType::kDynamicSize;
1649  int32_t weightDepth = ShapedType::kDynamicSize;
1650 
1651  // Input shape describes input width/height and batch.
1652  ShapeAdaptor inputShape = operands.getShape(adaptor.input());
1653  if (inputShape.hasRank()) {
1654  outputShape[0] = inputShape.getDimSize(0);
1655  inputHeight = inputShape.getDimSize(1);
1656  inputWidth = inputShape.getDimSize(2);
1657  inputDepth = inputShape.getDimSize(3);
1658  }
1659 
1660  // Weight shapes describes the filter width/height and the output channels.
1661  ShapeAdaptor weightShape = operands.getShape(adaptor.weight());
1662  if (weightShape.hasRank()) {
1663  outputShape[4] = weightShape.getDimSize(0);
1664  weightHeight = weightShape.getDimSize(1);
1665  weightWidth = weightShape.getDimSize(2);
1666  weightDepth = weightShape.getDimSize(3);
1667  }
1668 
1669  // Bias shape can describe the output channels.
1670  ShapeAdaptor biasShape = operands.getShape(adaptor.bias());
1671  if (biasShape.hasRank()) {
1672  outputShape[4] =
1673  (outputShape[4] == -1) ? biasShape.getDimSize(0) : outputShape[4];
1674  }
1675 
1676  llvm::SmallVector<int64_t> dilation;
1679 
1680  getI64Values(adaptor.dilation(), dilation);
1681  getI64Values(adaptor.pad(), padding);
1682  getI64Values(adaptor.stride(), stride);
1683 
1684  if (!ShapedType::isDynamic(inputHeight) &&
1685  !ShapedType::isDynamic(weightHeight)) {
1686  int32_t inputSize = inputHeight + padding[0] + padding[1];
1687  int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1688  int32_t unstridedResult = inputSize - filterSize + 1;
1689  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1690  }
1691 
1692  if (!ShapedType::isDynamic(inputWidth) &&
1693  !ShapedType::isDynamic(weightWidth)) {
1694  int32_t inputSize = inputWidth + padding[2] + padding[3];
1695  int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1696  int32_t unstridedResult = inputSize - filterSize + 1;
1697  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1698  }
1699 
1700  if (!ShapedType::isDynamic(inputDepth) &&
1701  !ShapedType::isDynamic(weightDepth)) {
1702  int32_t inputSize = inputDepth + padding[4] + padding[5];
1703  int32_t filterSize = (weightDepth - 1) * dilation[2] + 1;
1704  int32_t unstridedResult = inputSize - filterSize + 1;
1705  outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
1706  }
1707 
1708  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1709  return success();
1710 }
1711 
1712 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
1713  MLIRContext *context, ::llvm::Optional<Location> location,
1714  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1715  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1716  return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
1717 }
1718 
1719 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
1720  MLIRContext *context, ::llvm::Optional<Location> location,
1721  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1722  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1723  return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
1724 }
1725 
1726 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
1727  MLIRContext *context, ::llvm::Optional<Location> location,
1728  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1729  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1730  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamicSize);
1731  DepthwiseConv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1732 
1733  int32_t inputWidth = ShapedType::kDynamicSize;
1734  int32_t inputHeight = ShapedType::kDynamicSize;
1735  int32_t inputChannels = ShapedType::kDynamicSize;
1736 
1737  int32_t weightWidth = ShapedType::kDynamicSize;
1738  int32_t weightHeight = ShapedType::kDynamicSize;
1739  int32_t depthChannels = ShapedType::kDynamicSize;
1740 
1741  // Input shape describes input width/height and batch.
1742  ShapeAdaptor inputShape = operands.getShape(adaptor.input());
1743  if (inputShape.hasRank()) {
1744  outputShape[0] = inputShape.getDimSize(0);
1745  inputHeight = inputShape.getDimSize(1);
1746  inputWidth = inputShape.getDimSize(2);
1747  inputChannels = inputShape.getDimSize(3);
1748  }
1749 
1750  // Weight shapes describes the filter width/height and the output channels.
1751  ShapeAdaptor weightShape = operands.getShape(adaptor.weight());
1752  if (weightShape.hasRank()) {
1753  weightHeight = weightShape.getDimSize(0);
1754  weightWidth = weightShape.getDimSize(1);
1755  inputChannels = ShapedType::isDynamic(inputChannels)
1756  ? weightShape.getDimSize(2)
1757  : inputChannels;
1758  depthChannels = weightShape.getDimSize(3);
1759  }
1760 
1761  // If both inputChannels and depthChannels are available we can determine
1762  // the output channels.
1763  if (!ShapedType::isDynamic(inputChannels) &&
1764  !ShapedType::isDynamic(depthChannels)) {
1765  outputShape[3] = inputChannels * depthChannels;
1766  }
1767 
1768  // Bias shape can describe the output channels.
1769  ShapeAdaptor biasShape = operands.getShape(adaptor.bias());
1770  if (biasShape.hasRank()) {
1771  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1772  ? biasShape.getDimSize(0)
1773  : outputShape[3];
1774  }
1775 
1776  llvm::SmallVector<int64_t> dilation;
1779 
1780  getI64Values(adaptor.dilation(), dilation);
1781  getI64Values(adaptor.pad(), padding);
1782  getI64Values(adaptor.stride(), stride);
1783 
1784  if (!ShapedType::isDynamic(inputHeight) &&
1785  !ShapedType::isDynamic(weightHeight)) {
1786  int32_t inputSize = inputHeight + padding[0] + padding[1];
1787  int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1788  int32_t unstridedResult = inputSize - filterSize + 1;
1789  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1790  }
1791 
1792  if (!ShapedType::isDynamic(inputWidth) &&
1793  !ShapedType::isDynamic(weightWidth)) {
1794  int32_t inputSize = inputWidth + padding[2] + padding[3];
1795  int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1796  int32_t unstridedResult = inputSize - filterSize + 1;
1797  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1798  }
1799 
1800  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1801  return success();
1802 }
1803 
1804 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
1805  MLIRContext *context, ::llvm::Optional<Location> location,
1806  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1807  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1808  TransposeConv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1809  llvm::SmallVector<int64_t> outputShape;
1810  getI64Values(adaptor.out_shape(), outputShape);
1811 
1812  int32_t inputWidth = ShapedType::kDynamicSize;
1813  int32_t inputHeight = ShapedType::kDynamicSize;
1814  int32_t weightWidth = ShapedType::kDynamicSize;
1815  int32_t weightHeight = ShapedType::kDynamicSize;
1816 
1817  // Input shape describes input width/height and batch.
1818  ShapeAdaptor inputShape = operands.getShape(adaptor.input());
1819  if (inputShape.hasRank()) {
1820  outputShape[0] = ShapedType::isDynamic(outputShape[0])
1821  ? inputShape.getDimSize(0)
1822  : outputShape[0];
1823  inputHeight = inputShape.getDimSize(1);
1824  inputWidth = inputShape.getDimSize(2);
1825  }
1826 
1827  // Weight shapes describes the filter width/height and the output channels.
1828  ShapeAdaptor weightShape = operands.getShape(adaptor.filter());
1829  if (weightShape.hasRank()) {
1830  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1831  ? weightShape.getDimSize(0)
1832  : outputShape[3];
1833  weightHeight = weightShape.getDimSize(1);
1834  weightWidth = weightShape.getDimSize(2);
1835  }
1836 
1837  // Bias shape can describe the output channels.
1838  ShapeAdaptor biasShape = operands.getShape(adaptor.input());
1839  if (biasShape.hasRank()) {
1840  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1841  ? biasShape.getDimSize(0)
1842  : outputShape[3];
1843  }
1844 
1845  llvm::SmallVector<int64_t> dilation;
1848 
1849  getI64Values(adaptor.dilation(), dilation);
1850  getI64Values(adaptor.out_pad(), padding);
1851  getI64Values(adaptor.stride(), stride);
1852 
1853  if (!ShapedType::isDynamic(inputHeight) &&
1854  !ShapedType::isDynamic(weightHeight)) {
1855  int32_t dilated = (weightHeight - 1) * dilation[0] + 1;
1856  int32_t calculateSize =
1857  (inputHeight - 1) * stride[0] - padding[0] + dilated;
1858  outputShape[1] = outputShape[1] == -1 ? calculateSize : outputShape[1];
1859  }
1860 
1861  if (!ShapedType::isDynamic(inputWidth) &&
1862  !ShapedType::isDynamic(weightWidth)) {
1863  int32_t dilated = (weightWidth - 1) * dilation[1] + 1;
1864  int32_t calculateSize = (inputWidth - 1) * stride[1] - padding[1] + dilated;
1865  outputShape[2] = outputShape[2] == -1 ? calculateSize : outputShape[2];
1866  }
1867 
1868  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1869  return success();
1870 }
1871 
1872 LogicalResult IfOp::inferReturnTypeComponents(
1873  MLIRContext *context, ::llvm::Optional<Location> location,
1874  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1875  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1877  for (Region *region : regions) {
1878  for (auto &block : *region)
1879  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1880  yieldOps.push_back(returnOp);
1881  }
1882 
1883  if (yieldOps.empty())
1884  return failure();
1885 
1886  // Get the initial type information for the yield op.
1887  llvm::SmallVector<ValueKnowledge> resultKnowledge;
1888  resultKnowledge.reserve(yieldOps.front().getNumOperands());
1889  for (auto operand : yieldOps.front().getOperands()) {
1890  resultKnowledge.push_back(
1891  ValueKnowledge::getKnowledgeFromType(operand.getType()));
1892  }
1893 
1894  for (auto yieldOp : yieldOps) {
1895  if (resultKnowledge.size() != yieldOp.getNumOperands())
1896  return failure();
1897 
1898  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
1899  int32_t index = it.index();
1900  auto meet = ValueKnowledge::meet(
1901  resultKnowledge[index],
1902  ValueKnowledge::getKnowledgeFromType(it.value().getType()));
1903  if (!meet)
1904  continue;
1905  resultKnowledge[index] = meet;
1906  }
1907  }
1908 
1909  for (const ValueKnowledge &result : resultKnowledge) {
1910  inferredReturnShapes.push_back(result.getShapedTypeComponents());
1911  }
1912 
1913  return success();
1914 }
1915 
1916 LogicalResult WhileOp::inferReturnTypeComponents(
1917  MLIRContext *context, ::llvm::Optional<Location> location,
1918  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1919  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1921  for (auto &block : *regions[1])
1922  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1923  yieldOps.push_back(returnOp);
1924 
1925  // TOSA's while must have a tosa.yield as its terminator. If not found this
1926  // tosa.while is invalid.
1927  if (yieldOps.empty())
1928  return failure();
1929 
1930  // Get the initial type information from the operand types.
1931  llvm::SmallVector<ValueKnowledge> resultKnowledge;
1932  resultKnowledge.reserve(yieldOps.front().getNumOperands());
1933  for (auto operand : yieldOps.front().getOperands()) {
1934  resultKnowledge.push_back(
1935  ValueKnowledge::getKnowledgeFromType(operand.getType()));
1936  }
1937 
1938  for (auto yieldOp : yieldOps) {
1939  if (resultKnowledge.size() != yieldOp.getNumOperands())
1940  return failure();
1941 
1942  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
1943  int32_t index = it.index();
1944  if (auto meet = ValueKnowledge::meet(
1945  resultKnowledge[index],
1946  ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
1947  resultKnowledge[index] = meet;
1948  };
1949  }
1950  }
1951 
1952  for (const ValueKnowledge &result : resultKnowledge) {
1953  inferredReturnShapes.push_back(result.getShapedTypeComponents());
1954  }
1955 
1956  return success();
1957 }
1958 
1959 //===----------------------------------------------------------------------===//
1960 // TOSA Operator Definitions.
1961 //===----------------------------------------------------------------------===//
1962 
1963 #define GET_OP_CLASSES
1964 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, BlockAndValueMapping &valueMapping)
Utility to check that all of the operations within &#39;src&#39; can be inlined.
Include the generated interface declarations.
LogicalResult matchAndRewrite(tosa::MulOp op, PatternRewriter &rewriter) const override
Definition: TosaOps.cpp:333
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
U cast() const
Definition: Attributes.h:123
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
Definition: TosaOps.cpp:532
int64_t ceil(Fraction f)
Definition: Fraction.h:57
static Value min(ImplicitLocOpBuilder &builder, Value a, Value b)
Statically known information for a particular Value.
Definition: ShapeUtils.h:33
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class represents a single result from folding an operation.
Definition: OpDefinition.h:244
LogicalResult matchAndRewrite(tosa::ReshapeOp op, PatternRewriter &rewriter) const override
Definition: TosaOps.cpp:160
ShapedTypeComponents that represents the components of a ShapedType.
Represents a family of uniform, quantized types.
Definition: QuantTypes.h:256
bool isa() const
Definition: Attributes.h:107
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
Definition: QuantUtils.cpp:165
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:136
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
int64_t floor(Fraction f)
Definition: Fraction.h:55
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:639
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape...
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
Definition: TosaOps.cpp:200
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition: TosaOps.cpp:795
detail::constant_int_value_matcher< 1 > m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:243
static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d)
Helper method to apply dimension ordering permutation.
#define REDUCE_SHAPE_INFER(OP)
Definition: TosaOps.cpp:1412
int64_t getNumElements() const
Returns the number of elements in the shape.
static void getI64Values(ArrayAttr arrayAttr, SmallVector< int64_t > &values)
Definition: TosaOps.cpp:890
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
Definition: QuantUtils.cpp:234
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
Range of values and shapes (corresponding effectively to Shapes dialect&#39;s ValueShape type concept)...
static void buildUnaryOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter unary operators that have scale relationship between their...
Definition: TosaOps.cpp:849
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:1396
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, ArrayAttr pad, ArrayAttr stride, ArrayAttr dilation)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition: TosaOps.cpp:732
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:193
LogicalResult matchAndRewrite(tosa::ReshapeOp op, PatternRewriter &rewriter) const override
Definition: TosaOps.cpp:140
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
An attribute that represents a reference to a dense vector or tensor object.
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
void addOperands(ValueRange newOperands)
ShapeAdaptor getValueAsShape(int index)
Returns an argument as shape.
iterator_range_impl< ElementIterator< T > > getValues() const
LogicalResult matchAndRewrite(tosa::AddOp op, PatternRewriter &rewriter) const override
Definition: TosaOps.cpp:296
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:1467
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:170
U dyn_cast() const
Definition: Types.h:244
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:99
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, ArrayAttr outpad, ArrayAttr stride, ArrayAttr dilation, ArrayAttr outputShape)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition: TosaOps.cpp:754
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:199
#define NARY_SHAPE_INFER(OP)
Definition: TosaOps.cpp:1479
static LogicalResult poolingInferReturnTypes(const ValueShapeRange &operands, DictionaryAttr attributes, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:1529
Attributes are known-constant values of operations.
Definition: Attributes.h:24
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:44
#define ReduceFolder(OP)
Definition: TosaOps.cpp:580
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
bool hasRank() const
Returns whether the shape has a rank.
This is the interface that must be implemented by the dialects of operations to be inlined...
Definition: InliningUtils.h:41
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition: TosaOps.cpp:862
void addTypes(ArrayRef< Type > newTypes)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:360
ShapeAdaptor getShape(int index) const
Returns the shape of index&#39;th operand.
bool hasStaticShape() const
Returns whether the shape is fully static.
static LogicalResult verifyConvOp(T op)
Definition: TosaOps.cpp:664
This represents an operation in an abstracted form, suitable for use with the builder APIs...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:231
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition: TosaOps.cpp:1431
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const override
Definition: TosaOps.cpp:429
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Definition: QuantUtils.cpp:116
static bool isZero(Value v)
Definition: Tiling.cpp:36
LogicalResult matchAndRewrite(tosa::PadOp op, PatternRewriter &rewriter) const override
Definition: TosaOps.cpp:382
LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override
Definition: TosaOps.cpp:464
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
bool isDynamicDim(int index) const
Returns whether the index&#39;th dimension is dynamic.
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings, Value padConst)
This builder is called on TOSA pad operator when an explicit pad_const value is passed in...
Definition: TosaOps.cpp:874
Type getType() const
Return the type of this value.
Definition: Value.h:117
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:266
static LogicalResult verifyAveragePoolOp(tosa::AvgPool2dOp op)
Definition: TosaOps.cpp:705
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, ArrayAttr kernel, ArrayAttr stride, ArrayAttr pad)
Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr but avg_pool operator has...
Definition: TosaOps.cpp:831
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:322
int64_t getDimSize(int index) const
Returns the size of the index&#39;th dimension.
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45
LogicalResult matchAndRewrite(tosa::ConcatOp op, PatternRewriter &rewriter) const override
Definition: TosaOps.cpp:115
static void getF64Values(ArrayAttr arrayAttr, SmallVector< double > &values)
Definition: TosaOps.cpp:896
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Definition: QuantUtils.cpp:201
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Operation *op, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
Definition: PatternMatch.h:802
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
Definition: QuantUtils.cpp:259
ValueRange getValues() const
Returns the Values in the ValueRange.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
FloatAttr getF32FloatAttr(float value)
Definition: Builders.cpp:185
This class helps build Operations.
Definition: Builders.h:177
std::enable_if<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T >::type getSplatValue() const
Return the splat value for this attribute.
This class provides an abstraction over the different types of ranges over Values.
LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override
Definition: TosaOps.cpp:263
IntegerType getI32Type()
Definition: Builders.cpp:54
static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias)
The tosa.fully_connected op has its own builder as it does not have strides/dilation/padding.
Definition: TosaOps.cpp:776
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type. ...
Definition: FoldUtils.cpp:50
An attribute that represents a reference to a dense integer vector or tensor object.
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
U cast() const
Definition: Types.h:250
static Value max(ImplicitLocOpBuilder &builder, Value a, Value b)
SmallVector< Type, 4 > types
Types of the results of this operation.
int64_t getRank() const
Returns the rank of the shape.