MLIR  21.0.0git
TosaOps.cpp
Go to the documentation of this file.
1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://developer.mlplatform.org/w/tosa/
12 //
13 //===----------------------------------------------------------------------===//
14 
22 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 #include <numeric>
34 
35 using namespace mlir;
36 using namespace mlir::tosa;
37 
38 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
40 
41 //===----------------------------------------------------------------------===//
42 // Tosa dialect interface includes.
43 //===----------------------------------------------------------------------===//
44 
45 #include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
46 #include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
47 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
48 #include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
49 
50 namespace {
51 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
52 
53 //===----------------------------------------------------------------------===//
54 // Dialect Function Inliner Interface.
55 //===----------------------------------------------------------------------===//
56 struct TosaInlinerInterface : public DialectInlinerInterface {
58 
59  //===--------------------------------------------------------------------===//
60  // Analysis Hooks.
61  //===--------------------------------------------------------------------===//
62 
63  /// All operations can be inlined by default.
64  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
65  IRMapping &map) const final {
66  return true;
67  }
68 
69  /// All regions with If and While parent operators can be inlined.
70  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
71  IRMapping &map) const final {
72  return (isa<tosa::IfOp>(dest->getParentOp()) ||
73  isa<tosa::WhileOp>(dest->getParentOp()));
74  }
75 };
76 
77 /// This class implements the bytecode interface for the Tosa dialect.
78 struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
79  TosaDialectBytecodeInterface(Dialect *dialect)
80  : BytecodeDialectInterface(dialect) {}
81 
82  //===--------------------------------------------------------------------===//
83  // Attributes
84 
85  Attribute readAttribute(DialectBytecodeReader &reader) const override {
86  return ::readAttribute(getContext(), reader);
87  }
88 
89  LogicalResult writeAttribute(Attribute attr,
90  DialectBytecodeWriter &writer) const override {
91  return ::writeAttribute(attr, writer);
92  }
93 
94  //===--------------------------------------------------------------------===//
95  // Types
96 
97  Type readType(DialectBytecodeReader &reader) const override {
98  return ::readType(getContext(), reader);
99  }
100 
101  LogicalResult writeType(Type type,
102  DialectBytecodeWriter &writer) const override {
103  return ::writeType(type, writer);
104  }
105 
106  void writeVersion(DialectBytecodeWriter &writer) const final {
107  // TODO: Populate.
108  }
109 
110  std::unique_ptr<DialectVersion>
111  readVersion(DialectBytecodeReader &reader) const final {
112  // TODO: Populate
113  reader.emitError("Dialect does not support versioning");
114  return nullptr;
115  }
116 
117  LogicalResult upgradeFromVersion(Operation *topLevelOp,
118  const DialectVersion &version) const final {
119  return success();
120  }
121 };
122 
123 } // namespace
124 
125 //===----------------------------------------------------------------------===//
126 // TOSA control flow support.
127 //===----------------------------------------------------------------------===//
128 
129 /// Returns the while loop body.
130 SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
131  return {&getBodyGraph()};
132 }
133 
134 //===----------------------------------------------------------------------===//
135 // Tosa dialect initialization.
136 //===----------------------------------------------------------------------===//
137 
138 void TosaDialect::initialize() {
139  addTypes<
140 #define GET_TYPEDEF_LIST
141 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
142  >();
143  addOperations<
144 #define GET_OP_LIST
145 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
146  >();
147  addAttributes<
148 #define GET_ATTRDEF_LIST
149 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
150  >();
151  addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
152  declarePromisedInterfaces<
153  mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
154  ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
155  LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
156  LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
157  BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
158  NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
159  GreaterEqualOp, MatMulOp>();
160 }
161 
163  Type type, Location loc) {
164  // Tosa dialect constants only support ElementsAttr unlike standard dialect
165  // constant which supports all attributes.
166  if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
167  return builder.create<tosa::ConstShapeOp>(
168  loc, type, llvm::cast<DenseIntElementsAttr>(value));
169  }
170  if (llvm::isa<ElementsAttr>(value))
171  return builder.create<tosa::ConstOp>(loc, type,
172  llvm::cast<ElementsAttr>(value));
173  return nullptr;
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // Parsers and printers
178 //===----------------------------------------------------------------------===//
179 
180 ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
181  Attribute &attr) {
182  if (succeeded(parser.parseOptionalEqual())) {
183  if (failed(parser.parseAttribute(attr))) {
184  return parser.emitError(parser.getCurrentLocation())
185  << "expected attribute";
186  }
187  if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
188  typeAttr = TypeAttr::get(typedAttr.getType());
189  }
190  return success();
191  }
192 
193  Type type;
194  if (failed(parser.parseColonType(type))) {
195  return parser.emitError(parser.getCurrentLocation()) << "expected type";
196  }
197  typeAttr = TypeAttr::get(type);
198 
199  return success();
200 }
201 
203  Attribute attr) {
204  bool needsSpace = false;
205  auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
206  if (!typedAttr || typedAttr.getType() != type.getValue()) {
207  p << ": ";
208  p.printAttribute(type);
209  needsSpace = true; // subsequent attr value needs a space separator
210  }
211  if (attr) {
212  if (needsSpace)
213  p << ' ';
214  p << "= ";
215  p.printAttribute(attr);
216  }
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // Tosa utilities.
221 //===----------------------------------------------------------------------===//
222 
223 std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
224  if (lhs % rhs != 0)
225  return std::nullopt;
226  return lhs / rhs;
227 }
228 
230  auto srcType = getElementTypeOrSelf(type);
231  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
232  srcType = quantType.getStorageType();
233  return srcType;
234 }
235 
237  return getStorageElementTypeOrSelf(value.getType());
238 }
239 
240 static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
241  Value valZp, StringRef name) {
243  Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
244 
245  bool bothInts =
246  mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
247  bool sameBitWidth =
248  (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
249 
250  if (!bothInts || !sameBitWidth) {
251  return op->emitOpError()
252  << "expected " << name << " and " << name
253  << "_zp to both be integer of the same bitwidth, but got " << eType
254  << " vs. " << eZpType;
255  }
256  return success();
257 }
258 
259 // Create a pad-const const tensor with value of `val` of required data-type
261  Value src, int32_t val) {
262  const auto srcType = getElementTypeOrSelf(src);
263  const auto srcElemType = getStorageElementTypeOrSelf(src);
264  const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
265  const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
266  const auto padConstAttr{
267  llvm::isa<FloatType>(srcElemType)
268  ? DenseElementsAttr::get(padConstEType,
269  builder.getFloatAttr(srcElemType, val))
270  : DenseElementsAttr::get(padConstEType,
271  builder.getIntegerAttr(srcElemType, val))};
272  return builder.create<tosa::ConstOp>(loc, padConstType, padConstAttr);
273 }
274 
275 //===----------------------------------------------------------------------===//
276 // TOSA Operator Verifiers.
277 //===----------------------------------------------------------------------===//
278 
279 template <typename T>
280 static LogicalResult verifyConvOp(T op) {
281  // All TOSA conv ops have an input and weight arguments which must be ranked
282  // tensors.
283  auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
284  if (!inputType) {
285  op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
286  return failure();
287  }
288 
289  auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
290  if (!weightType) {
291  op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
292  return failure();
293  }
294 
295  auto inputEType = inputType.getElementType();
296  auto weightEType = weightType.getElementType();
297  auto biasEType =
298  llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
299  auto resultEType =
300  llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
301  bool biasIsFloat = llvm::isa<FloatType>(biasEType);
302  bool resultIsFloat = llvm::isa<FloatType>(resultEType);
303 
304  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
305  inputEType = quantType.getStorageType();
306 
307  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
308  weightEType = quantType.getStorageType();
309 
310  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
311  biasEType = quantType.getStorageType();
312 
313  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
314  resultEType = quantType.getStorageType();
315 
316  if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
317  // for now, only enforce bias element type == result element type for
318  // float types.
319  op.emitOpError(
320  "expect both bias and result to have same element type, got ")
321  << biasEType << " and " << resultEType;
322  return failure();
323  }
324 
325  if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
326  isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
327  if (inputEType != weightEType) {
328  op.emitOpError(
329  "expect both input and weight to have same element type, got ")
330  << inputEType << " and " << weightEType;
331  return failure();
332  }
333  }
334 
335  bool inputIsFloat = llvm::isa<FloatType>(inputEType);
336  bool weightIsFloat = llvm::isa<FloatType>(weightEType);
337 
338  // Either both must be float or both non-float.
339  if (inputIsFloat != weightIsFloat) {
340  op.emitOpError(
341  "expect both input and weight to be float or not together, got ")
342  << inputEType << " and " << weightEType;
343  return failure();
344  }
345 
346  auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
347  if (inputEType != inputZpEType) {
348  return op.emitOpError("expect both input and its zero point are the same "
349  "element type, got ")
350  << inputEType << " and " << inputZpEType;
351  }
352 
353  auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
354  if (weightEType != weightZpEType) {
355  return op.emitOpError("expect both weight and its zero point are the same "
356  "element type, got ")
357  << weightEType << " and " << weightZpEType;
358  }
359 
360  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
361  if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
362  return failure();
363 
364  FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
365  if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
366  return failure();
367 
368  return success();
369 }
370 
371 LogicalResult tosa::ConstOp::verify() {
372 
373  auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().getType());
374  auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
375 
376  if (!attrType || !outputType) {
377  emitOpError("expected tensors for attr/result type");
378  return failure();
379  }
380 
381  if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
382  outputType.getElementType())) {
383  if (result.getStorageType() == attrType.getElementType())
384  return success();
385  }
386 
387  if (attrType.getElementType() != outputType.getElementType()) {
388  emitOpError("expected same attr/result element types");
389  return failure();
390  }
391 
392  return success();
393 }
394 
395 template <typename T>
396 static LogicalResult verifyConvOpModes(T op) {
397  auto inputEType =
398  llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
399 
400  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
401  inputEType = quantType.getStorageType();
402 
403  auto accType = op.getAccType();
404  if (inputEType.isInteger(8) && !accType.isInteger(32))
405  return op.emitOpError("accumulator type for i8 tensor is not i32");
406 
407  if (inputEType.isInteger(16) && !accType.isInteger(48))
408  return op.emitOpError("accumulator type for i16 tensor is not i48");
409 
410  if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
411  return op.emitOpError("accumulator type for f8 tensor is not f16");
412 
413  if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
414  return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
415 
416  if (inputEType.isBF16() && !accType.isF32())
417  return op.emitOpError("accumulator type for bf16 tensor is not f32");
418 
419  if (inputEType.isF32() && !accType.isF32())
420  return op.emitOpError("accumulator type for f32 tensor is not f32");
421 
422  auto resultEType =
423  llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
424 
425  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
426  resultEType = quantType.getStorageType();
427 
428  // check allowed input/result element types combinations
429  if ((inputEType.isInteger(8) && resultEType.isInteger(32)) ||
430  (inputEType.isInteger(16) && resultEType.isInteger(48)) ||
431  (isa<Float8E5M2Type>(inputEType) && resultEType.isF16()) ||
432  (isa<Float8E4M3FNType>(inputEType) && resultEType.isF16()) ||
433  (inputEType.isF16() && resultEType.isF16()) ||
434  (inputEType.isBF16() && resultEType.isBF16()) ||
435  (inputEType.isF32() && resultEType.isF32()))
436  return success();
437 
438  return op.emitOpError("input/output element types are incompatible.");
439 }
440 
441 // verify that inType and outType have same element types
442 template <typename T>
443 static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
444  auto inputType = llvm::dyn_cast<TensorType>(inType);
445  auto outputType = llvm::dyn_cast<TensorType>(outType);
446  if (!inputType) {
447  op.emitOpError("expect shaped tensor for input, got ") << inType;
448  return failure();
449  }
450  if (!outputType) {
451  op.emitOpError("expect shaped tensor for output, got ") << outType;
452  return failure();
453  }
454  auto inputElementType = inputType.getElementType();
455  auto outputElementType = outputType.getElementType();
456  auto inputQuantType =
457  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
458  auto outputQuantType =
459  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
460  if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
461  (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
462  inputElementType != outputElementType) {
463  // only check if both element types are int/index/float/UniformQuantized
464  // eg, not sure how to check quant::QuantizedType
465  // this happens in test_conv2d_q_grouped_convolution in
466  // tfl-to-tosa-pipeline.mlir
467  op.emitOpError("expect input and output to have same element type, got ")
468  << inputElementType << " and " << outputElementType;
469  return failure();
470  }
471  return success();
472 }
473 
474 LogicalResult tosa::ArgMaxOp::verify() {
475  const ShapedType resultType = llvm::cast<ShapedType>(getType());
476 
477  // Ensure output is of 32-bit integer
478  if (const auto resultETy = resultType.getElementType();
479  !resultETy.isIntOrIndex())
480  return emitOpError("result tensor is not of integer type");
481 
482  const auto inputType = llvm::cast<ShapedType>(getInput().getType());
483  if (!inputType.hasRank())
484  return success();
485 
486  // Ensure axis is within the tensor rank
487  const int64_t axis = getAxisAttr().getInt();
488  if (((axis < 0) || axis >= inputType.getRank()))
489  return emitOpError("specified axis is outside the rank of the tensor");
490 
491  if (!resultType.hasRank())
492  return success();
493 
494  const ArrayRef<int64_t> inputShape = inputType.getShape();
495  const ArrayRef<int64_t> outputShape = resultType.getShape();
496  llvm::SmallVector<int64_t> expectedOutputShape(inputShape.begin(),
497  inputShape.end());
498  expectedOutputShape.erase(expectedOutputShape.begin() + axis);
499  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
500  return emitOpError("expected output shape '")
501  << expectedOutputShape << "', got '" << outputShape << "'";
502 
503  return success();
504 }
505 
506 template <typename T>
507 static LogicalResult verifyPoolingOp(T op) {
508  const llvm::ArrayRef<int64_t> kernel = op.getKernel();
509  if (llvm::any_of(kernel, [](int64_t s) { return s < 1; }))
510  return op.emitOpError("expect all kernel values to be >= 1, got ")
511  << kernel;
512 
513  const llvm::ArrayRef<int64_t> strides = op.getStride();
514  if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
515  return op.emitOpError("expect all stride values to be >= 1, got ")
516  << strides;
517 
518  const llvm::ArrayRef<int64_t> padding = op.getPad();
519  if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
520  return op.emitOpError("expect all padding values to be >= 0, got ")
521  << padding;
522 
523  // Padding must be less than kernel size to avoid a divide-by-zero
524  const int64_t kernelX = kernel[1];
525  const int64_t padLeft = padding[2];
526  const int64_t padRight = padding[3];
527  if (padRight >= kernelX || padLeft >= kernelX)
528  return op.emitOpError("expected left/right padding to be less than the "
529  "width of the kernel, got pad_left=")
530  << padLeft << ", pad_right=" << padRight << ", kernel_x=" << kernelX;
531 
532  const int64_t kernelY = kernel[0];
533  const int64_t padTop = padding[0];
534  const int64_t padBottom = padding[1];
535  if (padTop >= kernelY || padBottom >= kernelY)
536  return op.emitOpError("expected top/bottom padding to be less than the "
537  "height of the kernel, got pad_top=")
538  << padTop << ", pad_bottom=" << padBottom
539  << ", kernel_y=" << kernelY;
540 
541  const auto inputType =
542  llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
543  const auto outputType =
544  llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
545  if (!inputType || !outputType)
546  return success();
547 
548  const auto verifyOutputSize =
549  [&op](const int64_t inputSize, const int64_t outputSize,
550  const int64_t kernelSize, const int64_t strideSize,
551  const int64_t padBefore, const int64_t padAfter,
552  const llvm::StringRef dimName, const llvm::StringRef dimAxis,
553  const llvm::StringRef padBeforeName,
554  const llvm::StringRef padAfterName) -> LogicalResult {
555  if (ShapedType::isDynamic(inputSize))
556  return success();
557 
558  const std::optional<int64_t> calculatedOutSizeMinusOne =
559  idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
560  if (!calculatedOutSizeMinusOne.has_value())
561  return op.emitOpError("expected input_")
562  << dimName << " + pad_" << padBeforeName << " + pad_"
563  << padAfterName << " - kernel_" << dimAxis
564  << " to be wholly divisible by stride_" << dimAxis << ", got ("
565  << inputSize << " + " << padBefore << " + " << padAfter << " - "
566  << kernelSize << ") / " << strideSize;
567 
568  const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
569  if (!ShapedType::isDynamic(outputSize) && calculatedOutSize != outputSize)
570  return op.emitOpError("calculated output ")
571  << dimName << " did not match expected: "
572  << "calculated=" << calculatedOutSize
573  << ", expected=" << outputSize;
574 
575  return success();
576  };
577 
578  if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
579  kernel[0], strides[0], padding[0], padding[1],
580  "height", "y", "top", "bottom")))
581  return failure();
582 
583  if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
584  kernel[1], strides[1], padding[2], padding[3],
585  "width", "x", "left", "right")))
586  return failure();
587 
588  return success();
589 }
590 
591 LogicalResult tosa::AvgPool2dOp::verify() {
592  if (failed(verifyPoolingOp(*this)))
593  return failure();
594 
595  const Type inputETy = getStorageElementTypeOrSelf(getInput().getType());
596  const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType());
597  const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType());
598  const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType());
599 
600  auto accType = getAccType();
601  if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
602  return emitOpError("accumulator type for integer tensor is not i32");
603 
604  if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
605  return emitOpError("accumulator type for f16 tensor is not f16/f32");
606 
607  if (inputETy.isBF16() && !accType.isF32())
608  return emitOpError("accumulator type for bf16 tensor is not f32");
609 
610  if (inputETy.isF32() && !accType.isF32())
611  return emitOpError("accumulator type for f32 tensor is not f32");
612 
613  if (inputETy != inputZpETy)
614  return emitOpError("expect both input and its zero point are the same "
615  "element type, got ")
616  << inputETy << " and " << inputZpETy;
617 
618  if (resultETy != outputZpETy)
619  return emitOpError("expect both output and its zero point are the same "
620  "element type, got ")
621  << resultETy << " and " << outputZpETy;
622 
623  FailureOr<int64_t> maybeIZp = getInputZeroPoint();
624  if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
625  return failure();
626 
627  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
628  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
629  return failure();
630 
631  return success();
632 }
633 
634 LogicalResult tosa::ClampOp::verify() {
635  mlir::Type inputETy =
636  llvm::cast<ShapedType>(getInput().getType()).getElementType();
637  if (auto quantType =
638  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
639  inputETy = quantType.getStorageType();
640  }
641  mlir::Type outputETy =
642  llvm::cast<ShapedType>(getOutput().getType()).getElementType();
643  if (auto quantType =
644  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
645  outputETy = quantType.getStorageType();
646  }
647  if (inputETy != outputETy)
648  return emitOpError("input/output element types are incompatible.");
649 
650  auto maxValAttr = getMaxValAttr();
651  auto minValAttr = getMinValAttr();
652 
653  unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
654 
655  if (inputETy.isInteger(dataTypeBitWidth)) {
656  // if input datatype is integer, check that the min_val/max_val attributes
657  // are integer attributes, and that their type is the same as the input's
658  // datatype
659  auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
660  auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
661  if (!intMaxValAttr || !intMinValAttr ||
662  (intMaxValAttr.getType() != intMinValAttr.getType()) ||
663  (intMaxValAttr.getType() != inputETy))
664  return emitOpError("min/max attributes types are incompatible with "
665  "input/output element types.");
666  } else {
667  // otherwise, input datatype is float, check that the min_val/max_val
668  // attributes share the same type and that their type is the same as the
669  // input's datatype
670  auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
671  auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
672  if (!floatMaxValAttr || !floatMinValAttr ||
673  (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
674  (floatMaxValAttr.getType() != inputETy))
675  return emitOpError("min/max attributes types are incompatible with "
676  "input/output element types.");
677  }
678 
679  return success();
680 }
681 
682 //===----------------------------------------------------------------------===//
683 // TOSA Operator Quantization Builders.
684 //===----------------------------------------------------------------------===//
685 
686 /// This builder is called on all convolution operators except TransposeConv,
687 /// which has specialized output shape semantics. The builder also defines the
688 /// bitwidth of the output given the bit width of the input & weight content.
689 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
690  Type outputType, Value input, Value weight,
691  Value bias, DenseI64ArrayAttr pad,
692  DenseI64ArrayAttr stride,
693  DenseI64ArrayAttr dilation,
694  TypeAttr accType) {
695  auto zps = createZPsAsConst(builder, input, weight);
696  result.addOperands({input, weight, bias, zps.first, zps.second});
697  result.addAttribute("pad", pad);
698  result.addAttribute("stride", stride);
699  result.addAttribute("dilation", dilation);
700  result.addAttribute("acc_type", accType);
701  Type finalOutputType = outputType;
702  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
703  if (quantAttr) {
704  finalOutputType =
705  buildConvOpResultTypeInfo(builder, outputType, input, weight);
706  }
707  result.addTypes(finalOutputType);
708 }
709 
710 /// Handles tosa.transpose_conv2d which has outpad and output shape
711 /// attributes.
712 static void
714  Type outputType, Value input, Value weight,
715  Value bias, DenseI64ArrayAttr outpad,
716  DenseI64ArrayAttr stride, TypeAttr accType) {
717  auto zps = createZPsAsConst(builder, input, weight);
718  result.addOperands({input, weight, bias, zps.first, zps.second});
719  result.addAttribute("out_pad", outpad);
720  result.addAttribute("stride", stride);
721  result.addAttribute("acc_type", accType);
722  Type finalOutputType = outputType;
723  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
724  if (quantAttr) {
725  finalOutputType =
726  buildConvOpResultTypeInfo(builder, outputType, input, weight);
727  }
728  result.addTypes(finalOutputType);
729 }
730 
731 /// The tosa.matmul op is also intended to be generated where a fully_connected
732 /// op must be constructed where the weight is not a constant. In this case,
733 /// the fully_connected op must be expressed using matmul.
734 /// TODO: Add link to the leglization document explaining this.
736  OperationState &result, Type outputType,
737  Value a, Value b) {
738  auto zps = createZPsAsConst(builder, a, b);
739  result.addOperands({a, b, zps.first, zps.second});
740 
741  Type finalOutputType{outputType};
742  if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
743  auto eType = getStorageElementTypeOrSelf(a.getType());
744  auto inputBits = eType.getIntOrFloatBitWidth();
745 
746  auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
747  assert(outputShapedType && "Output must be a shaped type");
748 
749  IntegerType accElementType;
750  if (inputBits == 16)
751  accElementType = builder.getIntegerType(48);
752  else
753  accElementType = builder.getI32Type();
754 
755  finalOutputType = outputShapedType.clone(accElementType);
756  }
757  result.addTypes(finalOutputType);
758 }
759 
760 /// Both the tosa.avg_pool2d and unary ops use the same
761 /// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
762 /// has additional parameters not part of the unary ops.
763 static void
765  Type outputType, Value input,
766  DenseArrayAttr kernel, DenseArrayAttr stride,
767  DenseArrayAttr pad, TypeAttr accType) {
768  const Location loc{result.location};
769  int64_t inputZp{0};
770  int64_t outputZp{0};
771 
772  if (auto quantAttr =
773  buildUnaryOpQuantizationAttr(builder, input, outputType)) {
774  inputZp = quantAttr.getInputZp();
775  outputZp = quantAttr.getOutputZp();
776  }
777  const std::optional<Value> inputZpOp =
778  createZeroPointTensor(builder, loc, input.getType(), inputZp);
779  if (!inputZpOp) {
780  (void)emitError(
781  loc,
782  "Failed to create input zero point tensor for quantized AVG_POOL2D op");
783  }
784  const std::optional<Value> outputZpOp =
785  createZeroPointTensor(builder, loc, outputType, outputZp);
786  if (!outputZpOp) {
787  (void)emitError(loc, "Failed to create output zero point tensor for "
788  "quantized AVG_POOL2D op");
789  }
790 
791  if (inputZpOp && outputZpOp) {
792  result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
793  } else {
794  // failed to create one or more zero points above: just add input as
795  // operands this will trigger error in building the op because of missing
796  // zero points
797  result.addOperands({input});
798  }
799  result.addAttribute("kernel", kernel);
800  result.addAttribute("stride", stride);
801  result.addAttribute("pad", pad);
802  result.addAttribute("acc_type", accType);
803  result.types.push_back(outputType);
804 }
805 
806 /// This builder is called on single-parameter negate operator
807 /// to construct input and output zero points based on their
808 /// types.
810  OperationState &result, Type outputType,
811  Value input) {
812  const Location loc{result.location};
813  int64_t input1Zp{0};
814  int64_t outputZp{0};
815  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
816  if (quantAttr) {
817  input1Zp = quantAttr.getInputZp();
818  outputZp = quantAttr.getOutputZp();
819  }
820  const std::optional<Value> input1ZpOp =
821  createZeroPointTensor(builder, loc, input.getType(), input1Zp);
822  if (!input1ZpOp) {
823  (void)emitError(
824  loc, "Failed to create input1 zero point for quantized NEGATE op");
825  }
826 
827  const std::optional<Value> outputZpOp =
828  createZeroPointTensor(builder, loc, input.getType(), outputZp);
829  if (!outputZpOp) {
830  (void)emitError(
831  loc, "Failed to create output zero point for quantized NEGATE op");
832  }
833 
834  if (input1ZpOp && outputZpOp) {
835  result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
836  } else {
837  // failed to create one or more zero points above: just add input as
838  // operands. This will trigger error in building the op because of
839  // missing zero points
840  result.addOperands({input});
841  }
842 
843  result.types.push_back(outputType);
844 }
845 
846 /// This builder is called on TOSA pad operator that needs to create its own
847 /// OptionalAttr quantization_attr parameter to scale the padding values
848 /// correctly. No pad_const is interpreted as zero-padding.
849 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
850  Type outputType, Value input,
851  Value paddings) {
852  const Location loc{result.location};
853  int32_t zp{0};
854  const auto quantAttr = buildPadOpQuantizationAttr(builder, input);
855  if (quantAttr) {
856  zp = static_cast<int32_t>(quantAttr.getInputZp());
857  }
858  const auto padConstOp{createPadConstTensor(builder, loc, input, zp)};
859  result.addOperands({input, paddings, padConstOp});
860  result.types.push_back(outputType);
861 }
862 
863 //===----------------------------------------------------------------------===//
864 // TOSA Operator Return Type Inference.
865 //===----------------------------------------------------------------------===//
866 
867 static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
868  SmallVector<int64_t> &outShape) {
869  int64_t outRank = 0;
870  for (int i = 0, e = operands.size(); i != e; ++i) {
871  auto shape = operands.getShape(i);
872  if (!shape.hasRank()) {
873  // TODO(jennik): Update function to have better case handling for
874  // invalid operands and for ranked tensors.
875  return failure();
876  }
877  outRank = std::max<int64_t>(outRank, shape.getRank());
878  }
879 
880  outShape.resize(outRank, 1);
881 
882  for (int i = 0, e = operands.size(); i != e; ++i) {
883  auto shape = operands.getShape(i);
884  auto rankDiff = outShape.size() - shape.getRank();
885 
886  for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
887  auto dim1 = outShape[i + rankDiff];
888  auto dim2 = shape.getDimSize(i);
889  auto resolvedDim = dim1;
890 
891  if (dim1 == 1) {
892  resolvedDim = dim2;
893  } else if (dim2 == 1) {
894  resolvedDim = dim1;
895  } else if (dim1 != dim2) {
896  return failure();
897  }
898  outShape[i + rankDiff] = resolvedDim;
899  }
900  }
901 
902  return success();
903 }
904 
905 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
906  MLIRContext *context, ::std::optional<Location> location,
907  ArgMaxOp::Adaptor adaptor,
908  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
909  ShapeAdaptor inputShape(adaptor.getInput().getType());
910  IntegerAttr axis = adaptor.getProperties().axis;
911  int32_t axisVal = axis.getValue().getSExtValue();
912 
913  if (!inputShape.hasRank()) {
914  inferredReturnShapes.push_back(ShapedTypeComponents());
915  return success();
916  }
917 
918  SmallVector<int64_t> outShape;
919  outShape.reserve(inputShape.getRank() - 1);
920  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
921  if (i == axisVal)
922  continue;
923  outShape.push_back(inputShape.getDimSize(i));
924  }
925 
926  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
927  return success();
928 }
929 
930 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
931  MLIRContext *context, ::std::optional<Location> location,
932  RFFT2dOp::Adaptor adaptor,
933  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
934  ShapeAdaptor inputShape(adaptor.getInputReal().getType());
935 
936  if (!inputShape.hasRank())
937  return failure();
938 
939  llvm::SmallVector<int64_t> outputShape;
940  outputShape.resize(3, ShapedType::kDynamic);
941  outputShape[0] = inputShape.getDimSize(0);
942  outputShape[1] = inputShape.getDimSize(1);
943  int64_t inWidth = inputShape.getDimSize(2);
944 
945  // Note that we can support this calculation symbolically
946  // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
947  if (inWidth != ShapedType::kDynamic)
948  outputShape[2] = inWidth / 2 + 1;
949 
950  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
951  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
952 
953  return success();
954 }
955 
956 static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
957  const llvm::StringRef dimName) {
958  const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
959  if (!isPowerOfTwo)
960  return op->emitOpError("expected ")
961  << dimName << " to be a power of two, got " << dimSize;
962 
963  return success();
964 }
965 
966 LogicalResult tosa::RFFT2dOp::verify() {
967  const auto outputTypes = getResultTypes();
968  if (failed(verifyCompatibleShapes(outputTypes)))
969  return emitOpError("expected output shapes to match, got ") << outputTypes;
970 
971  const auto inputType =
972  llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
973  if (!inputType)
974  return success();
975 
976  const int64_t height = inputType.getDimSize(1);
977  if (!ShapedType::isDynamic(height) &&
978  failed(verifyDimIsPowerOfTwo(*this, height, "height")))
979  return failure();
980 
981  const int64_t width = inputType.getDimSize(2);
982  if (!ShapedType::isDynamic(width) &&
983  failed(verifyDimIsPowerOfTwo(*this, width, "width")))
984  return failure();
985 
986  const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
987  if (!outputType)
988  return success();
989 
990  // Batch and height input/output dimensions should match
991  if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
992  outputType.getShape().drop_back())))
993  return emitOpError("expected batch and height dimensions of input/output "
994  "to match, got input=")
995  << inputType << " output=" << outputType;
996 
997  // Output width dimension expected to be input_width / 2 + 1
998  const int64_t outputWidth = outputType.getDimSize(2);
999  if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) &&
1000  (outputWidth != (width / 2) + 1))
1001  return emitOpError(
1002  "expected output width to be equal to input_width / 2 + 1, got ")
1003  << outputWidth;
1004 
1005  return success();
1006 }
1007 
1008 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1009  MLIRContext *context, ::std::optional<Location> location,
1010  FFT2dOp::Adaptor adaptor,
1011  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1012  inferredReturnShapes.push_back(
1013  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
1014  inferredReturnShapes.push_back(
1015  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
1016  return success();
1017 }
1018 
1019 LogicalResult tosa::FFT2dOp::verify() {
1020  const auto inputRealType =
1021  llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1022  const auto inputImagType =
1023  llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
1024  if (!inputRealType || !inputImagType)
1025  return success();
1026 
1027  const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
1028  return ShapedType::isDynamic(a) ? a : b;
1029  };
1030 
1031  const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1032  inputImagType.getDimSize(1));
1033  if (!ShapedType::isDynamic(height) &&
1034  failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1035  return failure();
1036 
1037  const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1038  inputImagType.getDimSize(2));
1039  if (!ShapedType::isDynamic(width) &&
1040  failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1041  return failure();
1042 
1043  return success();
1044 }
1045 
1046 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1047  MLIRContext *context, ::std::optional<Location> location,
1048  ConcatOp::Adaptor adaptor,
1049  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1050  // Infer all dimension sizes by reducing based on inputs.
1051  const Properties &prop = adaptor.getProperties();
1052  int32_t axis = prop.axis.getValue().getSExtValue();
1053  llvm::SmallVector<int64_t> outputShape;
1054  bool hasRankedInput = false;
1055  for (auto operand : adaptor.getOperands()) {
1056  ShapeAdaptor operandShape(operand.getType());
1057  if (!operandShape.hasRank())
1058  continue;
1059 
1060  // Copy the Operand's rank.
1061  if (!hasRankedInput)
1062  outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1063 
1064  // Copy shapes until the dim is non-dynamic.
1065  for (int i = 0, s = operandShape.getRank(); i < s; i++) {
1066  if (i == axis || operandShape.isDynamicDim(i))
1067  continue;
1068  if (outputShape[i] == ShapedType::kDynamic)
1069  outputShape[i] = operandShape.getDimSize(i);
1070  if (outputShape[i] != operandShape.getDimSize(i))
1071  return emitOptionalError(location,
1072  "Cannot concat tensors with different sizes"
1073  " on the non-axis dimension ",
1074  i);
1075  }
1076 
1077  hasRankedInput = true;
1078  }
1079  Type inputType =
1080  llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1081  if (!hasRankedInput) {
1082  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1083  return success();
1084  }
1085 
1086  // Determine the dimension size along the concatenation axis.
1087  int64_t concatDimSize = 0;
1088  for (auto operand : adaptor.getOperands()) {
1089  ShapeAdaptor operandShape(operand.getType());
1090 
1091  // We need to know the length of the concatenation axis of all inputs to
1092  // determine the dimension size of the output shape.
1093  if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1094  concatDimSize = ShapedType::kDynamic;
1095  break;
1096  }
1097 
1098  concatDimSize += operandShape.getDimSize(axis);
1099  }
1100 
1101  outputShape[axis] = concatDimSize;
1102 
1103  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1104  return success();
1105 }
1106 
1107 LogicalResult tosa::ConcatOp::verify() {
1108  // check that each input has same element type as output
1109  auto outType = getOutput().getType();
1110  const Operation::operand_range inputList = getInput1();
1111 
1112  // Check there is at least one input
1113  if (inputList.empty())
1114  return emitOpError("expect at least one input");
1115 
1116  if (!llvm::all_of(inputList, [&](auto input) {
1117  return succeeded(verifySameElementTypes(
1118  *this, /* inType = */ input.getType(), outType));
1119  })) {
1120  return failure();
1121  }
1122 
1123  const int32_t axis = getAxis();
1124  ShapeAdaptor firstRankedInputShape = nullptr;
1125  for (const auto &input : inputList) {
1126  const Type inputType = input.getType();
1127  ShapeAdaptor currShape(inputType);
1128  if (currShape.hasRank()) {
1129  firstRankedInputShape = currShape;
1130  // Check axis is in expected range
1131  if (axis < 0 || axis >= firstRankedInputShape.getRank())
1132  return emitOpError("expect axis to be within range 0 < axis < "
1133  "rank(input1[firstRankedTensorIdx]), got ")
1134  << axis;
1135  break;
1136  }
1137  }
1138 
1139  const auto allOperandsHasRank = [](const Value input) {
1140  return ShapeAdaptor(input.getType()).hasRank();
1141  };
1142  if (llvm::all_of(inputList, allOperandsHasRank)) {
1143  const int64_t firstInputRank = firstRankedInputShape.getRank();
1144 
1145  for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {
1146  const ShapeAdaptor inputShape(input.getType());
1147  const int64_t inputRank = inputShape.getRank();
1148  const size_t operandNum = index + 1;
1149 
1150  // Check that each operand has the same rank
1151  if (inputRank != firstInputRank)
1152  return emitOpError(
1153  "expect all operands to have the same rank, but got ")
1154  << firstInputRank << " vs " << inputRank << " on operands 0 and "
1155  << operandNum;
1156 
1157  // Check non-axis dims match
1158  for (int i = 0; i < inputRank; i++) {
1159  const int64_t inputDim = inputShape.getDimSize(i);
1160  const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
1161  if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
1162  inputShape.isDynamicDim(i))
1163  continue;
1164  if (inputDim != firstInputDim)
1165  return emitOpError("expect all operand shapes to have the same sizes "
1166  "on non-axis dimensions, but got ")
1167  << inputDim << " vs " << firstInputDim << " at index " << i
1168  << " on operands 0 and " << operandNum;
1169  }
1170  }
1171  }
1172 
1173  return success();
1174 }
1175 
1176 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1177  MLIRContext *context, ::std::optional<Location> location,
1178  ValueShapeRange operands, DictionaryAttr attributes,
1179  OpaqueProperties properties, RegionRange regions,
1180  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1181  auto elementType = IntegerType::get(context, /*width=*/1);
1182 
1183  llvm::SmallVector<int64_t> outShape;
1184  if (resolveBroadcastShape(operands, outShape).failed()) {
1185  inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
1186  return success();
1187  }
1188 
1189  inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
1190  return success();
1191 }
1192 
1193 bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1194  if (l.size() != r.size() || l.size() != 1)
1195  return false;
1196  return succeeded(verifyCompatibleShape(l[0], r[0]));
1197 }
1198 
1199 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1200  MLIRContext *context, ::std::optional<Location> location,
1201  MatMulOp::Adaptor adaptor,
1202  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1203  ShapeAdaptor lhsShape(adaptor.getA().getType());
1204  ShapeAdaptor rhsShape(adaptor.getB().getType());
1205 
1206  // All shapes are dynamic.
1207  SmallVector<int64_t> outShape;
1208  outShape.resize(3, ShapedType::kDynamic);
1209 
1210  if (lhsShape.hasRank()) {
1211  outShape[0] = lhsShape.getDimSize(0);
1212  outShape[1] = lhsShape.getDimSize(1);
1213  }
1214 
1215  if (rhsShape.hasRank()) {
1216  outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1217  : outShape[0];
1218  outShape[2] = rhsShape.getDimSize(2);
1219  }
1220 
1221  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1222  return success();
1223 }
1224 
1225 LogicalResult MatMulOp::verify() {
1226  auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
1227  auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
1228 
1229  // Must be shaped tensor types
1230  if (!aType)
1231  return emitOpError("expect a shaped tensor for input a, got ")
1232  << getA().getType();
1233 
1234  if (!bType)
1235  return emitOpError("expect a shaped tensor for input b, got ")
1236  << getB().getType();
1237 
1238  auto aElementType = aType.getElementType();
1239  auto bElementType = bType.getElementType();
1240 
1241  auto aQuantizedEType =
1242  llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1243  auto bQuantizedEType =
1244  llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1245 
1246  if (aQuantizedEType || bQuantizedEType) {
1247  if (!aQuantizedEType || !bQuantizedEType) {
1248  return emitOpError("expect operands to be both quantized or both not "
1249  "quantized, got ")
1250  << aElementType << " and " << bElementType;
1251  }
1252  // both a and b have quantized element types
1253  auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1254  auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1255  if (aQuantWidth != bQuantWidth) {
1256  return emitOpError("expect quantized operands to have same widths, got ")
1257  << aQuantWidth << " and " << bQuantWidth;
1258  }
1259  } else {
1260  // non-quantized element types
1261  if (aElementType != bElementType) {
1262  return emitOpError("expect same element type for inputs a and b, got ")
1263  << aElementType << " and " << bElementType;
1264  }
1265  }
1266 
1267  // check a_zp and b_zp
1268  auto aEType = getStorageElementTypeOrSelf(aType);
1269  auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
1270  if (aEType != aZpEType) {
1271  return emitOpError("expect input a and a_zp have the same "
1272  "element type, got ")
1273  << aEType << " and " << aZpEType;
1274  }
1275 
1276  auto bEType = getStorageElementTypeOrSelf(bType);
1277  auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
1278  if (bEType != bZpEType) {
1279  return emitOpError("expect input b and b_zp have the same "
1280  "element type, got ")
1281  << bEType << " and " << bZpEType;
1282  }
1283 
1284  FailureOr<int64_t> maybeAZp = getAZeroPoint();
1285  if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1286  return failure();
1287 
1288  FailureOr<int64_t> maybeBZp = getBZeroPoint();
1289  if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1290  return failure();
1291 
1292  return success();
1293 }
1294 
1295 LogicalResult tosa::PadOp::inferReturnTypeComponents(
1296  MLIRContext *context, ::std::optional<Location> location,
1297  PadOp::Adaptor adaptor,
1298  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1299  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1300  auto paddingRank =
1301  cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
1302  SmallVector<int64_t> outputShape;
1303 
1304  // If the input rank is unknown, we can infer the output rank using the
1305  // padding shape's rank divided by 2.
1306  if (!inputShape.hasRank()) {
1307  outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
1308  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1309  return success();
1310  }
1311 
1312  SmallVector<int64_t> paddingValues;
1313  // If the paddings value is not a constant, all dimensions must be dynamic.
1314  if (!tosa::getConstShapeValues(adaptor.getPadding().getDefiningOp(),
1315  paddingValues)) {
1316  outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1317  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1318  return success();
1319  }
1320 
1321  outputShape.reserve(inputShape.getRank());
1322  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1323  if (inputShape.isDynamicDim(i)) {
1324  outputShape.push_back(ShapedType::kDynamic);
1325  continue;
1326  }
1327  auto padFront = paddingValues[i * 2];
1328  auto padBack = paddingValues[i * 2 + 1];
1329  if (padFront < 0 || padBack < 0) {
1330  // if either padding for dim i is -1, output dim is unknown
1331  outputShape.push_back(ShapedType::kDynamic);
1332  continue;
1333  }
1334 
1335  outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
1336  }
1337 
1338  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1339  return success();
1340 }
1341 
1342 LogicalResult tosa::PadOp::verify() {
1343  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1344  /* outType = */ getOutput().getType())
1345  .failed()) {
1346  return failure();
1347  }
1348 
1349  if (auto padConst = getPadConst()) {
1350  if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
1351  /* outType = */ getOutput().getType())
1352  .failed()) {
1353  return failure();
1354  }
1355  }
1356 
1357  RankedTensorType inputType =
1358  llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1359  RankedTensorType outputType =
1360  llvm::dyn_cast<RankedTensorType>(getOutput().getType());
1361  if (!inputType || !outputType)
1362  return success();
1363 
1364  auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
1365 
1366  if (inputType.getRank() != outputType.getRank())
1367  return emitOpError() << "expect same input and output tensor rank.";
1368 
1369  if (paddingRank != inputType.getRank() * 2)
1370  return emitOpError() << "expected padding tensor dim 0 to have size "
1371  << inputType.getRank() * 2
1372  << " (2*rank(shape1)) but got size " << paddingRank;
1373 
1374  return success();
1375 }
1376 
1378  return to_vector(llvm::map_range(shape, [](int64_t dim) {
1379  return dim == -1 ? ShapedType::kDynamic : dim;
1380  }));
1381 }
1382 
1383 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
1384  MLIRContext *context, ::std::optional<Location> location,
1385  SliceOp::Adaptor adaptor,
1386  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1387 
1388  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
1389  SmallVector<int64_t> start;
1390  SmallVector<int64_t> size;
1391 
1392  if (!tosa::getConstShapeValues(adaptor.getStart().getDefiningOp(), start) ||
1393  !tosa::getConstShapeValues(adaptor.getSize().getDefiningOp(), size)) {
1394  auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
1395  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
1396  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
1397  return success();
1398  }
1399 
1400  // if size[i] is -1, all remaining elements in dimension i are included
1401  // in the slice, similar to TF.
1402  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1403  // initialize outputShape to all unknown
1404  SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
1405  if (inputShape.hasRank()) {
1406  for (size_t i = 0; i < size.size(); i++) {
1407  if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
1408  (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
1409  start[i] < inputShape.getDimSize(i))) {
1410  // size[i] is not 0 and not < -1, and start[i] is in valid range
1411  if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
1412  // input shape has unknown dim[i] - only valid if size[i] > 0
1413  if (size[i] > 0) {
1414  outputShape[i] = size[i];
1415  }
1416  } else {
1417  // input shape has known dim[i]
1418  if (size[i] == -1) {
1419  outputShape[i] = inputShape.getDimSize(i) - start[i];
1420  } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
1421  // start[i] + size[i] is within bound of input shape's dim[i]
1422  outputShape[i] = size[i];
1423  }
1424  }
1425  }
1426  }
1427  } else {
1428  outputShape = convertToMlirShape(size);
1429  }
1430  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1431  return success();
1432 }
1433 
1434 LogicalResult tosa::SliceOp::verify() {
1435  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1436  /* outType = */ getOutput().getType())
1437  .failed())
1438  return failure();
1439  auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1440  if (!inputType)
1441  return success();
1442 
1443  auto startShapeRank =
1444  llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
1445  if (inputType.getRank() != startShapeRank)
1446  return emitOpError("length of start is not equal to rank of input shape");
1447 
1448  auto sizeShapeRank =
1449  llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
1450  if (inputType.getRank() != sizeShapeRank)
1451  return emitOpError("length of size is not equal to rank of input shape");
1452 
1453  return success();
1454 }
1455 
1456 LogicalResult tosa::MulOp::inferReturnTypeComponents(
1457  MLIRContext *context, ::std::optional<Location> location,
1458  ValueShapeRange operands, DictionaryAttr attributes,
1459  OpaqueProperties properties, RegionRange regions,
1460  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1461  // mul op's output shape only depend on input1 and input2, not on shift
1462  ValueShapeRange twoInputs = operands.drop_back();
1463  llvm::SmallVector<int64_t> outShape;
1464  if (resolveBroadcastShape(twoInputs, outShape).failed()) {
1465  inferredReturnShapes.push_back(ShapedTypeComponents());
1466  } else {
1467  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1468  }
1469  return success();
1470 }
1471 
1472 LogicalResult tosa::MulOp::verify() {
1473  auto resElemType = getElementTypeOrSelf(getOutput());
1474 
1475  // Verify if the element type among operands and result match tosa
1476  // specification.
1477  if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
1478  IntegerType lhsIntType =
1479  cast<IntegerType>(getElementTypeOrSelf(getInput1()));
1480  IntegerType rhsIntType =
1481  cast<IntegerType>(getElementTypeOrSelf(getInput2()));
1482  if (lhsIntType != rhsIntType)
1483  return emitOpError("requires the same element type for all operands");
1484 
1485  // Though the spec requires the element type of result to be i32, a more
1486  // relaxed way is provided at dialect level for easier cooperating with
1487  // other dialects.
1488  if (lhsIntType.getWidth() > resIntType.getWidth())
1489  return emitOpError("invalid data type size for operands or result");
1490 
1491  } else {
1492  // For other supported type, the spec requires requires the same element
1493  // type for all operands (excludes `shift` operand) and results.
1494  for (int i = 0; i < 2; ++i) {
1495  if (getElementTypeOrSelf(getOperand(i)) != resElemType)
1496  return emitOpError(
1497  "requires the same element type for all operands and results");
1498  }
1499 
1500  // verify shift has value 0 for non-integer types
1501  ElementsAttr shift_elem;
1502  if (matchPattern(getShift(), m_Constant(&shift_elem))) {
1503  int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1504  if (shift != 0) {
1505  return emitOpError() << "require shift to be 0 for float type";
1506  }
1507  }
1508  }
1509 
1510  // Verify the op has same ranks for all main operands (excludes extra operands
1511  // such as shift of mul op, so this is the only difference with the built-in
1512  // `SameOperandsAndResultRank` trait) and results types, if known.
1513 
1514  // delegate function that returns true if type is a shaped type with known
1515  // rank
1516  auto hasRank = [](const Type type) {
1517  if (auto shaped_type = dyn_cast<ShapedType>(type))
1518  return shaped_type.hasRank();
1519 
1520  return false;
1521  };
1522 
1523  auto rankedOperandTypes =
1524  llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));
1525 
1526  auto rankedResultTypes =
1527  llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
1528 
1529  // If all operands and results are unranked, then no further verification.
1530  if (rankedOperandTypes.empty() && rankedResultTypes.empty())
1531  return success();
1532 
1533  // delegate function that returns rank of shaped type with known rank
1534  auto getRank = [](const Type type) {
1535  return cast<ShapedType>(type).getRank();
1536  };
1537 
1538  auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
1539  : getRank(*rankedResultTypes.begin());
1540 
1541  for (size_t i = 0; i < 2; ++i) {
1542  if (rank != getRank(rankedOperandTypes[i])) {
1543  return emitOpError("operands don't have matching ranks");
1544  }
1545  }
1546 
1547  for (const auto type : rankedResultTypes) {
1548  if (rank != getRank(type)) {
1549  return emitOpError("result type has different rank than operands");
1550  }
1551  }
1552 
1553  // check for broadcast compatible shapes in first two operands (ignoring
1554  // shift)
1555 
1556  // delegate function that returns shape of shaped type
1557  auto getShape = [](const Type type) {
1558  return mlir::cast<ShapedType>(type).getShape();
1559  };
1560  SmallVector<int64_t> resultShape;
1561  if (!mlir::OpTrait::util::getBroadcastedShape(getShape(rankedOperandTypes[0]),
1562  getShape(rankedOperandTypes[1]),
1563  resultShape)) {
1564  return emitOpError("operands don't have broadcast-compatible shapes");
1565  }
1566 
1567  return success();
1568 }
1569 
1570 LogicalResult tosa::TableOp::inferReturnTypeComponents(
1571  MLIRContext *context, ::std::optional<Location> location,
1572  TableOp::Adaptor adaptor,
1573  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1574  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1575 
1576  if (!inputShape.hasRank()) {
1577  inferredReturnShapes.push_back(ShapedTypeComponents());
1578  return success();
1579  }
1580 
1581  inferredReturnShapes.resize(1);
1582  inputShape.getDims(inferredReturnShapes[0]);
1583  return success();
1584 }
1585 
1586 LogicalResult tosa::TableOp::verify() {
1587  TensorType inputType = getInput1().getType();
1588  TensorType outputType = getOutput().getType();
1589 
1590  if (inputType.hasRank() && outputType.hasRank() &&
1591  inputType.getRank() != outputType.getRank())
1592  return emitOpError()
1593  << "expected input tensor rank to equal result tensor rank";
1594 
1595  auto inputDims = inputType.getShape();
1596  auto outputDims = outputType.getShape();
1597  for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
1598  int64_t dim = it.index();
1599  auto [inputDim, outputDim] = it.value();
1600  if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) {
1601  return emitOpError() << "dim(result, " << dim << ") = " << outputDim
1602  << " doesn't match dim(input, " << dim
1603  << ") = " << inputDim;
1604  }
1605  }
1606  return success();
1607 }
1608 
1609 LogicalResult
1610 tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
1611  // Multiples must be constants.
1612  DenseIntElementsAttr multiplesAttr;
1613  if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
1614  return failure();
1615  multiples = llvm::to_vector(
1616  llvm::map_range(multiplesAttr.getValues<APInt>(),
1617  [](const APInt &val) { return val.getSExtValue(); }));
1618  return success();
1619 }
1620 
1621 LogicalResult tosa::TileOp::inferReturnTypeComponents(
1622  MLIRContext *context, ::std::optional<Location> location,
1623  TileOp::Adaptor adaptor,
1624  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1625  DenseIntElementsAttr multiplesAttr;
1626  if (!matchPattern(adaptor.getMultiples(), m_Constant(&multiplesAttr)))
1627  return failure();
1628 
1629  SmallVector<int64_t> multiples = llvm::to_vector(
1630  llvm::map_range(multiplesAttr.getValues<APInt>(),
1631  [](const APInt &val) { return val.getSExtValue(); }));
1632 
1633  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1634  SmallVector<int64_t> outputShape;
1635  if (!inputShape.hasRank()) {
1636  outputShape.resize(multiples.size(), ShapedType::kDynamic);
1637  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1638  return success();
1639  } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
1640  return failure();
1641 
1642  // Any non dynamic dimension can be multiplied to a known size.
1643  outputShape.reserve(multiples.size());
1644  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1645  int64_t dim = inputShape.getDimSize(i);
1646  if (dim != ShapedType::kDynamic)
1647  dim *= multiples[i];
1648  outputShape.push_back(dim);
1649  }
1650 
1651  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1652  return success();
1653 }
1654 
1655 LogicalResult tosa::TileOp::verify() {
1656  if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
1657  /* outType = */ getOutput().getType())
1658  .failed()) {
1659  return failure();
1660  }
1661  ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
1662  ShapedType outputType = llvm::cast<ShapedType>(getType());
1663 
1664  shapeType multiplesType =
1665  llvm::cast<tosa::shapeType>(getMultiples().getType());
1666 
1667  auto multiplesRank = multiplesType.getRank();
1668 
1669  if (inputType.hasRank()) {
1670  if (inputType.getRank() != multiplesRank)
1671  return emitOpError("expect 'multiples' to have rank ")
1672  << inputType.getRank() << " but got " << multiplesRank << ".";
1673  if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
1674  return emitOpError("expect same input and output tensor rank.");
1675  } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
1676  return emitOpError("expect 'multiples' array to have length ")
1677  << outputType.getRank() << " but got " << multiplesRank << ".";
1678 
1679  SmallVector<int64_t> multiples;
1680  if (getConstantMultiples(multiples).succeeded() &&
1681  llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
1682  return emitOpError(
1683  "expect element of 'multiples' to be positive integer or -1.");
1684 
1685  return success();
1686 }
1687 
1688 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1689  if (l.size() != r.size() || l.size() != 1)
1690  return false;
1691  return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
1692 }
1693 
1694 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
1695  MLIRContext *context, ::std::optional<Location> location,
1696  ReshapeOp::Adaptor adaptor,
1697  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1698  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1699  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
1700  llvm::SmallVector<int64_t> newShapeValue;
1701  if (!tosa::getConstShapeValues(adaptor.getShape().getDefiningOp(),
1702  newShapeValue)) {
1703  auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
1704  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
1705  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
1706  return success();
1707  } else {
1708  newShapeValue = convertToMlirShape(newShapeValue);
1709  }
1710 
1711  // We cannot infer from the total number of elements so we must take the
1712  // shape attribute as exact.
1713  if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
1714  inferredReturnShapes.push_back(
1715  ShapedTypeComponents(newShapeValue, inputType));
1716  return success();
1717  }
1718 
1719  // Determine the number of elements covered by the slice of all static
1720  // dimensions. This allows us to infer the length of the remaining dynamic
1721  // dimension.
1722  int64_t numElements = inputShape.getNumElements();
1723  int64_t staticMul = 1;
1724  for (auto val : newShapeValue) {
1725  if (!ShapedType::isDynamic(val)) {
1726  staticMul *= val;
1727  }
1728  }
1729 
1730  // Determine the length of the dynamic dimension.
1731  for (auto &val : newShapeValue) {
1732  if (ShapedType::isDynamic(val))
1733  val = numElements / staticMul;
1734  }
1735 
1736  inferredReturnShapes.push_back(
1737  ShapedTypeComponents(newShapeValue, inputType));
1738  return success();
1739 }
1740 
1741 llvm::LogicalResult tosa::ReshapeOp::verify() {
1742  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1743  /* outType = */ getOutput().getType())
1744  .failed()) {
1745  return failure();
1746  }
1747  TensorType inputType = getInput1().getType();
1748  RankedTensorType outputType = getType();
1749 
1750  SmallVector<int64_t> shapeValues;
1751  if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
1752  // skip following checks if shape is not constant
1753  return mlir::success();
1754  }
1755 
1756  if ((int64_t)shapeValues.size() != outputType.getRank())
1757  return emitOpError() << "new shape does not match result rank";
1758 
1759  for (auto [newShapeDim, outputShapeDim] :
1760  zip(shapeValues, outputType.getShape())) {
1761  if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
1762  outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
1763  return emitOpError() << "new shape is inconsistent with result shape";
1764 
1765  if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
1766  return emitOpError() << "new shape has invalid tensor dimension size "
1767  << newShapeDim;
1768  }
1769 
1770  if (inputType.hasStaticShape()) {
1771  int64_t inputElementsNum = inputType.getNumElements();
1772  if (outputType.hasStaticShape()) {
1773  int64_t outputElementsNum = outputType.getNumElements();
1774  if (inputElementsNum != outputElementsNum) {
1775  return emitOpError() << "cannot reshape " << inputElementsNum
1776  << " elements into " << outputElementsNum;
1777  }
1778  }
1779 
1780  int64_t newShapeElementsNum = std::accumulate(
1781  shapeValues.begin(), shapeValues.end(), 1LL,
1782  [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
1783  bool isStaticNewShape =
1784  llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
1785  if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
1786  (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
1787  return emitOpError() << "cannot reshape " << inputElementsNum
1788  << " elements into " << newShapeElementsNum;
1789  }
1790  }
1791 
1792  int missingDims = llvm::count(shapeValues, -1);
1793  if (missingDims > 1)
1794  return emitOpError() << "expected at most one target dimension to be -1";
1795 
1796  return mlir::success();
1797 }
1798 
1799 // return failure if val is not a constant
1800 // set zp to -1 if val is non-zero float or val is not integer nor float
1801 // otherwise set zp to val's constant value
1802 template <typename T>
1803 static FailureOr<int64_t> getZeroPoint(T op, Value val) {
1804  ElementsAttr zpAttr;
1805  if (!matchPattern(val, m_Constant(&zpAttr))) {
1806  return failure();
1807  }
1808 
1809  Type zpElemType = zpAttr.getElementType();
1810 
1811  if (llvm::isa<FloatType>(zpElemType)) {
1812  if (zpAttr.getValues<APFloat>()[0].isZero()) {
1813  return 0;
1814  }
1815  // return non-zero value to trigger error check
1816  return -1;
1817  }
1818 
1819  if (llvm::isa<IntegerType>(zpElemType)) {
1820  return zpAttr.getValues<APInt>()[0].getSExtValue();
1821  }
1822 
1823  // return non-zero value to trigger error check
1824  return -1;
1825 }
1826 
1827 template <typename T>
1828 static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
1829  const std::string &operand) {
1830  Type zpElemType = getElementTypeOrSelf(val);
1831 
1832  if (!zpElemType.isInteger(8) && zp != 0) {
1833  // convert operand to lower case for error message
1834  std::string lower = operand;
1835  std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
1836  return op.emitOpError()
1837  << lower << " zero point must be zero for non-int8 integer types";
1838  }
1839 
1840  return success();
1841 }
1842 
1843 static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
1844  const int64_t &zp,
1845  const std::string &operand) {
1846  bool isInputZp = (operand == "Input");
1847 
1848  bool tensorUnsigned =
1849  isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
1850  StringRef tensorName = isInputZp ? "input" : "output";
1851 
1852  Type zpElemType = getElementTypeOrSelf(zpVal);
1853 
1854  if (zp != 0) {
1855  if (!zpElemType.isInteger(8) &&
1856  !(zpElemType.isInteger(16) && tensorUnsigned)) {
1857  return op.emitOpError()
1858  << "expect " << tensorName << "_zp of 0, got " << zp;
1859  }
1860  if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
1861  return op.emitOpError() << "expect " << tensorName
1862  << "_zp of 0 or 32768 for unsigned int16 "
1863  << tensorName << ", got " << zp;
1864  }
1865  }
1866 
1867  return success();
1868 }
1869 
1870 #define ZERO_POINT_HELPER(OP, OPERAND_NAME) \
1871  FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
1872  return getZeroPoint(*this, get##OPERAND_NAME##Zp()); \
1873  } \
1874  LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
1875  return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
1876  }
1877 
1878 ZERO_POINT_HELPER(Conv2DOp, Input)
1879 ZERO_POINT_HELPER(Conv2DOp, Weight)
1880 ZERO_POINT_HELPER(Conv3DOp, Input)
1881 ZERO_POINT_HELPER(Conv3DOp, Weight)
1882 ZERO_POINT_HELPER(DepthwiseConv2DOp, Input)
1883 ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight)
1884 ZERO_POINT_HELPER(TransposeConv2DOp, Input)
1885 ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
1886 ZERO_POINT_HELPER(AvgPool2dOp, Input)
1887 ZERO_POINT_HELPER(AvgPool2dOp, Output)
1888 ZERO_POINT_HELPER(MatMulOp, A)
1889 ZERO_POINT_HELPER(MatMulOp, B)
1890 ZERO_POINT_HELPER(NegateOp, Input1)
1891 ZERO_POINT_HELPER(NegateOp, Output)
1892 ZERO_POINT_HELPER(RescaleOp, Input)
1893 ZERO_POINT_HELPER(RescaleOp, Output)
1894 #undef ZERO_POINT_HELPER
1895 
1896 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
1897  MLIRContext *context, ::std::optional<Location> location,
1898  TransposeOp::Adaptor adaptor,
1899  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1900  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1901 
1902  // If input rank and permutation length is unknown, the output rank is
1903  // unknown.
1904  if (!inputShape.hasRank()) {
1905  inferredReturnShapes.push_back(ShapedTypeComponents());
1906  return success();
1907  }
1908 
1909  const auto inputRank = inputShape.getRank();
1910 
1911  // This would imply the number of permutations does not match the rank of
1912  // the input which is illegal.
1913  if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
1914  return failure();
1915  }
1916 
1917  SmallVector<int64_t> outputShape;
1918  // Rank-0 means no permutations matter.
1919  if (inputRank == 0) {
1920  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1921  return success();
1922  }
1923 
1924  // Check whether the input dimensions are all the same.
1925  bool allTheSame = true;
1926  for (int i = 1, s = inputRank; i < s; i++) {
1927  if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
1928  allTheSame = false;
1929  break;
1930  }
1931  }
1932 
1933  // If all of the input dimensions are the same we don't care about the
1934  // permutation.
1935  if (allTheSame) {
1936  outputShape.resize(inputRank, inputShape.getDimSize(0));
1937  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1938  return success();
1939  }
1940 
1941  outputShape.resize(inputRank, ShapedType::kDynamic);
1942 
1943  // Constant permutation values must be within the input rank.
1944  if (llvm::any_of(adaptor.getPerms(),
1945  [inputRank](const auto i) { return i >= inputRank; }))
1946  return failure();
1947 
1948  outputShape.reserve(inputRank);
1949  for (int i = 0, s = inputRank; i < s; i++) {
1950  outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
1951  }
1952 
1953  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1954  return success();
1955 }
1956 
1957 LogicalResult tosa::TransposeOp::verify() {
1958  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1959  /* outType = */ getOutput().getType())
1960  .failed()) {
1961  return failure();
1962  }
1963  TensorType inputType = getInput1().getType();
1964  TensorType outputType = getOutput().getType();
1965  const llvm::ArrayRef<int32_t> constantPerms = getPerms();
1966 
1967  if (inputType.hasRank() &&
1968  constantPerms.size() != static_cast<size_t>(inputType.getRank()))
1969  return emitOpError() << "expected perms attribute to have size "
1970  << inputType.getRank() << " (input rank) but got size "
1971  << constantPerms.size();
1972  if (inputType.hasRank() && outputType.hasRank() &&
1973  inputType.getRank() != outputType.getRank())
1974  return emitOpError()
1975  << "expected input tensor rank to equal result tensor rank";
1976  if (outputType.hasRank() &&
1977  constantPerms.size() != static_cast<size_t>(outputType.getRank()))
1978  return emitOpError() << "expected perms attribute to have size "
1979  << outputType.getRank()
1980  << " (output rank) but got size "
1981  << constantPerms.size();
1982 
1983  if (!llvm::all_of(constantPerms,
1984  [&constantPerms](int32_t s) {
1985  return s >= 0 &&
1986  static_cast<size_t>(s) < constantPerms.size();
1987  }) ||
1988  !isPermutationVector(llvm::to_vector(llvm::map_range(
1989  constantPerms, [](int32_t v) -> int64_t { return v; }))))
1990  return emitOpError() << "expected valid permutation indices";
1991 
1992  // Verify that the types of the input and output tensors are properly
1993  // permuted.
1994  if (inputType.hasRank() && outputType.hasRank()) {
1995  assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
1996  inputType.getRank() == outputType.getRank());
1997 
1998  for (auto i = 0; i < outputType.getRank(); i++) {
1999  if (inputType.isDynamicDim(constantPerms[i]) ||
2000  outputType.isDynamicDim(i))
2001  continue;
2002 
2003  if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
2004  return emitOpError()
2005  << "expected output tensor dim " << i << " to match "
2006  << "input dim " << constantPerms[i] << " with value of "
2007  << inputType.getDimSize(constantPerms[i]);
2008  }
2009  }
2010 
2011  return success();
2012 }
2013 
2014 LogicalResult TransposeOp::reifyResultShapes(
2015  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2016 
2017  const llvm::ArrayRef<int32_t> transposePerms = getPerms();
2018 
2019  Value input = getInput1();
2020  auto inputType = cast<TensorType>(input.getType());
2021 
2022  SmallVector<OpFoldResult> returnedDims(inputType.getRank());
2023  for (auto dim : transposePerms) {
2024  int32_t dimInInput = transposePerms[dim];
2025  if (inputType.isDynamicDim(dimInInput))
2026  returnedDims[dim] =
2027  builder.create<tensor::DimOp>(getLoc(), input, dimInInput)
2028  .getResult();
2029  else
2030  returnedDims[dim] =
2031  builder.getIndexAttr(inputType.getDimSize(dimInInput));
2032  }
2033 
2034  reifiedReturnShapes.emplace_back(std::move(returnedDims));
2035  return success();
2036 }
2037 
2038 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2039  MLIRContext *context, ::std::optional<Location> location,
2040  GatherOp::Adaptor adaptor,
2041  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2042  llvm::SmallVector<int64_t> outputShape;
2043  outputShape.resize(3, ShapedType::kDynamic);
2044 
2045  ShapeAdaptor valuesShape(adaptor.getValues().getType());
2046  if (valuesShape.hasRank()) {
2047  outputShape[0] = valuesShape.getDimSize(0);
2048  outputShape[2] = valuesShape.getDimSize(2);
2049  }
2050 
2051  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2052  if (indicesShape.hasRank()) {
2053  if (outputShape[0] == ShapedType::kDynamic)
2054  outputShape[0] = indicesShape.getDimSize(0);
2055  if (outputShape[1] == ShapedType::kDynamic)
2056  outputShape[1] = indicesShape.getDimSize(1);
2057  }
2058 
2059  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2060  return success();
2061 }
2062 
2063 LogicalResult tosa::GatherOp::verify() {
2064  return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
2065  /* outType = */ getOutput().getType());
2066 }
2067 
2068 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2069  MLIRContext *context, ::std::optional<Location> location,
2070  ResizeOp::Adaptor adaptor,
2071  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2072  llvm::SmallVector<int64_t, 4> outputShape;
2073  outputShape.resize(4, ShapedType::kDynamic);
2074 
2075  ShapeAdaptor inputShape(adaptor.getInput().getType());
2076  if (!inputShape.hasRank())
2077  return failure();
2078 
2079  outputShape[0] = inputShape.getDimSize(0);
2080  outputShape[3] = inputShape.getDimSize(3);
2081  int64_t inputHeight = inputShape.getDimSize(1);
2082  int64_t inputWidth = inputShape.getDimSize(2);
2083 
2084  if ((inputHeight == ShapedType::kDynamic) ||
2085  (inputWidth == ShapedType::kDynamic))
2086  return failure();
2087 
2088  SmallVector<int64_t> scaleInt, offsetInt, borderInt;
2089  if (!tosa::getConstShapeValues(adaptor.getScale().getDefiningOp(),
2090  scaleInt) ||
2091  !tosa::getConstShapeValues(adaptor.getOffset().getDefiningOp(),
2092  offsetInt) ||
2093  !tosa::getConstShapeValues(adaptor.getBorder().getDefiningOp(),
2094  borderInt)) {
2095  return failure();
2096  }
2097 
2098  // Compute the output shape based on attributes: scale, offset, and border.
2099  outputShape[1] =
2100  (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2101  scaleInt[1]) +
2102  1;
2103 
2104  outputShape[2] =
2105  (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2106  scaleInt[3]) +
2107  1;
2108 
2109  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2110  return success();
2111 }
2112 
2113 LogicalResult tosa::ResizeOp::verify() {
2114  const Value input = getInput();
2115  const Value output = getOutput();
2116  const RankedTensorType inputType =
2117  llvm::dyn_cast<RankedTensorType>(input.getType());
2118  const RankedTensorType outputType =
2119  llvm::dyn_cast<RankedTensorType>(output.getType());
2120 
2121  if (!inputType)
2122  return emitOpError("expect a ranked input tensor");
2123  if (!outputType)
2124  return emitOpError("expect a ranked output tensor");
2125 
2126  const int64_t oh = outputType.getDimSize(1);
2127  const int64_t ow = outputType.getDimSize(2);
2128  const int64_t ih = inputType.getDimSize(1);
2129  const int64_t iw = inputType.getDimSize(2);
2130 
2131  SmallVector<int64_t> scaleValues;
2132  SmallVector<int64_t> offsetValues;
2133  SmallVector<int64_t> borderValues;
2134  if (!tosa::getConstShapeValues(getScale().getDefiningOp(), scaleValues) ||
2135  !tosa::getConstShapeValues(getOffset().getDefiningOp(), offsetValues) ||
2136  !tosa::getConstShapeValues(getBorder().getDefiningOp(), borderValues)) {
2137  // Skip following checks if shape is not constant
2138  return success();
2139  }
2140 
2141  if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
2142  return emitOpError("expect all scale values to be > 0, got ")
2143  << scaleValues;
2144 
2145  const int64_t scaleYN = scaleValues[0];
2146  const int64_t scaleYD = scaleValues[1];
2147  const int64_t scaleXN = scaleValues[2];
2148  const int64_t scaleXD = scaleValues[3];
2149 
2150  const int64_t offsetY = offsetValues[0];
2151  const int64_t offsetX = offsetValues[1];
2152 
2153  const int64_t borderY = borderValues[0];
2154  const int64_t borderX = borderValues[1];
2155 
2156  // Don't check with input height that could be broadcast (ih != 1)
2157  // since Linalg, a consumer of TOSA, expects broadcasting support
2158  // in resize to be available. Taking the cautious approach for now,
2159  // we can consider removing support for broadcasting later.
2160  if (ih != ShapedType::kDynamic && ih != 1) {
2161  const std::optional<int64_t> calculatedOutHeightMinusOne =
2162  idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
2163  if (!calculatedOutHeightMinusOne.has_value())
2164  return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
2165  "border_y ")
2166  << "to be wholly divisible by scale_y_d, got ((" << ih
2167  << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
2168  << ") / " << scaleYD;
2169  const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
2170  if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
2171  return emitOpError("calculated output height did not match expected: ")
2172  << "calculated=" << calculatedOutHeight << ", expected=" << oh;
2173  }
2174 
2175  // Don't check with input width that could be broadcast (iw != 1)
2176  // since Linalg, a consumer of TOSA, expects broadcasting support
2177  // in resize to be available. Taking the cautious approach for now,
2178  // we can consider removing support for broadcasting later.
2179  if (iw != ShapedType::kDynamic && iw != 1) {
2180  const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
2181  const std::optional<int64_t> calculatedOutWidthMinusOne =
2182  idivCheck(scaledInWidth, scaleXD);
2183  if (!calculatedOutWidthMinusOne.has_value())
2184  return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
2185  "border_x ")
2186  << "to be wholly divisible by scale_x_d, got ((" << iw
2187  << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
2188  << ") / " << scaleXD;
2189  const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
2190  if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
2191  return emitOpError("calculated output width did not match expected: ")
2192  << "calculated=" << calculatedOutWidth << ", expected=" << ow;
2193  }
2194 
2195  return success();
2196 }
2197 
2198 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
2199  MLIRContext *context, ::std::optional<Location> location,
2200  ScatterOp::Adaptor adaptor,
2201  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2202  llvm::SmallVector<int64_t> outputShape;
2203  outputShape.resize(3, ShapedType::kDynamic);
2204 
2205  ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
2206  if (valuesInShape.hasRank()) {
2207  outputShape[0] = valuesInShape.getDimSize(0);
2208  outputShape[1] = valuesInShape.getDimSize(1);
2209  outputShape[2] = valuesInShape.getDimSize(2);
2210  }
2211 
2212  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2213  if (indicesShape.hasRank()) {
2214  if (outputShape[0] == ShapedType::kDynamic)
2215  outputShape[0] = indicesShape.getDimSize(0);
2216  }
2217 
2218  ShapeAdaptor inputShape(adaptor.getInput().getType());
2219  if (inputShape.hasRank()) {
2220  if (outputShape[0] == ShapedType::kDynamic)
2221  outputShape[0] = inputShape.getDimSize(0);
2222  if (outputShape[2] == ShapedType::kDynamic)
2223  outputShape[2] = inputShape.getDimSize(2);
2224  }
2225 
2226  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2227  return success();
2228 }
2229 
2230 LogicalResult tosa::ScatterOp::verify() {
2231  if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
2232  /* outType = */ getValuesOut().getType())
2233  .failed() ||
2234  verifySameElementTypes(*this, /* inType = */ getInput().getType(),
2235  /* outType = */ getValuesOut().getType())
2236  .failed()) {
2237  return failure();
2238  }
2239  return success();
2240 }
2241 
2242 static LogicalResult ReduceInferReturnTypes(
2243  ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
2244  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2245  int64_t axisVal = axis.getValue().getSExtValue();
2246  if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
2247  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
2248  return success();
2249  }
2250 
2251  SmallVector<int64_t> outputShape;
2252  operandShape.getDims(outputShape);
2253  outputShape[axisVal] = 1;
2254  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2255  return success();
2256 }
2257 
2258 #define COMPATIBLE_RETURN_TYPES(OP) \
2259  bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
2260  if (l.size() != r.size() || l.size() != 1) \
2261  return false; \
2262  if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
2263  return false; \
2264  return succeeded(verifyCompatibleShape(l[0], r[0])); \
2265  }
2266 
2267 #define REDUCE_SHAPE_INFER(OP) \
2268  LogicalResult OP::inferReturnTypeComponents( \
2269  MLIRContext *context, ::std::optional<Location> location, \
2270  OP::Adaptor adaptor, \
2271  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2272  Type inputType = \
2273  llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
2274  ShapeAdaptor inputShape(adaptor.getInput().getType()); \
2275  const Properties &prop = adaptor.getProperties(); \
2276  return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
2277  inferredReturnShapes); \
2278  } \
2279  COMPATIBLE_RETURN_TYPES(OP)
2280 
2281 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
2282 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
2283 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
2284 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
2285 REDUCE_SHAPE_INFER(tosa::ReduceProductOp)
2286 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
2287 #undef REDUCE_SHAPE_INFER
2288 COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
2289 #undef COMPATIBLE_RETURN_TYPES
2290 
2291 template <typename T>
2292 static LogicalResult verifyReduceOp(T op) {
2293  // All TOSA reduce Ops have input, output and axis.
2294  TensorType inputType = op.getInput().getType();
2295  TensorType outputType = op.getOutput().getType();
2296  int32_t reduceAxis = op.getAxis();
2297 
2298  if (reduceAxis < 0) {
2299  op.emitOpError("reduce axis must not be negative");
2300  return failure();
2301  }
2302  if (inputType.hasRank()) {
2303  int64_t inputRank = inputType.getRank();
2304  // We allow for a special case where the input/output shape has rank 0 and
2305  // axis is also 0.
2306  if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
2307  op.emitOpError("expect input tensor rank (")
2308  << inputRank << ") to be larger than reduce axis (" << reduceAxis
2309  << ")";
2310  return failure();
2311  }
2312  }
2313  if (outputType.hasRank()) {
2314  int64_t outputRank = outputType.getRank();
2315  if (inputType.hasRank() && outputRank != inputType.getRank()) {
2316  op.emitOpError(
2317  "expect output tensor rank to be equal to input tensor rank");
2318  return failure();
2319  }
2320  if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
2321  op.emitOpError("expect output tensor rank (")
2322  << outputRank << ") to be larger than reduce axis (" << reduceAxis
2323  << ")";
2324  return failure();
2325  }
2326  // We can only verify the reduced dimension size to be 1 if this is not
2327  // the special case of output rank == 0.
2328  if (outputRank != 0) {
2329  auto outputShape = outputType.getShape();
2330  if (!outputType.isDynamicDim(reduceAxis) &&
2331  outputShape[reduceAxis] != 1) {
2332  op.emitOpError("expect reduced dimension size to be 1, got ")
2333  << outputShape[reduceAxis];
2334  return failure();
2335  }
2336  }
2337  }
2338  return success();
2339 }
2340 
2341 LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
2342 LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
2343 LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
2344 LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
2345 LogicalResult tosa::ReduceProductOp::verify() { return verifyReduceOp(*this); }
2346 LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
2347 
2348 static LogicalResult NAryInferReturnTypes(
2349  const ValueShapeRange &operands,
2350  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2351  llvm::SmallVector<int64_t> outShape;
2352  if (resolveBroadcastShape(operands, outShape).failed()) {
2353  inferredReturnShapes.push_back(ShapedTypeComponents());
2354  } else {
2355  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2356  }
2357  return success();
2358 }
2359 
2360 #define NARY_SHAPE_INFER(OP) \
2361  LogicalResult OP::inferReturnTypeComponents( \
2362  MLIRContext *context, ::std::optional<Location> location, \
2363  ValueShapeRange operands, DictionaryAttr attributes, \
2364  OpaqueProperties properties, RegionRange regions, \
2365  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2366  return NAryInferReturnTypes(operands, inferredReturnShapes); \
2367  }
2368 
2369 NARY_SHAPE_INFER(tosa::AbsOp)
2370 NARY_SHAPE_INFER(tosa::AddOp)
2371 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
2372 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
2373 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
2374 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
2375 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
2376 NARY_SHAPE_INFER(tosa::CastOp)
2377 NARY_SHAPE_INFER(tosa::CeilOp)
2378 NARY_SHAPE_INFER(tosa::ClampOp)
2379 NARY_SHAPE_INFER(tosa::ClzOp)
2380 NARY_SHAPE_INFER(tosa::CosOp)
2381 NARY_SHAPE_INFER(tosa::ExpOp)
2382 NARY_SHAPE_INFER(tosa::FloorOp)
2383 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
2384 NARY_SHAPE_INFER(tosa::GreaterOp)
2385 NARY_SHAPE_INFER(tosa::IdentityOp)
2386 NARY_SHAPE_INFER(tosa::IntDivOp)
2387 NARY_SHAPE_INFER(tosa::LogOp)
2388 NARY_SHAPE_INFER(tosa::LogicalAndOp)
2389 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
2390 NARY_SHAPE_INFER(tosa::LogicalNotOp)
2391 NARY_SHAPE_INFER(tosa::LogicalOrOp)
2392 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
2393 NARY_SHAPE_INFER(tosa::LogicalXorOp)
2394 NARY_SHAPE_INFER(tosa::MaximumOp)
2395 NARY_SHAPE_INFER(tosa::MinimumOp)
2396 NARY_SHAPE_INFER(tosa::PowOp)
2397 NARY_SHAPE_INFER(tosa::ReciprocalOp)
2398 NARY_SHAPE_INFER(tosa::ReverseOp)
2399 NARY_SHAPE_INFER(tosa::RsqrtOp)
2400 NARY_SHAPE_INFER(tosa::SinOp)
2401 NARY_SHAPE_INFER(tosa::SelectOp)
2402 NARY_SHAPE_INFER(tosa::SubOp)
2403 NARY_SHAPE_INFER(tosa::TanhOp)
2404 NARY_SHAPE_INFER(tosa::ErfOp)
2405 NARY_SHAPE_INFER(tosa::SigmoidOp)
2406 #undef PRED_SHAPE_INFER
2407 
2408 LogicalResult tosa::NegateOp::inferReturnTypeComponents(
2409  MLIRContext *context, ::std::optional<Location> location,
2410  NegateOp::Adaptor adaptor,
2411  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2412  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2413  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
2414  return success();
2415 }
2416 
2417 LogicalResult tosa::NegateOp::verify() {
2418  // Verify same element type
2419  const Type input1Type = getInput1().getType();
2420  const Type outputType = getOutput().getType();
2421  if (verifySameElementTypes(*this, input1Type, outputType).failed())
2422  return failure();
2423 
2424  // Verify same shape
2425  const SmallVector<Type, 2> types = {input1Type, outputType};
2426  if (failed(verifyCompatibleShapes(types)))
2427  return emitOpError() << "requires the same shape for input1 and output";
2428 
2429  const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
2430  const Type input1ZpEType =
2431  getStorageElementTypeOrSelf(getInput1Zp().getType());
2432  if (input1EType != input1ZpEType) {
2433  return emitOpError("expect both input1 and its zero point are the same "
2434  "element type, got ")
2435  << input1EType << " and " << input1ZpEType;
2436  }
2437  const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
2438  const Type outputZpEType =
2439  getStorageElementTypeOrSelf(getOutputZp().getType());
2440  if (outputEType != outputZpEType) {
2441  return emitOpError("expect both output and its zero point are the same "
2442  "element type, got ")
2443  << outputEType << " and " << outputZpEType;
2444  }
2445 
2446  FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2447  if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
2448  return failure();
2449 
2450  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2451  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
2452  return failure();
2453 
2454  return success();
2455 }
2456 
2457 static LogicalResult poolingInferReturnTypes(
2458  ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
2459  ArrayRef<int64_t> pad,
2460  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2461  llvm::SmallVector<int64_t> outputShape;
2462  outputShape.resize(4, ShapedType::kDynamic);
2463 
2464  // We only know the rank if the input type is unranked.
2465  if (!inputShape) {
2466  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2467  return success();
2468  }
2469 
2470  // Batch and number of channels are identical for pooling layer.
2471  outputShape[0] = inputShape.getDimSize(0);
2472  outputShape[3] = inputShape.getDimSize(3);
2473 
2474  int64_t height = inputShape.getDimSize(1);
2475  int64_t width = inputShape.getDimSize(2);
2476 
2477  if (!ShapedType::isDynamic(height)) {
2478  int64_t padded = height + pad[0] + pad[1] - kernel[0];
2479  outputShape[1] = padded / stride[0] + 1;
2480  }
2481 
2482  if (!ShapedType::isDynamic(width)) {
2483  int64_t padded = width + pad[2] + pad[3] - kernel[1];
2484  outputShape[2] = padded / stride[1] + 1;
2485  }
2486 
2487  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2488  return success();
2489 }
2490 
2491 LogicalResult Conv2DOp::inferReturnTypeComponents(
2492  MLIRContext *context, ::std::optional<Location> location,
2493  Conv2DOp::Adaptor adaptor,
2494  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2495  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
2496 
2497  int64_t inputWidth = ShapedType::kDynamic;
2498  int64_t inputHeight = ShapedType::kDynamic;
2499  int64_t weightWidth = ShapedType::kDynamic;
2500  int64_t weightHeight = ShapedType::kDynamic;
2501 
2502  // Input shape describes input width/height and batch.
2503 
2504  ShapeAdaptor inputShape(adaptor.getInput().getType());
2505  if (inputShape.hasRank()) {
2506  outputShape[0] = inputShape.getDimSize(0);
2507  inputHeight = inputShape.getDimSize(1);
2508  inputWidth = inputShape.getDimSize(2);
2509  }
2510 
2511  // Weight shapes describes the filter width/height and the output channels.
2512  ShapeAdaptor weightShape(adaptor.getWeight().getType());
2513  if (weightShape.hasRank()) {
2514  outputShape[3] = weightShape.getDimSize(0);
2515  weightHeight = weightShape.getDimSize(1);
2516  weightWidth = weightShape.getDimSize(2);
2517  }
2518 
2519  // Bias shape can describe the output channels.
2520  ShapeAdaptor biasShape(adaptor.getBias().getType());
2521  if (biasShape.hasRank()) {
2522  outputShape[3] = ShapedType::isDynamic(outputShape[3])
2523  ? biasShape.getDimSize(0)
2524  : outputShape[3];
2525  }
2526 
2527  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
2528  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
2529  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
2530 
2531  if (!ShapedType::isDynamic(inputHeight) &&
2532  !ShapedType::isDynamic(weightHeight)) {
2533  int64_t inputSize = inputHeight + padding[0] + padding[1];
2534  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
2535  int64_t unstridedResult = inputSize - filterSize + 1;
2536  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
2537  }
2538 
2539  if (!ShapedType::isDynamic(inputWidth) &&
2540  !ShapedType::isDynamic(weightWidth)) {
2541  int64_t inputSize = inputWidth + padding[2] + padding[3];
2542  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
2543  int64_t unstridedResult = inputSize - filterSize + 1;
2544  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
2545  }
2546 
2547  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2548  return success();
2549 }
2550 
2551 LogicalResult Conv2DOp::verify() {
2552  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
2553  return failure();
2554 
2555  llvm::ArrayRef<int64_t> padding = getPad();
2556  if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
2557  return emitOpError("expect all padding values to be >= 0, got ") << padding;
2558 
2559  llvm::ArrayRef<int64_t> strides = getStride();
2560  if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
2561  return emitOpError("expect all stride values to be >= 1, got ") << strides;
2562 
2563  llvm::ArrayRef<int64_t> dilations = getDilation();
2564  if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
2565  return emitOpError("expect all dilation values to be >= 1, got ")
2566  << dilations;
2567 
2568  const RankedTensorType outputType =
2569  llvm::dyn_cast<RankedTensorType>(getOutput().getType());
2570  if (!outputType)
2571  // Skip following checks if output is not ranked
2572  return success();
2573 
2574  const RankedTensorType inputType =
2575  llvm::dyn_cast<RankedTensorType>(getInput().getType());
2576  const RankedTensorType weightType =
2577  llvm::dyn_cast<RankedTensorType>(getWeight().getType());
2578 
2579  if (inputType && weightType) {
2580  const auto verifyOutputSize =
2581  [this](const int64_t inputSize, const int64_t kernelSize,
2582  const int64_t outputSize, const int64_t padBefore,
2583  const int64_t padAfter, const int64_t stride,
2584  const int64_t dilation, const llvm::StringRef dimName,
2585  const llvm::StringRef dimAxis,
2586  const llvm::StringRef padBeforeName,
2587  const llvm::StringRef padAfterName) -> LogicalResult {
2588  if (inputSize == ShapedType::kDynamic ||
2589  kernelSize == ShapedType::kDynamic)
2590  return success();
2591 
2592  const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
2593  inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
2594  stride);
2595  if (!calculatedOutSizeMinusOne.has_value())
2596  return emitOpError("expected input_")
2597  << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
2598  << padAfterName << " - (kernel_" << dimName
2599  << " - 1) * dilation_" << dimAxis
2600  << " to be wholly divisible by stride_" << dimAxis << ", got ("
2601  << inputSize << " - 1 + " << padBefore << " + " << padAfter
2602  << " - (" << kernelSize << " - 1) * " << dilation << ") / "
2603  << stride;
2604 
2605  const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
2606  if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
2607  return emitOpError("calculated output ")
2608  << dimName << " did not match expected: "
2609  << "calculated=" << calculatedOutSize
2610  << ", expected=" << outputSize;
2611 
2612  return success();
2613  };
2614 
2615  if (failed(verifyOutputSize(
2616  inputType.getDimSize(1), weightType.getDimSize(1),
2617  outputType.getDimSize(1), padding[0], padding[1], strides[0],
2618  dilations[0], "height", "y", "top", "bottom")))
2619  return failure();
2620 
2621  if (failed(verifyOutputSize(
2622  inputType.getDimSize(2), weightType.getDimSize(2),
2623  outputType.getDimSize(2), padding[2], padding[3], strides[1],
2624  dilations[1], "width", "x", "left", "right")))
2625  return failure();
2626  }
2627 
2628  const RankedTensorType biasType =
2629  llvm::dyn_cast<RankedTensorType>(getBias().getType());
2630  if (!biasType)
2631  // Skip following checks if bias is not ranked
2632  return success();
2633 
2634  const int64_t biasChannels = biasType.getDimSize(0);
2635  const int64_t outputChannels = outputType.getDimSize(3);
2636  if (biasChannels == ShapedType::kDynamic ||
2637  outputChannels == ShapedType::kDynamic)
2638  // Skip following checks if biasChannels or outputChannels is dynamic dim
2639  return success();
2640 
2641  if (biasChannels != outputChannels && biasChannels != 1)
2642  return emitOpError(
2643  "bias channels expected to be equal to output channels (")
2644  << outputChannels << ") or 1, got " << biasChannels;
2645  return success();
2646 }
2647 
2648 LogicalResult Conv3DOp::inferReturnTypeComponents(
2649  MLIRContext *context, ::std::optional<Location> location,
2650  Conv3DOp::Adaptor adaptor,
2651  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2652  llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
2653 
2654  int64_t inputWidth = ShapedType::kDynamic;
2655  int64_t inputHeight = ShapedType::kDynamic;
2656  int64_t inputDepth = ShapedType::kDynamic;
2657 
2658  int64_t weightWidth = ShapedType::kDynamic;
2659  int64_t weightHeight = ShapedType::kDynamic;
2660  int64_t weightDepth = ShapedType::kDynamic;
2661 
2662  // Input shape describes input width/height and batch.
2663  ShapeAdaptor inputShape(adaptor.getInput().getType());
2664  if (inputShape.hasRank()) {
2665  outputShape[0] = inputShape.getDimSize(0);
2666  inputDepth = inputShape.getDimSize(1);
2667  inputHeight = inputShape.getDimSize(2);
2668  inputWidth = inputShape.getDimSize(3);
2669  }
2670 
2671  // Weight shapes describes the filter width/height and the output channels.
2672  ShapeAdaptor weightShape(adaptor.getWeight().getType());
2673  if (weightShape.hasRank()) {
2674  outputShape[4] = weightShape.getDimSize(0);
2675  weightDepth = weightShape.getDimSize(1);
2676  weightHeight = weightShape.getDimSize(2);
2677  weightWidth = weightShape.getDimSize(3);
2678  }
2679 
2680  // Bias shape can describe the output channels.
2681  ShapeAdaptor biasShape(adaptor.getBias().getType());
2682  if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
2683  outputShape[4] = biasShape.getDimSize(0);
2684  }
2685 
2686  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
2687  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
2688  llvm::ArrayRef<int64_t> pad = adaptor.getPad();
2689 
2690  if (!ShapedType::isDynamic(inputDepth) &&
2691  !ShapedType::isDynamic(weightDepth)) {
2692  int32_t inputSize = inputDepth + pad[0] + pad[1];
2693  int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
2694  int32_t unstridedResult = inputSize - filterSize + 1;
2695  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
2696  }
2697 
2698  if (!ShapedType::isDynamic(inputHeight) &&
2699  !ShapedType::isDynamic(weightHeight)) {
2700  int32_t inputSize = inputHeight + pad[2] + pad[3];
2701  int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
2702  int32_t unstridedResult = inputSize - filterSize + 1;
2703  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
2704  }
2705 
2706  if (!ShapedType::isDynamic(inputWidth) &&
2707  !ShapedType::isDynamic(weightWidth)) {
2708  int32_t inputSize = inputWidth + pad[4] + pad[5];
2709  int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
2710  int32_t unstridedResult = inputSize - filterSize + 1;
2711  outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
2712  }
2713 
2714  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2715  return success();
2716 }
2717 
2718 LogicalResult Conv3DOp::verify() {
2719  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
2720  return failure();
2721  return success();
2722 }
2723 
2724 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
2725  MLIRContext *context, ::std::optional<Location> location,
2726  AvgPool2dOp::Adaptor adaptor,
2727  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2728  ShapeAdaptor inputShape(adaptor.getInput().getType());
2729  const Properties &prop = adaptor.getProperties();
2730  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
2731  inferredReturnShapes);
2732 }
2733 
2734 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
2735  MLIRContext *context, ::std::optional<Location> location,
2736  MaxPool2dOp::Adaptor adaptor,
2737  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2738  ShapeAdaptor inputShape(adaptor.getInput().getType());
2739  const Properties &prop = adaptor.getProperties();
2740  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
2741  inferredReturnShapes);
2742 }
2743 
2744 LogicalResult MaxPool2dOp::verify() {
2745  if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
2746  /* outType = */ getOutput().getType())))
2747  return failure();
2748 
2749  if (failed(verifyPoolingOp(*this)))
2750  return failure();
2751 
2752  return success();
2753 }
2754 
2755 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
2756  MLIRContext *context, ::std::optional<Location> location,
2757  DepthwiseConv2DOp::Adaptor adaptor,
2758  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2759  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
2760 
2761  int64_t inputWidth = ShapedType::kDynamic;
2762  int64_t inputHeight = ShapedType::kDynamic;
2763  int64_t inputChannels = ShapedType::kDynamic;
2764 
2765  int64_t weightWidth = ShapedType::kDynamic;
2766  int64_t weightHeight = ShapedType::kDynamic;
2767  int64_t depthChannels = ShapedType::kDynamic;
2768 
2769  // Input shape describes input width/height and batch.
2770  ShapeAdaptor inputShape(adaptor.getInput().getType());
2771  if (inputShape.hasRank()) {
2772  outputShape[0] = inputShape.getDimSize(0);
2773  inputHeight = inputShape.getDimSize(1);
2774  inputWidth = inputShape.getDimSize(2);
2775  inputChannels = inputShape.getDimSize(3);
2776  }
2777 
2778  // Weight shapes describes the filter width/height and the output channels.
2779  ShapeAdaptor weightShape(adaptor.getWeight().getType());
2780  if (weightShape.hasRank()) {
2781  weightHeight = weightShape.getDimSize(0);
2782  weightWidth = weightShape.getDimSize(1);
2783  inputChannels = ShapedType::isDynamic(inputChannels)
2784  ? weightShape.getDimSize(2)
2785  : inputChannels;
2786  depthChannels = weightShape.getDimSize(3);
2787  }
2788 
2789  // If both inputChannels and depthChannels are available we can determine
2790  // the output channels.
2791  if (!ShapedType::isDynamic(inputChannels) &&
2792  !ShapedType::isDynamic(depthChannels)) {
2793  outputShape[3] = inputChannels * depthChannels;
2794  }
2795 
2796  // Bias shape can describe the output channels.
2797  ShapeAdaptor biasShape(adaptor.getBias().getType());
2798  if (biasShape.hasRank()) {
2799  outputShape[3] = ShapedType::isDynamic(outputShape[3])
2800  ? biasShape.getDimSize(0)
2801  : outputShape[3];
2802  }
2803 
2804  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
2805  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
2806  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
2807 
2808  if (!ShapedType::isDynamic(inputHeight) &&
2809  !ShapedType::isDynamic(weightHeight)) {
2810  int64_t inputSize = inputHeight + padding[0] + padding[1];
2811  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
2812  int64_t unstridedResult = inputSize - filterSize + 1;
2813  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
2814  }
2815 
2816  if (!ShapedType::isDynamic(inputWidth) &&
2817  !ShapedType::isDynamic(weightWidth)) {
2818  int64_t inputSize = inputWidth + padding[2] + padding[3];
2819  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
2820  int64_t unstridedResult = inputSize - filterSize + 1;
2821  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
2822  }
2823 
2824  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2825  return success();
2826 }
2827 
2828 LogicalResult DepthwiseConv2DOp::verify() {
2829  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
2830  return failure();
2831  return success();
2832 }
2833 
2834 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
2835  MLIRContext *context, ::std::optional<Location> location,
2836  TransposeConv2DOp::Adaptor adaptor,
2837  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2838  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
2839 
2840  int64_t inputWidth = ShapedType::kDynamic;
2841  int64_t inputHeight = ShapedType::kDynamic;
2842  int64_t weightWidth = ShapedType::kDynamic;
2843  int64_t weightHeight = ShapedType::kDynamic;
2844 
2845  // Input shape describes input width/height and batch.
2846  ShapeAdaptor inputShape(adaptor.getInput().getType());
2847  if (inputShape.hasRank()) {
2848  outputShape[0] = ShapedType::isDynamic(outputShape[0])
2849  ? inputShape.getDimSize(0)
2850  : outputShape[0];
2851  inputHeight = inputShape.getDimSize(1);
2852  inputWidth = inputShape.getDimSize(2);
2853  }
2854 
2855  // Weight shapes describes the filter width/height and the output channels.
2856  ShapeAdaptor weightShape(adaptor.getWeight().getType());
2857  if (weightShape.hasRank()) {
2858  outputShape[3] = ShapedType::isDynamic(outputShape[3])
2859  ? weightShape.getDimSize(0)
2860  : outputShape[3];
2861  weightHeight = weightShape.getDimSize(1);
2862  weightWidth = weightShape.getDimSize(2);
2863  }
2864 
2865  // Bias shape can describe the output channels.
2866  ShapeAdaptor biasShape(adaptor.getInput().getType());
2867  if (biasShape.hasRank()) {
2868  outputShape[3] = ShapedType::isDynamic(outputShape[3])
2869  ? biasShape.getDimSize(0)
2870  : outputShape[3];
2871  }
2872 
2873  llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
2874  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
2875 
2876  if (!ShapedType::isDynamic(inputHeight) &&
2877  !ShapedType::isDynamic(weightHeight)) {
2878  int64_t calculateSize =
2879  (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
2880  outputShape[1] =
2881  ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
2882  }
2883 
2884  if (!ShapedType::isDynamic(inputWidth) &&
2885  !ShapedType::isDynamic(weightWidth)) {
2886  int64_t calculateSize =
2887  (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
2888  outputShape[2] =
2889  ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
2890  }
2891 
2892  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2893  return success();
2894 }
2895 
2896 LogicalResult TransposeConv2DOp::verify() {
2897  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
2898  return failure();
2899  return success();
2900 }
2901 
2902 LogicalResult RescaleOp::verify() {
2903  auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
2904  if (!inputType) {
2905  emitOpError("expect shaped tensor for input, got ") << getInput().getType();
2906  return failure();
2907  }
2908 
2909  auto inputElementType =
2910  getStorageElementTypeOrSelf(inputType.getElementType());
2911  if (!mlir::isa<IntegerType>(inputElementType)) {
2912  emitOpError("expect input to have integer element type, got ")
2913  << inputElementType;
2914  return failure();
2915  }
2916 
2917  auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
2918  if (!outputType) {
2919  emitOpError("expect shaped tensor for output, got ")
2920  << getOutput().getType();
2921  return failure();
2922  }
2923 
2924  auto outputElementType =
2925  getStorageElementTypeOrSelf(outputType.getElementType());
2926  if (!mlir::isa<IntegerType>(outputElementType)) {
2927  emitOpError("expect output to have integer element type, got ")
2928  << outputElementType;
2929  return failure();
2930  }
2931 
2932  if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
2933  .failed())
2934  return failure();
2935 
2936  if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
2937  .failed())
2938  return failure();
2939 
2940  FailureOr<int64_t> maybeIZp = getInputZeroPoint();
2941  if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
2942  return failure();
2943 
2944  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2945  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
2946  return failure();
2947 
2948  auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
2949  if (!multiplierType) {
2950  emitOpError("expect shaped tensor for multiplier, got ")
2951  << getMultiplier().getType();
2952  return failure();
2953  }
2954 
2955  auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
2956  if (!shiftType) {
2957  emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
2958  return failure();
2959  }
2960 
2961  // multiplier element type must be i32 for scale32 = true
2962  if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
2963  emitOpError("expect i32 element type for multiplier for scale32=true, got ")
2964  << multiplierType.getElementType();
2965  return failure();
2966  }
2967 
2968  // multiplier element type must be i16 for scale32 = false
2969  if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
2970  emitOpError(
2971  "expect i16 element type for multiplier for scale32=false, got ")
2972  << multiplierType.getElementType();
2973  return failure();
2974  }
2975 
2976  if (!inputType.hasRank())
2977  return success();
2978 
2979  // multiplier/shift must have shape = {numChannels},
2980  // where numChannel is 1 if per_channel = false
2981  // otherwise numChannel is dimension in input shape's last axis
2982  int64_t numChannels = 1;
2983  if (getPerChannel()) {
2984  numChannels = inputType.getDimSize(inputType.getRank() - 1);
2985  }
2986 
2987  if (!multiplierType.hasRank())
2988  return success();
2989 
2990  ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
2991  // multiplier input has rank 1 by dialect definition
2992  if (multiplierShape[0] != ShapedType::kDynamic &&
2993  multiplierShape[0] != numChannels) {
2994  emitOpError("expect shape of { ")
2995  << numChannels << " } for multiplier input, got { "
2996  << multiplierShape[0] << " }";
2997  return failure();
2998  }
2999 
3000  if (!shiftType.hasRank())
3001  return success();
3002 
3003  ArrayRef<int64_t> shiftShape = shiftType.getShape();
3004  // shift input has rank 1 by dialect definition
3005  if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3006  emitOpError("expect shape of { ")
3007  << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
3008  return failure();
3009  }
3010 
3011  return success();
3012 }
3013 
3014 LogicalResult RescaleOp::inferReturnTypeComponents(
3015  MLIRContext *context, ::std::optional<Location> location,
3016  RescaleOp::Adaptor adaptor,
3017  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3018  ShapeAdaptor inputShape(adaptor.getInput().getType());
3019  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3020  return success();
3021 }
3022 
3023 LogicalResult IfOp::inferReturnTypeComponents(
3024  MLIRContext *context, ::std::optional<Location> location,
3025  IfOp::Adaptor adaptor,
3026  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3028  for (Region *region : adaptor.getRegions()) {
3029  for (auto &block : *region)
3030  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3031  yieldOps.push_back(returnOp);
3032  }
3033 
3034  if (yieldOps.empty())
3035  return failure();
3036 
3037  // Get the initial type information for the yield op.
3038  llvm::SmallVector<ValueKnowledge> resultKnowledge;
3039  resultKnowledge.reserve(yieldOps.front().getNumOperands());
3040  for (auto operand : yieldOps.front().getOperands()) {
3041  resultKnowledge.push_back(
3042  ValueKnowledge::getKnowledgeFromType(operand.getType()));
3043  }
3044 
3045  for (auto yieldOp : yieldOps) {
3046  if (resultKnowledge.size() != yieldOp.getNumOperands())
3047  return failure();
3048 
3049  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
3050  int32_t index = it.index();
3051  auto meet = ValueKnowledge::meet(
3052  resultKnowledge[index],
3053  ValueKnowledge::getKnowledgeFromType(it.value().getType()));
3054  if (!meet)
3055  continue;
3056  resultKnowledge[index] = meet;
3057  }
3058  }
3059 
3060  for (const ValueKnowledge &result : resultKnowledge) {
3061  inferredReturnShapes.push_back(result.getShapedTypeComponents());
3062  }
3063 
3064  return success();
3065 }
3066 
3067 LogicalResult WhileOp::inferReturnTypeComponents(
3068  MLIRContext *context, ::std::optional<Location> location,
3069  WhileOp::Adaptor adaptor,
3070  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3072  for (auto &block : adaptor.getBodyGraph())
3073  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3074  yieldOps.push_back(returnOp);
3075 
3076  // TOSA's while must have a tosa.yield as its terminator. If not found this
3077  // tosa.while is invalid.
3078  if (yieldOps.empty())
3079  return failure();
3080 
3081  // Get the initial type information from the operand types.
3082  llvm::SmallVector<ValueKnowledge> resultKnowledge;
3083  resultKnowledge.reserve(yieldOps.front().getNumOperands());
3084  for (auto operand : yieldOps.front().getOperands()) {
3085  resultKnowledge.push_back(
3086  ValueKnowledge::getKnowledgeFromType(operand.getType()));
3087  }
3088 
3089  for (auto yieldOp : yieldOps) {
3090  if (resultKnowledge.size() != yieldOp.getNumOperands())
3091  return failure();
3092 
3093  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
3094  int32_t index = it.index();
3095  if (auto meet = ValueKnowledge::meet(
3096  resultKnowledge[index],
3097  ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
3098  resultKnowledge[index] = meet;
3099  }
3100  }
3101  }
3102 
3103  for (const ValueKnowledge &result : resultKnowledge) {
3104  inferredReturnShapes.push_back(result.getShapedTypeComponents());
3105  }
3106 
3107  return success();
3108 }
3109 
3110 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
3111  if (auto vt = llvm::dyn_cast<VectorType>(getType()))
3112  return llvm::to_vector<4>(vt.getShape());
3113  return std::nullopt;
3114 }
3115 
3116 // parse and print of IfOp refer to the implementation of SCF dialect.
3117 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
3118  // Create the regions for 'then'.
3119  result.regions.reserve(2);
3120  Region *thenRegion = result.addRegion();
3121  Region *elseRegion = result.addRegion();
3122 
3123  auto &builder = parser.getBuilder();
3125  // Create a i1 tensor type for the boolean condition.
3126  Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
3127  if (parser.parseOperand(cond) ||
3128  parser.resolveOperand(cond, i1Type, result.operands))
3129  return failure();
3130  // Parse optional results type list.
3131  if (parser.parseOptionalArrowTypeList(result.types))
3132  return failure();
3133  // Parse the 'then' region.
3134  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
3135  return failure();
3136 
3137  // If we find an 'else' keyword then parse the 'else' region.
3138  if (!parser.parseOptionalKeyword("else")) {
3139  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
3140  return failure();
3141  }
3142 
3143  // Parse the optional attribute list.
3144  if (parser.parseOptionalAttrDict(result.attributes))
3145  return failure();
3146  return success();
3147 }
3148 
3149 void IfOp::print(OpAsmPrinter &p) {
3150  bool printBlockTerminators = false;
3151 
3152  p << " " << getCondition();
3153  if (!getResults().empty()) {
3154  p << " -> (" << getResultTypes() << ")";
3155  // Print yield explicitly if the op defines values.
3156  printBlockTerminators = true;
3157  }
3158  p << ' ';
3159  p.printRegion(getThenGraph(),
3160  /*printEntryBlockArgs=*/false,
3161  /*printBlockTerminators=*/printBlockTerminators);
3162 
3163  // Print the 'else' regions if it exists and has a block.
3164  auto &elseRegion = getElseGraph();
3165  if (!elseRegion.empty()) {
3166  p << " else ";
3167  p.printRegion(elseRegion,
3168  /*printEntryBlockArgs=*/false,
3169  /*printBlockTerminators=*/printBlockTerminators);
3170  }
3171 
3172  p.printOptionalAttrDict((*this)->getAttrs());
3173 }
3174 
3175 LogicalResult ReverseOp::verify() {
3176  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
3177  /* outType = */ getOutput().getType())
3178  .failed())
3179  return failure();
3180  TensorType inputType = getInput1().getType();
3181  TensorType outputType = getOutput().getType();
3182  int32_t reverseAxis = getAxis();
3183 
3184  if (reverseAxis < 0)
3185  return emitOpError("expected non-negative reverse axis");
3186  if (inputType.hasRank()) {
3187  int64_t inputRank = inputType.getRank();
3188  // We allow for a special case where the input/output shape has rank 0 and
3189  // axis is also 0.
3190  if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
3191  return emitOpError("expect input tensor rank (")
3192  << inputRank << ") to be larger than reverse axis (" << reverseAxis
3193  << ")";
3194  }
3195  if (outputType.hasRank()) {
3196  int64_t outputRank = outputType.getRank();
3197  if (inputType.hasRank() && outputRank != inputType.getRank())
3198  return emitOpError(
3199  "expect output tensor rank to be equal to input tensor rank");
3200  if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
3201  return emitOpError("expect output tensor rank (")
3202  << outputRank << ") to be larger than reverse axis ("
3203  << reverseAxis << ")";
3204  }
3205  return success();
3206 }
3207 
3208 LogicalResult tosa::SelectOp::verify() {
3209  // verify input2 and input3 have same element type as output
3210  if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(),
3211  /* outType = */ getOutput().getType())
3212  .failed() ||
3213  verifySameElementTypes(*this, /* inType = */ getInput3().getType(),
3214  /* outType = */ getOutput().getType())
3215  .failed()) {
3216  return failure();
3217  }
3218  // verify input1 has element type of bool
3219  auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().getType());
3220  if (!predicateType) {
3221  return emitOpError("expect shaped tensor for input1, got ")
3222  << getInput1().getType();
3223  }
3224  auto predicateElementType = predicateType.getElementType();
3225  if (!predicateElementType.isInteger(1)) {
3226  return emitOpError("expect element type of bool for input1, got ")
3227  << predicateElementType;
3228  }
3229 
3230  return success();
3231 }
3232 
3233 // parse and print of WhileOp refer to the implementation of SCF dialect.
3234 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3237  Region *cond = result.addRegion();
3238  Region *body = result.addRegion();
3239 
3240  OptionalParseResult listResult =
3241  parser.parseOptionalAssignmentList(regionArgs, operands);
3242  if (listResult.has_value() && failed(listResult.value()))
3243  return failure();
3244 
3245  FunctionType functionType;
3246  SMLoc typeLoc = parser.getCurrentLocation();
3247  if (failed(parser.parseColonType(functionType)))
3248  return failure();
3249 
3250  result.addTypes(functionType.getResults());
3251 
3252  if (functionType.getNumInputs() != operands.size()) {
3253  return parser.emitError(typeLoc)
3254  << "expected as many input types as operands "
3255  << "(expected " << operands.size() << " got "
3256  << functionType.getNumInputs() << ")";
3257  }
3258 
3259  // Resolve input operands.
3260  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3261  parser.getCurrentLocation(),
3262  result.operands)))
3263  return failure();
3264 
3265  // Propagate the types into the region arguments.
3266  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3267  regionArgs[i].type = functionType.getInput(i);
3268 
3269  return failure(parser.parseRegion(*cond, regionArgs) ||
3270  parser.parseKeyword("do") || parser.parseRegion(*body) ||
3272 }
3273 
3275  Block::BlockArgListType blocksArgs,
3276  ValueRange initializers,
3277  StringRef prefix = "") {
3278  assert(blocksArgs.size() == initializers.size() &&
3279  "expected same length of arguments and initializers");
3280  if (initializers.empty())
3281  return;
3282 
3283  parser << prefix << '(';
3284  llvm::interleaveComma(
3285  llvm::zip(blocksArgs, initializers), parser,
3286  [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
3287  parser << ")";
3288 }
3289 
3290 void WhileOp::print(OpAsmPrinter &parser) {
3291  printInitializationList(parser, getCondGraph().front().getArguments(),
3292  getInputList(), " ");
3293  parser << " : ";
3294  parser.printFunctionalType(getInputList().getTypes(),
3295  getResults().getTypes());
3296  parser << ' ';
3297  parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
3298  parser << " do ";
3299  parser.printRegion(getBodyGraph());
3300  parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3301 }
3302 
3303 // Create a rank-1 const tensor for zero point of the source tensor.
3304 std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
3305  Location loc,
3306  Type srcElemType,
3307  int64_t zp) {
3308  srcElemType = getStorageElementTypeOrSelf(srcElemType);
3309  auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
3310  if (llvm::isa<FloatType>(srcElemType)) {
3311  auto zpAttr = DenseElementsAttr::get(
3312  zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
3313  return builder.create<tosa::ConstOp>(loc, zpType, zpAttr);
3314  }
3315  if (llvm::isa<IntegerType>(srcElemType)) {
3316  auto zpAttr =
3317  DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp));
3318  return builder.create<tosa::ConstOp>(loc, zpType, zpAttr);
3319  }
3320  llvm::errs() << "zero point is not allowed for unsupported data types\n";
3321  return std::nullopt;
3322 }
3323 
3324 //===----------------------------------------------------------------------===//
3325 // TOSA Shape and Shape Operators Helper functions.
3326 //===----------------------------------------------------------------------===//
3327 
3329  return mlir::isa<tosa::shapeType>(t);
3330 }
3331 
3332 LogicalResult
3334  int rank) {
3335  if (rank < 0)
3336  return emitError() << "invalid rank (must be >= 0): " << rank;
3337  return success();
3338 }
3339 
3341  for (auto v : op->getOperands()) {
3342  if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
3343  Operation *definingOp = v.getDefiningOp();
3344  if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
3345  return op->emitOpError("shape operand is not compile time resolvable");
3346  }
3347  }
3348  }
3349  return success();
3350 }
3351 
3353  for (auto type : op->getOperandTypes()) {
3354  if (!mlir::isa<mlir::tosa::shapeType>(type)) {
3355  return op->emitOpError("must have operands with tosa shape type");
3356  }
3357  }
3358  for (auto type : op->getResultTypes()) {
3359  if (!mlir::isa<mlir::tosa::shapeType>(type)) {
3360  return op->emitOpError("must have result with tosa shape type");
3361  }
3362  }
3363  return success();
3364 }
3365 
3366 LogicalResult
3368  if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) ||
3369  failed(verifyTosaShapeOperator(op)))
3370  return failure();
3371 
3372  // delegate function that returns rank of shape type
3373  auto getRank = [](const Type type) {
3374  return mlir::cast<mlir::tosa::shapeType>(type).getRank();
3375  };
3376  auto operandTypes = op->getOperandTypes();
3377  auto resultTypes = op->getResultTypes();
3378 
3379  auto rank = getRank(*op->getOperandTypes().begin());
3380  for (auto type : operandTypes) {
3381  if (getRank(type) != rank) {
3382  return op->emitOpError("operands don't have matching ranks");
3383  }
3384  }
3385  for (auto type : resultTypes) {
3386  if (getRank(type) != rank) {
3387  return op->emitOpError("result shape has different rank than operands");
3388  }
3389  }
3390  return success();
3391 }
3392 
3393 //===----------------------------------------------------------------------===//
3394 // TOSA Shape Operators verify functions.
3395 //===----------------------------------------------------------------------===//
3396 
3397 LogicalResult tosa::ConstShapeOp::verify() {
3398  // check one dimensional rank
3399  auto valuesRank = getValues().getType().getRank();
3400  if (valuesRank != 1)
3401  return emitOpError("expect elements in attribute values with rank 1");
3402  // check that number of elements in values attr equal to rank of result shape
3403  auto count = getValues().getNumElements();
3404  auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
3405  if (!(count == rank || (count == 1 && rank == 0))) {
3406  return emitOpError("expect number of elements in attribute values (")
3407  << count << ") to be equal to the rank (" << rank
3408  << ") for the result shape type";
3409  }
3410  return success();
3411 }
3412 
3413 //===----------------------------------------------------------------------===//
3414 // TOSA Attribute Definitions.
3415 //===----------------------------------------------------------------------===//
3416 
3417 #define GET_ATTRDEF_CLASSES
3418 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
3419 
3420 //===----------------------------------------------------------------------===//
3421 // TOSA Type Definitions.
3422 //===----------------------------------------------------------------------===//
3423 #define GET_TYPEDEF_CLASSES
3424 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
3425 
3426 //===----------------------------------------------------------------------===//
3427 // TOSA Operator Definitions.
3428 //===----------------------------------------------------------------------===//
3429 
3430 #define GET_OP_CLASSES
3431 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition: TosaOps.cpp:735
static FailureOr< int64_t > getZeroPoint(T op, Value val)
Definition: TosaOps.cpp:1803
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType)
Definition: TosaOps.cpp:443
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:2242
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
Definition: TosaOps.cpp:240
#define REDUCE_SHAPE_INFER(OP)
Definition: TosaOps.cpp:2267
static LogicalResult verifyConvOp(T op)
Definition: TosaOps.cpp:280
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:2457
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition: TosaOps.cpp:849
static LogicalResult verifyReduceOp(T op)
Definition: TosaOps.cpp:2292
#define NARY_SHAPE_INFER(OP)
Definition: TosaOps.cpp:2360
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition: TosaOps.cpp:713
static LogicalResult verifyConvOpModes(T op)
Definition: TosaOps.cpp:396
std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition: TosaOps.cpp:223
#define ZERO_POINT_HELPER(OP, OPERAND_NAME)
Definition: TosaOps.cpp:1870
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:2348
#define COMPATIBLE_RETURN_TYPES(OP)
Definition: TosaOps.cpp:2258
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition: TosaOps.cpp:867
Type getStorageElementTypeOrSelf(Type type)
Definition: TosaOps.cpp:229
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
Definition: TosaOps.cpp:809
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition: TosaOps.cpp:689
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
Definition: TosaOps.cpp:764
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition: TosaOps.cpp:1377
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
Definition: TosaOps.cpp:1828
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Definition: TosaOps.cpp:3274
static LogicalResult verifyPoolingOp(T op)
Definition: TosaOps.cpp:507
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
Definition: TosaOps.cpp:956
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:95
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class indicates that op operates on tosa shape types.
Definition: TosaOps.h:126
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
Definition: Operation.cpp:915
LogicalResult verifyTosaShapeOperator(Operation *op)
Definition: TosaOps.cpp:3352
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
Definition: TosaOps.cpp:3367
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
Definition: TosaOps.cpp:3340
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition: Traits.cpp:60
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Definition: QuantUtils.cpp:198
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
Definition: QuantUtils.cpp:289
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
Definition: QuantUtils.cpp:269
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
Definition: QuantUtils.cpp:162
ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &attr)
Definition: TosaOps.cpp:180
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
Definition: QuantUtils.cpp:214
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
Definition: TosaOps.cpp:3304
bool isa_tosa_shape_type(mlir::Type t)
Definition: TosaOps.cpp:3328
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, Attribute attr)
Definition: TosaOps.cpp:202
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Definition: QuantUtils.cpp:243
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
Definition: TosaOps.cpp:260
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:419
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
Definition: ShapeUtils.h:33
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45