MLIR  21.0.0git
TosaOps.cpp
Go to the documentation of this file.
1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://developer.mlplatform.org/w/tosa/
12 //
13 //===----------------------------------------------------------------------===//
14 
22 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 #include <numeric>
34 
35 using namespace mlir;
36 using namespace mlir::tosa;
37 
38 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
40 
41 //===----------------------------------------------------------------------===//
42 // Tosa dialect interface includes.
43 //===----------------------------------------------------------------------===//
44 
45 #include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
46 #include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
47 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
48 #include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
49 
50 namespace {
51 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
52 
53 //===----------------------------------------------------------------------===//
54 // Dialect Function Inliner Interface.
55 //===----------------------------------------------------------------------===//
56 struct TosaInlinerInterface : public DialectInlinerInterface {
58 
59  //===--------------------------------------------------------------------===//
60  // Analysis Hooks.
61  //===--------------------------------------------------------------------===//
62 
63  /// All operations can be inlined by default.
64  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
65  IRMapping &map) const final {
66  return true;
67  }
68 
69  /// All regions with If and While parent operators can be inlined.
70  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
71  IRMapping &map) const final {
72  return (isa<tosa::IfOp>(dest->getParentOp()) ||
73  isa<tosa::WhileOp>(dest->getParentOp()));
74  }
75 };
76 
77 /// This class implements the bytecode interface for the Tosa dialect.
78 struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
79  TosaDialectBytecodeInterface(Dialect *dialect)
80  : BytecodeDialectInterface(dialect) {}
81 
82  //===--------------------------------------------------------------------===//
83  // Attributes
84 
85  Attribute readAttribute(DialectBytecodeReader &reader) const override {
86  return ::readAttribute(getContext(), reader);
87  }
88 
89  LogicalResult writeAttribute(Attribute attr,
90  DialectBytecodeWriter &writer) const override {
91  return ::writeAttribute(attr, writer);
92  }
93 
94  //===--------------------------------------------------------------------===//
95  // Types
96 
97  Type readType(DialectBytecodeReader &reader) const override {
98  return ::readType(getContext(), reader);
99  }
100 
101  LogicalResult writeType(Type type,
102  DialectBytecodeWriter &writer) const override {
103  return ::writeType(type, writer);
104  }
105 
106  void writeVersion(DialectBytecodeWriter &writer) const final {
107  // TODO: Populate.
108  }
109 
110  std::unique_ptr<DialectVersion>
111  readVersion(DialectBytecodeReader &reader) const final {
112  // TODO: Populate
113  reader.emitError("Dialect does not support versioning");
114  return nullptr;
115  }
116 
117  LogicalResult upgradeFromVersion(Operation *topLevelOp,
118  const DialectVersion &version) const final {
119  return success();
120  }
121 };
122 
123 } // namespace
124 
125 //===----------------------------------------------------------------------===//
126 // TOSA control flow support.
127 //===----------------------------------------------------------------------===//
128 
129 /// Returns the while loop body.
130 SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
131  return {&getBodyGraph()};
132 }
133 
134 //===----------------------------------------------------------------------===//
135 // Tosa dialect initialization.
136 //===----------------------------------------------------------------------===//
137 
138 void TosaDialect::initialize() {
139  addTypes<
140 #define GET_TYPEDEF_LIST
141 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
142  >();
143  addOperations<
144 #define GET_OP_LIST
145 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
146  >();
147  addAttributes<
148 #define GET_ATTRDEF_LIST
149 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
150  >();
151  addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
152  declarePromisedInterfaces<
153  mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
154  ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
155  LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
156  LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
157  BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
158  NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
159  GreaterEqualOp, MatMulOp>();
160 }
161 
163  Type type, Location loc) {
164  // Tosa dialect constants only support ElementsAttr unlike standard dialect
165  // constant which supports all attributes.
166  if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
167  return builder.create<tosa::ConstShapeOp>(
168  loc, type, llvm::cast<DenseIntElementsAttr>(value));
169  }
170  if (llvm::isa<ElementsAttr>(value))
171  return builder.create<tosa::ConstOp>(loc, type,
172  llvm::cast<ElementsAttr>(value));
173  return nullptr;
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // Parsers and printers
178 //===----------------------------------------------------------------------===//
179 
180 ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
181  Attribute &attr) {
182  if (succeeded(parser.parseOptionalEqual())) {
183  if (failed(parser.parseAttribute(attr))) {
184  return parser.emitError(parser.getCurrentLocation())
185  << "expected attribute";
186  }
187  if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
188  typeAttr = TypeAttr::get(typedAttr.getType());
189  }
190  return success();
191  }
192 
193  Type type;
194  if (failed(parser.parseColonType(type))) {
195  return parser.emitError(parser.getCurrentLocation()) << "expected type";
196  }
197  typeAttr = TypeAttr::get(type);
198 
199  return success();
200 }
201 
203  Attribute attr) {
204  bool needsSpace = false;
205  auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
206  if (!typedAttr || typedAttr.getType() != type.getValue()) {
207  p << ": ";
208  p.printAttribute(type);
209  needsSpace = true; // subsequent attr value needs a space separator
210  }
211  if (attr) {
212  if (needsSpace)
213  p << ' ';
214  p << "= ";
215  p.printAttribute(attr);
216  }
217 }
218 
219 // Create a pad-const const tensor with value of `val` of required data-type
221  Value src, int32_t val) {
222  const auto srcType = getElementTypeOrSelf(src);
223  const auto srcElemType = getElementTypeOrSelf(src);
224  const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
225  const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
226  const auto padConstAttr{
227  llvm::isa<FloatType>(srcElemType)
228  ? DenseElementsAttr::get(padConstEType,
229  builder.getFloatAttr(srcElemType, val))
230  : DenseElementsAttr::get(padConstEType,
231  builder.getIntegerAttr(srcElemType, val))};
232  return builder.create<tosa::ConstOp>(loc, padConstType, padConstAttr);
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // Tosa utilities.
237 //===----------------------------------------------------------------------===//
238 
239 std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
240  if (lhs % rhs != 0)
241  return std::nullopt;
242  return lhs / rhs;
243 }
244 
245 //===----------------------------------------------------------------------===//
246 // Tosa utilities.
247 //===----------------------------------------------------------------------===//
248 
250  auto elementType = getElementTypeOrSelf(type);
251  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(elementType))
252  elementType = quantType.getStorageType();
253 
254  return elementType;
255 }
256 
257 static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
258  Value valZp, StringRef name) {
260  Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
261 
262  bool bothInts =
263  mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
264  bool sameBitWidth =
265  (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
266 
267  if (!bothInts || !sameBitWidth) {
268  return op->emitOpError()
269  << "expected " << name << " and " << name
270  << "_zp to both be integer of the same bitwidth, but got " << eType
271  << " vs. " << eZpType;
272  }
273  return success();
274 }
275 
276 //===----------------------------------------------------------------------===//
277 // TOSA Operator Verifiers.
278 //===----------------------------------------------------------------------===//
279 
280 template <typename T>
281 static LogicalResult verifyConvOp(T op) {
282  // All TOSA conv ops have an input and weight arguments which must be ranked
283  // tensors.
284  auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
285  if (!inputType) {
286  op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
287  return failure();
288  }
289 
290  auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
291  if (!weightType) {
292  op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
293  return failure();
294  }
295 
296  auto inputEType = inputType.getElementType();
297  auto weightEType = weightType.getElementType();
298  auto biasEType =
299  llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
300  auto resultEType =
301  llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
302  bool biasIsFloat = llvm::isa<FloatType>(biasEType);
303  bool resultIsFloat = llvm::isa<FloatType>(resultEType);
304 
305  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
306  inputEType = quantType.getStorageType();
307 
308  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
309  weightEType = quantType.getStorageType();
310 
311  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
312  biasEType = quantType.getStorageType();
313 
314  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
315  resultEType = quantType.getStorageType();
316 
317  if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
318  // for now, only enforce bias element type == result element type for
319  // float types.
320  op.emitOpError(
321  "expect both bias and result to have same element type, got ")
322  << biasEType << " and " << resultEType;
323  return failure();
324  }
325 
326  if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
327  isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
328  if (inputEType != weightEType) {
329  op.emitOpError(
330  "expect both input and weight to have same element type, got ")
331  << inputEType << " and " << weightEType;
332  return failure();
333  }
334  }
335 
336  bool inputIsFloat = llvm::isa<FloatType>(inputEType);
337  bool weightIsFloat = llvm::isa<FloatType>(weightEType);
338 
339  // Either both must be float or both non-float.
340  if (inputIsFloat != weightIsFloat) {
341  op.emitOpError(
342  "expect both input and weight to be float or not together, got ")
343  << inputEType << " and " << weightEType;
344  return failure();
345  }
346 
347  auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
348  if (inputEType != inputZpEType) {
349  return op.emitOpError("expect both input and its zero point are the same "
350  "element type, got ")
351  << inputEType << " and " << inputZpEType;
352  }
353 
354  auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
355  if (weightEType != weightZpEType) {
356  return op.emitOpError("expect both weight and its zero point are the same "
357  "element type, got ")
358  << weightEType << " and " << weightZpEType;
359  }
360 
361  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
362  if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
363  return failure();
364 
365  FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
366  if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
367  return failure();
368 
369  return success();
370 }
371 
372 LogicalResult tosa::ConstOp::verify() {
373 
374  auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().getType());
375  auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
376 
377  if (!attrType || !outputType) {
378  emitOpError("expected tensors for attr/result type");
379  return failure();
380  }
381 
382  if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
383  outputType.getElementType())) {
384  if (result.getStorageType() == attrType.getElementType())
385  return success();
386  }
387 
388  if (attrType.getElementType() != outputType.getElementType()) {
389  emitOpError("expected same attr/result element types");
390  return failure();
391  }
392 
393  return success();
394 }
395 
396 template <typename T>
397 static LogicalResult verifyConvOpModes(T op) {
398  auto inputEType =
399  llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
400 
401  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
402  inputEType = quantType.getStorageType();
403 
404  auto accType = op.getAccType();
405  if (inputEType.isInteger(8) && !accType.isInteger(32))
406  return op.emitOpError("accumulator type for i8 tensor is not i32");
407 
408  if (inputEType.isInteger(16) && !accType.isInteger(48))
409  return op.emitOpError("accumulator type for i16 tensor is not i48");
410 
411  if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
412  return op.emitOpError("accumulator type for f8 tensor is not f16");
413 
414  if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
415  return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
416 
417  if (inputEType.isBF16() && !accType.isF32())
418  return op.emitOpError("accumulator type for bf16 tensor is not f32");
419 
420  if (inputEType.isF32() && !accType.isF32())
421  return op.emitOpError("accumulator type for f32 tensor is not f32");
422 
423  auto resultEType =
424  llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
425 
426  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
427  resultEType = quantType.getStorageType();
428 
429  // check allowed input/result element types combinations
430  if ((inputEType.isInteger(8) && resultEType.isInteger(32)) ||
431  (inputEType.isInteger(16) && resultEType.isInteger(48)) ||
432  (isa<Float8E5M2Type>(inputEType) && resultEType.isF16()) ||
433  (isa<Float8E4M3FNType>(inputEType) && resultEType.isF16()) ||
434  (inputEType.isF16() && resultEType.isF16()) ||
435  (inputEType.isBF16() && resultEType.isBF16()) ||
436  (inputEType.isF32() && resultEType.isF32()))
437  return success();
438 
439  return op.emitOpError("input/output element types are incompatible.");
440 }
441 
442 // verify that inType and outType have same element types
443 template <typename T>
444 static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
445  auto inputType = llvm::dyn_cast<TensorType>(inType);
446  auto outputType = llvm::dyn_cast<TensorType>(outType);
447  if (!inputType) {
448  op.emitOpError("expect shaped tensor for input, got ") << inType;
449  return failure();
450  }
451  if (!outputType) {
452  op.emitOpError("expect shaped tensor for output, got ") << outType;
453  return failure();
454  }
455  auto inputElementType = inputType.getElementType();
456  auto outputElementType = outputType.getElementType();
457  auto inputQuantType =
458  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
459  auto outputQuantType =
460  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
461  if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
462  (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
463  inputElementType != outputElementType) {
464  // only check if both element types are int/index/float/UniformQuantized
465  // eg, not sure how to check quant::QuantizedType
466  // this happens in test_conv2d_q_grouped_convolution in
467  // tfl-to-tosa-pipeline.mlir
468  op.emitOpError("expect input and output to have same element type, got ")
469  << inputElementType << " and " << outputElementType;
470  return failure();
471  }
472  return success();
473 }
474 
475 LogicalResult tosa::ArgMaxOp::verify() {
476  const ShapedType resultType = llvm::cast<ShapedType>(getType());
477 
478  // Ensure output is of 32-bit integer
479  if (const auto resultETy = resultType.getElementType();
480  !resultETy.isIntOrIndex())
481  return emitOpError("result tensor is not of integer type");
482 
483  const auto inputType = llvm::cast<ShapedType>(getInput().getType());
484  if (!inputType.hasRank())
485  return success();
486 
487  // Ensure axis is within the tensor rank
488  const int64_t axis = getAxisAttr().getInt();
489  if (((axis < 0) || axis >= inputType.getRank()))
490  return emitOpError("specified axis is outside the rank of the tensor");
491 
492  if (!resultType.hasRank())
493  return success();
494 
495  const ArrayRef<int64_t> inputShape = inputType.getShape();
496  const ArrayRef<int64_t> outputShape = resultType.getShape();
497  llvm::SmallVector<int64_t> expectedOutputShape(inputShape.begin(),
498  inputShape.end());
499  expectedOutputShape.erase(expectedOutputShape.begin() + axis);
500  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
501  return emitOpError("expected output shape '")
502  << expectedOutputShape << "', got '" << outputShape << "'";
503 
504  return success();
505 }
506 
507 template <typename T>
508 static LogicalResult verifyPoolingOp(T op) {
509  const llvm::ArrayRef<int64_t> kernel = op.getKernel();
510  if (llvm::any_of(kernel, [](int64_t s) { return s < 1; }))
511  return op.emitOpError("expect all kernel values to be >= 1, got ")
512  << kernel;
513 
514  const llvm::ArrayRef<int64_t> strides = op.getStride();
515  if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
516  return op.emitOpError("expect all stride values to be >= 1, got ")
517  << strides;
518 
519  const llvm::ArrayRef<int64_t> padding = op.getPad();
520  if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
521  return op.emitOpError("expect all padding values to be >= 0, got ")
522  << padding;
523 
524  // Padding must be less than kernel size to avoid a divide-by-zero
525  const int64_t kernelX = kernel[1];
526  const int64_t padLeft = padding[2];
527  const int64_t padRight = padding[3];
528  if (padRight >= kernelX || padLeft >= kernelX)
529  return op.emitOpError("expected left/right padding to be less than the "
530  "width of the kernel, got pad_left=")
531  << padLeft << ", pad_right=" << padRight << ", kernel_x=" << kernelX;
532 
533  const int64_t kernelY = kernel[0];
534  const int64_t padTop = padding[0];
535  const int64_t padBottom = padding[1];
536  if (padTop >= kernelY || padBottom >= kernelY)
537  return op.emitOpError("expected top/bottom padding to be less than the "
538  "height of the kernel, got pad_top=")
539  << padTop << ", pad_bottom=" << padBottom
540  << ", kernel_y=" << kernelY;
541 
542  const auto inputType =
543  llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
544  const auto outputType =
545  llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
546  if (!inputType || !outputType)
547  return success();
548 
549  const auto verifyOutputSize =
550  [&op](const int64_t inputSize, const int64_t outputSize,
551  const int64_t kernelSize, const int64_t strideSize,
552  const int64_t padBefore, const int64_t padAfter,
553  const llvm::StringRef dimName, const llvm::StringRef dimAxis,
554  const llvm::StringRef padBeforeName,
555  const llvm::StringRef padAfterName) -> LogicalResult {
556  if (ShapedType::isDynamic(inputSize))
557  return success();
558 
559  const std::optional<int64_t> calculatedOutSizeMinusOne =
560  idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
561  if (!calculatedOutSizeMinusOne.has_value())
562  return op.emitOpError("expected input_")
563  << dimName << " + pad_" << padBeforeName << " + pad_"
564  << padAfterName << " - kernel_" << dimAxis
565  << " to be wholly divisible by stride_" << dimAxis << ", got ("
566  << inputSize << " + " << padBefore << " + " << padAfter << " - "
567  << kernelSize << ") / " << strideSize;
568 
569  const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
570  if (!ShapedType::isDynamic(outputSize) && calculatedOutSize != outputSize)
571  return op.emitOpError("calculated output ")
572  << dimName << " did not match expected: "
573  << "calculated=" << calculatedOutSize
574  << ", expected=" << outputSize;
575 
576  return success();
577  };
578 
579  if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
580  kernel[0], strides[0], padding[0], padding[1],
581  "height", "y", "top", "bottom")))
582  return failure();
583 
584  if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
585  kernel[1], strides[1], padding[2], padding[3],
586  "width", "x", "left", "right")))
587  return failure();
588 
589  return success();
590 }
591 
592 LogicalResult tosa::AvgPool2dOp::verify() {
593  if (failed(verifyPoolingOp(*this)))
594  return failure();
595 
596  const Type inputETy = getStorageElementTypeOrSelf(getInput().getType());
597  const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType());
598  const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType());
599  const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType());
600 
601  auto accType = getAccType();
602  if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
603  return emitOpError("accumulator type for integer tensor is not i32");
604 
605  if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
606  return emitOpError("accumulator type for f16 tensor is not f16/f32");
607 
608  if (inputETy.isBF16() && !accType.isF32())
609  return emitOpError("accumulator type for bf16 tensor is not f32");
610 
611  if (inputETy.isF32() && !accType.isF32())
612  return emitOpError("accumulator type for f32 tensor is not f32");
613 
614  if (inputETy != inputZpETy)
615  return emitOpError("expect both input and its zero point are the same "
616  "element type, got ")
617  << inputETy << " and " << inputZpETy;
618 
619  if (resultETy != outputZpETy)
620  return emitOpError("expect both output and its zero point are the same "
621  "element type, got ")
622  << resultETy << " and " << outputZpETy;
623 
624  FailureOr<int64_t> maybeIZp = getInputZeroPoint();
625  if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
626  return failure();
627 
628  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
629  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
630  return failure();
631 
632  return success();
633 }
634 
635 LogicalResult tosa::ClampOp::verify() {
636  mlir::Type inputETy =
637  llvm::cast<ShapedType>(getInput().getType()).getElementType();
638  if (auto quantType =
639  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
640  inputETy = quantType.getStorageType();
641  }
642  mlir::Type outputETy =
643  llvm::cast<ShapedType>(getOutput().getType()).getElementType();
644  if (auto quantType =
645  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
646  outputETy = quantType.getStorageType();
647  }
648  if (inputETy != outputETy)
649  return emitOpError("input/output element types are incompatible.");
650 
651  auto maxValAttr = getMaxValAttr();
652  auto minValAttr = getMinValAttr();
653 
654  unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
655 
656  if (inputETy.isInteger(dataTypeBitWidth)) {
657  // if input datatype is integer, check that the min_val/max_val attributes
658  // are integer attributes, and that their type is the same as the input's
659  // datatype
660  auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
661  auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
662  if (!intMaxValAttr || !intMinValAttr ||
663  (intMaxValAttr.getType() != intMinValAttr.getType()) ||
664  (intMaxValAttr.getType() != inputETy))
665  return emitOpError("min/max attributes types are incompatible with "
666  "input/output element types.");
667  } else {
668  // otherwise, input datatype is float, check that the min_val/max_val
669  // attributes share the same type and that their type is the same as the
670  // input's datatype
671  auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
672  auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
673  if (!floatMaxValAttr || !floatMinValAttr ||
674  (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
675  (floatMaxValAttr.getType() != inputETy))
676  return emitOpError("min/max attributes types are incompatible with "
677  "input/output element types.");
678  }
679 
680  return success();
681 }
682 
683 //===----------------------------------------------------------------------===//
684 // TOSA Operator Quantization Builders.
685 //===----------------------------------------------------------------------===//
686 
687 /// This builder is called on all convolution operators except TransposeConv,
688 /// which has specialized output shape semantics. The builder also defines the
689 /// bitwidth of the output given the bit width of the input & weight content.
690 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
691  Type outputType, Value input, Value weight,
692  Value bias, DenseI64ArrayAttr pad,
693  DenseI64ArrayAttr stride,
694  DenseI64ArrayAttr dilation,
695  TypeAttr accType) {
696  auto zps = createZPsAsConst(builder, input, weight);
697  result.addOperands({input, weight, bias, zps.first, zps.second});
698  result.addAttribute("pad", pad);
699  result.addAttribute("stride", stride);
700  result.addAttribute("dilation", dilation);
701  result.addAttribute("acc_type", accType);
702  Type finalOutputType = outputType;
703  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
704  if (quantAttr) {
705  finalOutputType =
706  buildConvOpResultTypeInfo(builder, outputType, input, weight);
707  }
708  result.addTypes(finalOutputType);
709 }
710 
711 /// Handles tosa.transpose_conv2d which has outpad and output shape
712 /// attributes.
713 static void
715  Type outputType, Value input, Value weight,
716  Value bias, DenseI64ArrayAttr outpad,
717  DenseI64ArrayAttr stride, TypeAttr accType) {
718  auto zps = createZPsAsConst(builder, input, weight);
719  result.addOperands({input, weight, bias, zps.first, zps.second});
720  result.addAttribute("out_pad", outpad);
721  result.addAttribute("stride", stride);
722  result.addAttribute("acc_type", accType);
723  Type finalOutputType = outputType;
724  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
725  if (quantAttr) {
726  finalOutputType =
727  buildConvOpResultTypeInfo(builder, outputType, input, weight);
728  }
729  result.addTypes(finalOutputType);
730 }
731 
732 /// The tosa.matmul op is also intended to be generated where a fully_connected
733 /// op must be constructed where the weight is not a constant. In this case,
734 /// the fully_connected op must be expressed using matmul.
735 /// TODO: Add link to the leglization document explaining this.
737  OperationState &result, Type outputType,
738  Value a, Value b) {
739  auto zps = createZPsAsConst(builder, a, b);
740  result.addOperands({a, b, zps.first, zps.second});
741 
742  Type finalOutputType{outputType};
743  if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
744  auto eType = getStorageElementTypeOrSelf(a.getType());
745  auto inputBits = eType.getIntOrFloatBitWidth();
746 
747  auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
748  assert(outputShapedType && "Output must be a shaped type");
749 
750  IntegerType accElementType;
751  if (inputBits == 16)
752  accElementType = builder.getIntegerType(48);
753  else
754  accElementType = builder.getI32Type();
755 
756  finalOutputType = outputShapedType.clone(accElementType);
757  }
758  result.addTypes(finalOutputType);
759 }
760 
761 /// Both the tosa.avg_pool2d and unary ops use the same
762 /// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
763 /// has additional parameters not part of the unary ops.
764 static void
766  Type outputType, Value input,
767  DenseArrayAttr kernel, DenseArrayAttr stride,
768  DenseArrayAttr pad, TypeAttr accType) {
769  const Location loc{result.location};
770  int64_t inputZp{0};
771  int64_t outputZp{0};
772 
773  if (auto quantAttr =
774  buildUnaryOpQuantizationAttr(builder, input, outputType)) {
775  inputZp = quantAttr.getInputZp();
776  outputZp = quantAttr.getOutputZp();
777  }
778  const std::optional<Value> inputZpOp =
779  createZeroPointTensor(builder, loc, input.getType(), inputZp);
780  if (!inputZpOp) {
781  (void)emitError(
782  loc,
783  "Failed to create input zero point tensor for quantized AVG_POOL2D op");
784  }
785  const std::optional<Value> outputZpOp =
786  createZeroPointTensor(builder, loc, outputType, outputZp);
787  if (!outputZpOp) {
788  (void)emitError(loc, "Failed to create output zero point tensor for "
789  "quantized AVG_POOL2D op");
790  }
791 
792  if (inputZpOp && outputZpOp) {
793  result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
794  } else {
795  // failed to create one or more zero points above: just add input as
796  // operands this will trigger error in building the op because of missing
797  // zero points
798  result.addOperands({input});
799  }
800  result.addAttribute("kernel", kernel);
801  result.addAttribute("stride", stride);
802  result.addAttribute("pad", pad);
803  result.addAttribute("acc_type", accType);
804  result.types.push_back(outputType);
805 }
806 
807 /// This builder is called on single-parameter negate operator
808 /// to construct input and output zero points based on their
809 /// types.
811  OperationState &result, Type outputType,
812  Value input) {
813  const Location loc{result.location};
814  int64_t input1Zp{0};
815  int64_t outputZp{0};
816  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
817  if (quantAttr) {
818  input1Zp = quantAttr.getInputZp();
819  outputZp = quantAttr.getOutputZp();
820  }
821  const std::optional<Value> input1ZpOp =
822  createZeroPointTensor(builder, loc, input.getType(), input1Zp);
823  if (!input1ZpOp) {
824  (void)emitError(
825  loc, "Failed to create input1 zero point for quantized NEGATE op");
826  }
827 
828  const std::optional<Value> outputZpOp =
829  createZeroPointTensor(builder, loc, input.getType(), outputZp);
830  if (!outputZpOp) {
831  (void)emitError(
832  loc, "Failed to create output zero point for quantized NEGATE op");
833  }
834 
835  if (input1ZpOp && outputZpOp) {
836  result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
837  } else {
838  // failed to create one or more zero points above: just add input as
839  // operands. This will trigger error in building the op because of
840  // missing zero points
841  result.addOperands({input});
842  }
843 
844  result.types.push_back(outputType);
845 }
846 
847 /// This builder is called on TOSA pad operator that needs to create its own
848 /// OptionalAttr quantization_attr parameter to scale the padding values
849 /// correctly. No pad_const is interpreted as zero-padding.
850 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
851  Type outputType, Value input,
852  Value paddings) {
853  const Location loc{result.location};
854  int32_t zp{0};
855  const auto quantAttr = buildPadOpQuantizationAttr(builder, input);
856  if (quantAttr) {
857  zp = static_cast<int32_t>(quantAttr.getInputZp());
858  }
859  const auto padConstOp{createPadConstTensor(builder, loc, input, zp)};
860  result.addOperands({input, paddings, padConstOp});
861  result.types.push_back(outputType);
862 }
863 
864 //===----------------------------------------------------------------------===//
865 // TOSA Operator Return Type Inference.
866 //===----------------------------------------------------------------------===//
867 
868 static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
869  SmallVector<int64_t> &outShape) {
870  int64_t outRank = 0;
871  for (int i = 0, e = operands.size(); i != e; ++i) {
872  auto shape = operands.getShape(i);
873  if (!shape.hasRank()) {
874  // TODO(jennik): Update function to have better case handling for
875  // invalid operands and for ranked tensors.
876  return failure();
877  }
878  outRank = std::max<int64_t>(outRank, shape.getRank());
879  }
880 
881  outShape.resize(outRank, 1);
882 
883  for (int i = 0, e = operands.size(); i != e; ++i) {
884  auto shape = operands.getShape(i);
885  auto rankDiff = outShape.size() - shape.getRank();
886 
887  for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
888  auto dim1 = outShape[i + rankDiff];
889  auto dim2 = shape.getDimSize(i);
890  auto resolvedDim = dim1;
891 
892  if (dim1 == 1) {
893  resolvedDim = dim2;
894  } else if (dim2 == 1) {
895  resolvedDim = dim1;
896  } else if (dim1 != dim2) {
897  return failure();
898  }
899  outShape[i + rankDiff] = resolvedDim;
900  }
901  }
902 
903  return success();
904 }
905 
906 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
907  MLIRContext *context, ::std::optional<Location> location,
908  ArgMaxOp::Adaptor adaptor,
909  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
910  ShapeAdaptor inputShape(adaptor.getInput().getType());
911  IntegerAttr axis = adaptor.getProperties().axis;
912  int32_t axisVal = axis.getValue().getSExtValue();
913 
914  if (!inputShape.hasRank()) {
915  inferredReturnShapes.push_back(ShapedTypeComponents());
916  return success();
917  }
918 
919  SmallVector<int64_t> outShape;
920  outShape.reserve(inputShape.getRank() - 1);
921  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
922  if (i == axisVal)
923  continue;
924  outShape.push_back(inputShape.getDimSize(i));
925  }
926 
927  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
928  return success();
929 }
930 
931 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
932  MLIRContext *context, ::std::optional<Location> location,
933  RFFT2dOp::Adaptor adaptor,
934  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
935  ShapeAdaptor inputShape(adaptor.getInputReal().getType());
936 
937  if (!inputShape.hasRank())
938  return failure();
939 
940  llvm::SmallVector<int64_t> outputShape;
941  outputShape.resize(3, ShapedType::kDynamic);
942  outputShape[0] = inputShape.getDimSize(0);
943  outputShape[1] = inputShape.getDimSize(1);
944  int64_t inWidth = inputShape.getDimSize(2);
945 
946  // Note that we can support this calculation symbolically
947  // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
948  if (inWidth != ShapedType::kDynamic)
949  outputShape[2] = inWidth / 2 + 1;
950 
951  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
952  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
953 
954  return success();
955 }
956 
957 static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
958  const llvm::StringRef dimName) {
959  const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
960  if (!isPowerOfTwo)
961  return op->emitOpError("expected ")
962  << dimName << " to be a power of two, got " << dimSize;
963 
964  return success();
965 }
966 
967 LogicalResult tosa::RFFT2dOp::verify() {
968  const auto outputTypes = getResultTypes();
969  if (failed(verifyCompatibleShapes(outputTypes)))
970  return emitOpError("expected output shapes to match, got ") << outputTypes;
971 
972  const auto inputType =
973  llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
974  if (!inputType)
975  return success();
976 
977  const int64_t height = inputType.getDimSize(1);
978  if (!ShapedType::isDynamic(height) &&
979  failed(verifyDimIsPowerOfTwo(*this, height, "height")))
980  return failure();
981 
982  const int64_t width = inputType.getDimSize(2);
983  if (!ShapedType::isDynamic(width) &&
984  failed(verifyDimIsPowerOfTwo(*this, width, "width")))
985  return failure();
986 
987  const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
988  if (!outputType)
989  return success();
990 
991  // Batch and height input/output dimensions should match
992  if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
993  outputType.getShape().drop_back())))
994  return emitOpError("expected batch and height dimensions of input/output "
995  "to match, got input=")
996  << inputType << " output=" << outputType;
997 
998  // Output width dimension expected to be input_width / 2 + 1
999  const int64_t outputWidth = outputType.getDimSize(2);
1000  if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) &&
1001  (outputWidth != (width / 2) + 1))
1002  return emitOpError(
1003  "expected output width to be equal to input_width / 2 + 1, got ")
1004  << outputWidth;
1005 
1006  return success();
1007 }
1008 
1009 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1010  MLIRContext *context, ::std::optional<Location> location,
1011  FFT2dOp::Adaptor adaptor,
1012  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1013  inferredReturnShapes.push_back(
1014  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
1015  inferredReturnShapes.push_back(
1016  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
1017  return success();
1018 }
1019 
1020 LogicalResult tosa::FFT2dOp::verify() {
1021  const auto inputRealType =
1022  llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1023  const auto inputImagType =
1024  llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
1025  if (!inputRealType || !inputImagType)
1026  return success();
1027 
1028  const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
1029  return ShapedType::isDynamic(a) ? a : b;
1030  };
1031 
1032  const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1033  inputImagType.getDimSize(1));
1034  if (!ShapedType::isDynamic(height) &&
1035  failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1036  return failure();
1037 
1038  const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1039  inputImagType.getDimSize(2));
1040  if (!ShapedType::isDynamic(width) &&
1041  failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1042  return failure();
1043 
1044  return success();
1045 }
1046 
1047 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1048  MLIRContext *context, ::std::optional<Location> location,
1049  ConcatOp::Adaptor adaptor,
1050  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1051  // Infer all dimension sizes by reducing based on inputs.
1052  const Properties &prop = adaptor.getProperties();
1053  int32_t axis = prop.axis.getValue().getSExtValue();
1054  llvm::SmallVector<int64_t> outputShape;
1055  bool hasRankedInput = false;
1056  for (auto operand : adaptor.getOperands()) {
1057  ShapeAdaptor operandShape(operand.getType());
1058  if (!operandShape.hasRank())
1059  continue;
1060 
1061  // Copy the Operand's rank.
1062  if (!hasRankedInput)
1063  outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1064 
1065  // Copy shapes until the dim is non-dynamic.
1066  for (int i = 0, s = operandShape.getRank(); i < s; i++) {
1067  if (i == axis || operandShape.isDynamicDim(i))
1068  continue;
1069  if (outputShape[i] == ShapedType::kDynamic)
1070  outputShape[i] = operandShape.getDimSize(i);
1071  if (outputShape[i] != operandShape.getDimSize(i))
1072  return emitOptionalError(location,
1073  "Cannot concat tensors with different sizes"
1074  " on the non-axis dimension ",
1075  i);
1076  }
1077 
1078  hasRankedInput = true;
1079  }
1080  Type inputType =
1081  llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1082  if (!hasRankedInput) {
1083  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1084  return success();
1085  }
1086 
1087  // Determine the dimension size along the concatenation axis.
1088  int64_t concatDimSize = 0;
1089  for (auto operand : adaptor.getOperands()) {
1090  ShapeAdaptor operandShape(operand.getType());
1091 
1092  // We need to know the length of the concatenation axis of all inputs to
1093  // determine the dimension size of the output shape.
1094  if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1095  concatDimSize = ShapedType::kDynamic;
1096  break;
1097  }
1098 
1099  concatDimSize += operandShape.getDimSize(axis);
1100  }
1101 
1102  outputShape[axis] = concatDimSize;
1103 
1104  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1105  return success();
1106 }
1107 
1108 LogicalResult tosa::ConcatOp::verify() {
1109  // check that each input has same element type as output
1110  auto outType = getOutput().getType();
1111  const Operation::operand_range inputList = getInput1();
1112 
1113  // Check there is at least one input
1114  if (inputList.empty())
1115  return emitOpError("expect at least one input");
1116 
1117  if (!llvm::all_of(inputList, [&](auto input) {
1118  return succeeded(verifySameElementTypes(
1119  *this, /* inType = */ input.getType(), outType));
1120  })) {
1121  return failure();
1122  }
1123 
1124  const int32_t axis = getAxis();
1125  ShapeAdaptor firstRankedInputShape = nullptr;
1126  for (const auto &input : inputList) {
1127  const Type inputType = input.getType();
1128  ShapeAdaptor currShape(inputType);
1129  if (currShape.hasRank()) {
1130  firstRankedInputShape = currShape;
1131  // Check axis is in expected range
1132  if (axis < 0 || axis >= firstRankedInputShape.getRank())
1133  return emitOpError("expect axis to be within range 0 < axis < "
1134  "rank(input1[firstRankedTensorIdx]), got ")
1135  << axis;
1136  break;
1137  }
1138  }
1139 
1140  const auto allOperandsHasRank = [](const Value input) {
1141  return ShapeAdaptor(input.getType()).hasRank();
1142  };
1143  if (llvm::all_of(inputList, allOperandsHasRank)) {
1144  const int64_t firstInputRank = firstRankedInputShape.getRank();
1145 
1146  for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {
1147  const ShapeAdaptor inputShape(input.getType());
1148  const int64_t inputRank = inputShape.getRank();
1149  const size_t operandNum = index + 1;
1150 
1151  // Check that each operand has the same rank
1152  if (inputRank != firstInputRank)
1153  return emitOpError(
1154  "expect all operands to have the same rank, but got ")
1155  << firstInputRank << " vs " << inputRank << " on operands 0 and "
1156  << operandNum;
1157 
1158  // Check non-axis dims match
1159  for (int i = 0; i < inputRank; i++) {
1160  const int64_t inputDim = inputShape.getDimSize(i);
1161  const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
1162  if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
1163  inputShape.isDynamicDim(i))
1164  continue;
1165  if (inputDim != firstInputDim)
1166  return emitOpError("expect all operand shapes to have the same sizes "
1167  "on non-axis dimensions, but got ")
1168  << inputDim << " vs " << firstInputDim << " at index " << i
1169  << " on operands 0 and " << operandNum;
1170  }
1171  }
1172  }
1173 
1174  return success();
1175 }
1176 
1177 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1178  MLIRContext *context, ::std::optional<Location> location,
1179  ValueShapeRange operands, DictionaryAttr attributes,
1180  OpaqueProperties properties, RegionRange regions,
1181  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1182  auto elementType = IntegerType::get(context, /*width=*/1);
1183 
1184  llvm::SmallVector<int64_t> outShape;
1185  if (resolveBroadcastShape(operands, outShape).failed()) {
1186  inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
1187  return success();
1188  }
1189 
1190  inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
1191  return success();
1192 }
1193 
1194 bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1195  if (l.size() != r.size() || l.size() != 1)
1196  return false;
1197  return succeeded(verifyCompatibleShape(l[0], r[0]));
1198 }
1199 
1200 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1201  MLIRContext *context, ::std::optional<Location> location,
1202  MatMulOp::Adaptor adaptor,
1203  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1204  ShapeAdaptor lhsShape(adaptor.getA().getType());
1205  ShapeAdaptor rhsShape(adaptor.getB().getType());
1206 
1207  // All shapes are dynamic.
1208  SmallVector<int64_t> outShape;
1209  outShape.resize(3, ShapedType::kDynamic);
1210 
1211  if (lhsShape.hasRank()) {
1212  outShape[0] = lhsShape.getDimSize(0);
1213  outShape[1] = lhsShape.getDimSize(1);
1214  }
1215 
1216  if (rhsShape.hasRank()) {
1217  outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1218  : outShape[0];
1219  outShape[2] = rhsShape.getDimSize(2);
1220  }
1221 
1222  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1223  return success();
1224 }
1225 
1226 LogicalResult MatMulOp::verify() {
1227  auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
1228  auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
1229 
1230  // Must be shaped tensor types
1231  if (!aType)
1232  return emitOpError("expect a shaped tensor for input a, got ")
1233  << getA().getType();
1234 
1235  if (!bType)
1236  return emitOpError("expect a shaped tensor for input b, got ")
1237  << getB().getType();
1238 
1239  auto aElementType = aType.getElementType();
1240  auto bElementType = bType.getElementType();
1241 
1242  auto aQuantizedEType =
1243  llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1244  auto bQuantizedEType =
1245  llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1246 
1247  if (aQuantizedEType || bQuantizedEType) {
1248  if (!aQuantizedEType || !bQuantizedEType) {
1249  return emitOpError("expect operands to be both quantized or both not "
1250  "quantized, got ")
1251  << aElementType << " and " << bElementType;
1252  }
1253  // both a and b have quantized element types
1254  auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1255  auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1256  if (aQuantWidth != bQuantWidth) {
1257  return emitOpError("expect quantized operands to have same widths, got ")
1258  << aQuantWidth << " and " << bQuantWidth;
1259  }
1260  } else {
1261  // non-quantized element types
1262  if (aElementType != bElementType) {
1263  return emitOpError("expect same element type for inputs a and b, got ")
1264  << aElementType << " and " << bElementType;
1265  }
1266  }
1267 
1268  // check a_zp and b_zp
1269  auto aEType = getStorageElementTypeOrSelf(aType);
1270  auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
1271  if (aEType != aZpEType) {
1272  return emitOpError("expect input a and a_zp have the same "
1273  "element type, got ")
1274  << aEType << " and " << aZpEType;
1275  }
1276 
1277  auto bEType = getStorageElementTypeOrSelf(bType);
1278  auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
1279  if (bEType != bZpEType) {
1280  return emitOpError("expect input b and b_zp have the same "
1281  "element type, got ")
1282  << bEType << " and " << bZpEType;
1283  }
1284 
1285  FailureOr<int64_t> maybeAZp = getAZeroPoint();
1286  if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1287  return failure();
1288 
1289  FailureOr<int64_t> maybeBZp = getBZeroPoint();
1290  if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1291  return failure();
1292 
1293  return success();
1294 }
1295 
1296 LogicalResult tosa::PadOp::inferReturnTypeComponents(
1297  MLIRContext *context, ::std::optional<Location> location,
1298  PadOp::Adaptor adaptor,
1299  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1300  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1301  auto paddingRank =
1302  cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
1303  SmallVector<int64_t> outputShape;
1304 
1305  // If the input rank is unknown, we can infer the output rank using the
1306  // padding shape's rank divided by 2.
1307  if (!inputShape.hasRank()) {
1308  outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
1309  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1310  return success();
1311  }
1312 
1313  SmallVector<int64_t> paddingValues;
1314  // If the paddings value is not a constant, all dimensions must be dynamic.
1315  if (!tosa::getConstShapeValues(adaptor.getPadding().getDefiningOp(),
1316  paddingValues)) {
1317  outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1318  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1319  return success();
1320  }
1321 
1322  outputShape.reserve(inputShape.getRank());
1323  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1324  if (inputShape.isDynamicDim(i)) {
1325  outputShape.push_back(ShapedType::kDynamic);
1326  continue;
1327  }
1328  auto padFront = paddingValues[i * 2];
1329  auto padBack = paddingValues[i * 2 + 1];
1330  if (padFront < 0 || padBack < 0) {
1331  // if either padding for dim i is -1, output dim is unknown
1332  outputShape.push_back(ShapedType::kDynamic);
1333  continue;
1334  }
1335 
1336  outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
1337  }
1338 
1339  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1340  return success();
1341 }
1342 
1343 LogicalResult tosa::PadOp::verify() {
1344  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1345  /* outType = */ getOutput().getType())
1346  .failed()) {
1347  return failure();
1348  }
1349 
1350  if (auto padConst = getPadConst()) {
1351  if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
1352  /* outType = */ getOutput().getType())
1353  .failed()) {
1354  return failure();
1355  }
1356  }
1357 
1358  RankedTensorType inputType = getInput1().getType();
1359  RankedTensorType outputType = getOutput().getType();
1360  auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
1361 
1362  if (inputType.getRank() != outputType.getRank())
1363  return emitOpError() << "expect same input and output tensor rank.";
1364 
1365  if (paddingRank != inputType.getRank() * 2)
1366  return emitOpError() << "expected padding tensor dim 0 to have size "
1367  << inputType.getRank() * 2
1368  << " (2*rank(shape1)) but got size " << paddingRank;
1369 
1370  return success();
1371 }
1372 
1374  return to_vector(llvm::map_range(shape, [](int64_t dim) {
1375  return dim == -1 ? ShapedType::kDynamic : dim;
1376  }));
1377 }
1378 
1379 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
1380  MLIRContext *context, ::std::optional<Location> location,
1381  SliceOp::Adaptor adaptor,
1382  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1383 
1384  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
1385  SmallVector<int64_t> start;
1386  SmallVector<int64_t> size;
1387 
1388  if (!tosa::getConstShapeValues(adaptor.getStart().getDefiningOp(), start) ||
1389  !tosa::getConstShapeValues(adaptor.getSize().getDefiningOp(), size)) {
1390  auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
1391  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
1392  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
1393  return success();
1394  }
1395 
1396  // if size[i] is -1, all remaining elements in dimension i are included
1397  // in the slice, similar to TF.
1398  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1399  // initialize outputShape to all unknown
1400  SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
1401  if (inputShape.hasRank()) {
1402  for (size_t i = 0; i < size.size(); i++) {
1403  if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
1404  (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
1405  start[i] < inputShape.getDimSize(i))) {
1406  // size[i] is not 0 and not < -1, and start[i] is in valid range
1407  if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
1408  // input shape has unknown dim[i] - only valid if size[i] > 0
1409  if (size[i] > 0) {
1410  outputShape[i] = size[i];
1411  }
1412  } else {
1413  // input shape has known dim[i]
1414  if (size[i] == -1) {
1415  outputShape[i] = inputShape.getDimSize(i) - start[i];
1416  } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
1417  // start[i] + size[i] is within bound of input shape's dim[i]
1418  outputShape[i] = size[i];
1419  }
1420  }
1421  }
1422  }
1423  } else {
1424  outputShape = convertToMlirShape(size);
1425  }
1426  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1427  return success();
1428 }
1429 
1430 LogicalResult tosa::SliceOp::verify() {
1431  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1432  /* outType = */ getOutput().getType())
1433  .failed())
1434  return failure();
1435  auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1436  if (!inputType)
1437  return success();
1438 
1439  auto startShapeRank =
1440  llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
1441  if (inputType.getRank() != startShapeRank)
1442  return emitOpError("length of start is not equal to rank of input shape");
1443 
1444  auto sizeShapeRank =
1445  llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
1446  if (inputType.getRank() != sizeShapeRank)
1447  return emitOpError("length of size is not equal to rank of input shape");
1448 
1449  return success();
1450 }
1451 
1452 LogicalResult tosa::MulOp::inferReturnTypeComponents(
1453  MLIRContext *context, ::std::optional<Location> location,
1454  ValueShapeRange operands, DictionaryAttr attributes,
1455  OpaqueProperties properties, RegionRange regions,
1456  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1457  // mul op's output shape only depend on input1 and input2, not on shift
1458  ValueShapeRange twoInputs = operands.drop_back();
1459  llvm::SmallVector<int64_t> outShape;
1460  if (resolveBroadcastShape(twoInputs, outShape).failed()) {
1461  inferredReturnShapes.push_back(ShapedTypeComponents());
1462  } else {
1463  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1464  }
1465  return success();
1466 }
1467 
1468 LogicalResult tosa::MulOp::verify() {
1469  auto resElemType = getElementTypeOrSelf(getOutput());
1470 
1471  // Verify if the element type among operands and result match tosa
1472  // specification.
1473  if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
1474  IntegerType lhsIntType =
1475  cast<IntegerType>(getElementTypeOrSelf(getInput1()));
1476  IntegerType rhsIntType =
1477  cast<IntegerType>(getElementTypeOrSelf(getInput2()));
1478  if (lhsIntType != rhsIntType)
1479  return emitOpError("requires the same element type for all operands");
1480 
1481  // Though the spec requires the element type of result to be i32, a more
1482  // relaxed way is provided at dialect level for easier cooperating with
1483  // other dialects.
1484  if (lhsIntType.getWidth() > resIntType.getWidth())
1485  return emitOpError("invalid data type size for operands or result");
1486 
1487  } else {
1488  // For other supported type, the spec requires requires the same element
1489  // type for all operands (excludes `shift` operand) and results.
1490  for (int i = 0; i < 2; ++i) {
1491  if (getElementTypeOrSelf(getOperand(i)) != resElemType)
1492  return emitOpError(
1493  "requires the same element type for all operands and results");
1494  }
1495 
1496  // verify shift has value 0 for non-integer types
1497  ElementsAttr shift_elem;
1498  if (matchPattern(getShift(), m_Constant(&shift_elem))) {
1499  int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1500  if (shift != 0) {
1501  return emitOpError() << "require shift to be 0 for float type";
1502  }
1503  }
1504  }
1505 
1506  // Verify the op has same ranks for all main operands (excludes extra operands
1507  // such as shift of mul op, so this is the only difference with the built-in
1508  // `SameOperandsAndResultRank` trait) and results types, if known.
1509 
1510  // delegate function that returns true if type is a shaped type with known
1511  // rank
1512  auto hasRank = [](const Type type) {
1513  if (auto shaped_type = dyn_cast<ShapedType>(type))
1514  return shaped_type.hasRank();
1515 
1516  return false;
1517  };
1518 
1519  auto rankedOperandTypes =
1520  llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));
1521 
1522  auto rankedResultTypes =
1523  llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
1524 
1525  // If all operands and results are unranked, then no further verification.
1526  if (rankedOperandTypes.empty() && rankedResultTypes.empty())
1527  return success();
1528 
1529  // delegate function that returns rank of shaped type with known rank
1530  auto getRank = [](const Type type) {
1531  return cast<ShapedType>(type).getRank();
1532  };
1533 
1534  auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
1535  : getRank(*rankedResultTypes.begin());
1536 
1537  for (size_t i = 0; i < 2; ++i) {
1538  if (rank != getRank(rankedOperandTypes[i])) {
1539  return emitOpError("operands don't have matching ranks");
1540  }
1541  }
1542 
1543  for (const auto type : rankedResultTypes) {
1544  if (rank != getRank(type)) {
1545  return emitOpError("result type has different rank than operands");
1546  }
1547  }
1548 
1549  // check for broadcast compatible shapes in first two operands (ignoring
1550  // shift)
1551 
1552  // delegate function that returns shape of shaped type
1553  auto getShape = [](const Type type) {
1554  return mlir::cast<ShapedType>(type).getShape();
1555  };
1556  SmallVector<int64_t> resultShape;
1557  if (!mlir::OpTrait::util::getBroadcastedShape(getShape(rankedOperandTypes[0]),
1558  getShape(rankedOperandTypes[1]),
1559  resultShape)) {
1560  return emitOpError("operands don't have broadcast-compatible shapes");
1561  }
1562 
1563  return success();
1564 }
1565 
1566 LogicalResult tosa::TableOp::inferReturnTypeComponents(
1567  MLIRContext *context, ::std::optional<Location> location,
1568  TableOp::Adaptor adaptor,
1569  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1570  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1571 
1572  if (!inputShape.hasRank()) {
1573  inferredReturnShapes.push_back(ShapedTypeComponents());
1574  return success();
1575  }
1576 
1577  inferredReturnShapes.resize(1);
1578  inputShape.getDims(inferredReturnShapes[0]);
1579  return success();
1580 }
1581 
1582 LogicalResult tosa::TableOp::verify() {
1583  TensorType inputType = getInput1().getType();
1584  TensorType outputType = getOutput().getType();
1585 
1586  if (inputType.hasRank() && outputType.hasRank() &&
1587  inputType.getRank() != outputType.getRank())
1588  return emitOpError()
1589  << "expected input tensor rank to equal result tensor rank";
1590 
1591  auto inputDims = inputType.getShape();
1592  auto outputDims = outputType.getShape();
1593  for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
1594  int64_t dim = it.index();
1595  auto [inputDim, outputDim] = it.value();
1596  if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) {
1597  return emitOpError() << "dim(result, " << dim << ") = " << outputDim
1598  << " doesn't match dim(input, " << dim
1599  << ") = " << inputDim;
1600  }
1601  }
1602  return success();
1603 }
1604 
1605 LogicalResult
1606 tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
1607  // Multiples must be constants.
1608  DenseIntElementsAttr multiplesAttr;
1609  if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
1610  return failure();
1611  multiples = llvm::to_vector(
1612  llvm::map_range(multiplesAttr.getValues<APInt>(),
1613  [](const APInt &val) { return val.getSExtValue(); }));
1614  return success();
1615 }
1616 
1617 LogicalResult tosa::TileOp::inferReturnTypeComponents(
1618  MLIRContext *context, ::std::optional<Location> location,
1619  TileOp::Adaptor adaptor,
1620  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1621  DenseIntElementsAttr multiplesAttr;
1622  if (!matchPattern(adaptor.getMultiples(), m_Constant(&multiplesAttr)))
1623  return failure();
1624 
1625  SmallVector<int64_t> multiples = llvm::to_vector(
1626  llvm::map_range(multiplesAttr.getValues<APInt>(),
1627  [](const APInt &val) { return val.getSExtValue(); }));
1628 
1629  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1630  SmallVector<int64_t> outputShape;
1631  if (!inputShape.hasRank()) {
1632  outputShape.resize(multiples.size(), ShapedType::kDynamic);
1633  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1634  return success();
1635  } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
1636  return failure();
1637 
1638  // Any non dynamic dimension can be multiplied to a known size.
1639  outputShape.reserve(multiples.size());
1640  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1641  int64_t dim = inputShape.getDimSize(i);
1642  if (dim != ShapedType::kDynamic)
1643  dim *= multiples[i];
1644  outputShape.push_back(dim);
1645  }
1646 
1647  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1648  return success();
1649 }
1650 
1651 LogicalResult tosa::TileOp::verify() {
1652  if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
1653  /* outType = */ getOutput().getType())
1654  .failed()) {
1655  return failure();
1656  }
1657  ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
1658  ShapedType outputType = llvm::cast<ShapedType>(getType());
1659 
1660  shapeType multiplesType =
1661  llvm::cast<tosa::shapeType>(getMultiples().getType());
1662 
1663  auto multiplesRank = multiplesType.getRank();
1664 
1665  if (inputType.hasRank()) {
1666  if (inputType.getRank() != multiplesRank)
1667  return emitOpError("expect 'multiples' to have rank ")
1668  << inputType.getRank() << " but got " << multiplesRank << ".";
1669  if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
1670  return emitOpError("expect same input and output tensor rank.");
1671  } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
1672  return emitOpError("expect 'multiples' array to have length ")
1673  << outputType.getRank() << " but got " << multiplesRank << ".";
1674 
1675  SmallVector<int64_t> multiples;
1676  if (getConstantMultiples(multiples).succeeded() &&
1677  llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
1678  return emitOpError(
1679  "expect element of 'multiples' to be positive integer or -1.");
1680 
1681  return success();
1682 }
1683 
1684 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1685  if (l.size() != r.size() || l.size() != 1)
1686  return false;
1687  return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
1688 }
1689 
1690 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
1691  MLIRContext *context, ::std::optional<Location> location,
1692  ReshapeOp::Adaptor adaptor,
1693  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1694  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1695  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
1696  llvm::SmallVector<int64_t> newShapeValue;
1697  if (!tosa::getConstShapeValues(adaptor.getShape().getDefiningOp(),
1698  newShapeValue)) {
1699  auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
1700  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
1701  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
1702  return success();
1703  } else {
1704  newShapeValue = convertToMlirShape(newShapeValue);
1705  }
1706 
1707  // We cannot infer from the total number of elements so we must take the
1708  // shape attribute as exact.
1709  if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
1710  inferredReturnShapes.push_back(
1711  ShapedTypeComponents(newShapeValue, inputType));
1712  return success();
1713  }
1714 
1715  // Determine the number of elements covered by the slice of all static
1716  // dimensions. This allows us to infer the length of the remaining dynamic
1717  // dimension.
1718  int64_t numElements = inputShape.getNumElements();
1719  int64_t staticMul = 1;
1720  for (auto val : newShapeValue) {
1721  if (!ShapedType::isDynamic(val)) {
1722  staticMul *= val;
1723  }
1724  }
1725 
1726  // Determine the length of the dynamic dimension.
1727  for (auto &val : newShapeValue) {
1728  if (ShapedType::isDynamic(val))
1729  val = numElements / staticMul;
1730  }
1731 
1732  inferredReturnShapes.push_back(
1733  ShapedTypeComponents(newShapeValue, inputType));
1734  return success();
1735 }
1736 
1737 llvm::LogicalResult tosa::ReshapeOp::verify() {
1738  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1739  /* outType = */ getOutput().getType())
1740  .failed()) {
1741  return failure();
1742  }
1743  TensorType inputType = getInput1().getType();
1744  RankedTensorType outputType = getType();
1745 
1746  SmallVector<int64_t> shapeValues;
1747  if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
1748  // skip following checks if shape is not constant
1749  return mlir::success();
1750  }
1751 
1752  if ((int64_t)shapeValues.size() != outputType.getRank())
1753  return emitOpError() << "new shape does not match result rank";
1754 
1755  for (auto [newShapeDim, outputShapeDim] :
1756  zip(shapeValues, outputType.getShape())) {
1757  if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
1758  outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
1759  return emitOpError() << "new shape is inconsistent with result shape";
1760 
1761  if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
1762  return emitOpError() << "new shape has invalid tensor dimension size "
1763  << newShapeDim;
1764  }
1765 
1766  if (inputType.hasStaticShape()) {
1767  int64_t inputElementsNum = inputType.getNumElements();
1768  if (outputType.hasStaticShape()) {
1769  int64_t outputElementsNum = outputType.getNumElements();
1770  if (inputElementsNum != outputElementsNum) {
1771  return emitOpError() << "cannot reshape " << inputElementsNum
1772  << " elements into " << outputElementsNum;
1773  }
1774  }
1775 
1776  int64_t newShapeElementsNum = std::accumulate(
1777  shapeValues.begin(), shapeValues.end(), 1LL,
1778  [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
1779  bool isStaticNewShape =
1780  llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
1781  if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
1782  (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
1783  return emitOpError() << "cannot reshape " << inputElementsNum
1784  << " elements into " << newShapeElementsNum;
1785  }
1786  }
1787 
1788  int missingDims = llvm::count(shapeValues, -1);
1789  if (missingDims > 1)
1790  return emitOpError() << "expected at most one target dimension to be -1";
1791 
1792  return mlir::success();
1793 }
1794 
1795 // return failure if val is not a constant
1796 // set zp to -1 if val is non-zero float or val is not integer nor float
1797 // otherwise set zp to val's constant value
1798 template <typename T>
1799 static FailureOr<int64_t> getZeroPoint(T op, Value val) {
1800  ElementsAttr zpAttr;
1801  if (!matchPattern(val, m_Constant(&zpAttr))) {
1802  return failure();
1803  }
1804 
1805  Type zpElemType = zpAttr.getElementType();
1806 
1807  if (llvm::isa<FloatType>(zpElemType)) {
1808  if (zpAttr.getValues<APFloat>()[0].isZero()) {
1809  return 0;
1810  }
1811  // return non-zero value to trigger error check
1812  return -1;
1813  }
1814 
1815  if (llvm::isa<IntegerType>(zpElemType)) {
1816  return zpAttr.getValues<APInt>()[0].getSExtValue();
1817  }
1818 
1819  // return non-zero value to trigger error check
1820  return -1;
1821 }
1822 
1823 template <typename T>
1824 static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
1825  const std::string &operand) {
1826  Type zpElemType = getElementTypeOrSelf(val);
1827 
1828  if (!zpElemType.isInteger(8) && zp != 0) {
1829  // convert operand to lower case for error message
1830  std::string lower = operand;
1831  std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
1832  return op.emitOpError()
1833  << lower << " zero point must be zero for non-int8 integer types";
1834  }
1835 
1836  return success();
1837 }
1838 
1839 static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
1840  const int64_t &zp,
1841  const std::string &operand) {
1842  bool isInputZp = (operand == "Input");
1843 
1844  bool tensorUnsigned =
1845  isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
1846  StringRef tensorName = isInputZp ? "input" : "output";
1847 
1848  Type zpElemType = getElementTypeOrSelf(zpVal);
1849 
1850  if (zp != 0) {
1851  if (!zpElemType.isInteger(8) &&
1852  !(zpElemType.isInteger(16) && tensorUnsigned)) {
1853  return op.emitOpError()
1854  << "expect " << tensorName << "_zp of 0, got " << zp;
1855  }
1856  if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
1857  return op.emitOpError() << "expect " << tensorName
1858  << "_zp of 0 or 32768 for unsigned int16 "
1859  << tensorName << ", got " << zp;
1860  }
1861  }
1862 
1863  return success();
1864 }
1865 
1866 #define ZERO_POINT_HELPER(OP, OPERAND_NAME) \
1867  FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
1868  return getZeroPoint(*this, get##OPERAND_NAME##Zp()); \
1869  } \
1870  LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
1871  return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
1872  }
1873 
1874 ZERO_POINT_HELPER(Conv2DOp, Input)
1875 ZERO_POINT_HELPER(Conv2DOp, Weight)
1876 ZERO_POINT_HELPER(Conv3DOp, Input)
1877 ZERO_POINT_HELPER(Conv3DOp, Weight)
1878 ZERO_POINT_HELPER(DepthwiseConv2DOp, Input)
1879 ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight)
1880 ZERO_POINT_HELPER(TransposeConv2DOp, Input)
1881 ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
1882 ZERO_POINT_HELPER(AvgPool2dOp, Input)
1883 ZERO_POINT_HELPER(AvgPool2dOp, Output)
1884 ZERO_POINT_HELPER(MatMulOp, A)
1885 ZERO_POINT_HELPER(MatMulOp, B)
1886 ZERO_POINT_HELPER(NegateOp, Input1)
1887 ZERO_POINT_HELPER(NegateOp, Output)
1888 ZERO_POINT_HELPER(RescaleOp, Input)
1889 ZERO_POINT_HELPER(RescaleOp, Output)
1890 #undef ZERO_POINT_HELPER
1891 
1892 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
1893  MLIRContext *context, ::std::optional<Location> location,
1894  TransposeOp::Adaptor adaptor,
1895  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1896  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1897 
1898  // If input rank and permutation length is unknown, the output rank is
1899  // unknown.
1900  if (!inputShape.hasRank()) {
1901  inferredReturnShapes.push_back(ShapedTypeComponents());
1902  return success();
1903  }
1904 
1905  const auto inputRank = inputShape.getRank();
1906 
1907  // This would imply the number of permutations does not match the rank of
1908  // the input which is illegal.
1909  if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
1910  return failure();
1911  }
1912 
1913  SmallVector<int64_t> outputShape;
1914  // Rank-0 means no permutations matter.
1915  if (inputRank == 0) {
1916  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1917  return success();
1918  }
1919 
1920  // Check whether the input dimensions are all the same.
1921  bool allTheSame = true;
1922  for (int i = 1, s = inputRank; i < s; i++) {
1923  if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
1924  allTheSame = false;
1925  break;
1926  }
1927  }
1928 
1929  // If all of the input dimensions are the same we don't care about the
1930  // permutation.
1931  if (allTheSame) {
1932  outputShape.resize(inputRank, inputShape.getDimSize(0));
1933  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1934  return success();
1935  }
1936 
1937  outputShape.resize(inputRank, ShapedType::kDynamic);
1938 
1939  // Constant permutation values must be within the input rank.
1940  if (llvm::any_of(adaptor.getPerms(),
1941  [inputRank](const auto i) { return i >= inputRank; }))
1942  return failure();
1943 
1944  outputShape.reserve(inputRank);
1945  for (int i = 0, s = inputRank; i < s; i++) {
1946  outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
1947  }
1948 
1949  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1950  return success();
1951 }
1952 
1953 LogicalResult tosa::TransposeOp::verify() {
1954  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1955  /* outType = */ getOutput().getType())
1956  .failed()) {
1957  return failure();
1958  }
1959  TensorType inputType = getInput1().getType();
1960  TensorType outputType = getOutput().getType();
1961  const llvm::ArrayRef<int32_t> constantPerms = getPerms();
1962 
1963  if (inputType.hasRank() &&
1964  constantPerms.size() != static_cast<size_t>(inputType.getRank()))
1965  return emitOpError() << "expected perms attribute to have size "
1966  << inputType.getRank() << " (input rank) but got size "
1967  << constantPerms.size();
1968  if (inputType.hasRank() && outputType.hasRank() &&
1969  inputType.getRank() != outputType.getRank())
1970  return emitOpError()
1971  << "expected input tensor rank to equal result tensor rank";
1972  if (outputType.hasRank() &&
1973  constantPerms.size() != static_cast<size_t>(outputType.getRank()))
1974  return emitOpError() << "expected perms attribute to have size "
1975  << outputType.getRank()
1976  << " (output rank) but got size "
1977  << constantPerms.size();
1978 
1979  if (!llvm::all_of(constantPerms,
1980  [&constantPerms](int32_t s) {
1981  return s >= 0 &&
1982  static_cast<size_t>(s) < constantPerms.size();
1983  }) ||
1984  !isPermutationVector(llvm::to_vector(llvm::map_range(
1985  constantPerms, [](int32_t v) -> int64_t { return v; }))))
1986  return emitOpError() << "expected valid permutation indices";
1987 
1988  // Verify that the types of the input and output tensors are properly
1989  // permuted.
1990  if (inputType.hasRank() && outputType.hasRank()) {
1991  assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
1992  inputType.getRank() == outputType.getRank());
1993 
1994  for (auto i = 0; i < outputType.getRank(); i++) {
1995  if (inputType.isDynamicDim(constantPerms[i]) ||
1996  outputType.isDynamicDim(i))
1997  continue;
1998 
1999  if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
2000  return emitOpError()
2001  << "expected output tensor dim " << i << " to match "
2002  << "input dim " << constantPerms[i] << " with value of "
2003  << inputType.getDimSize(constantPerms[i]);
2004  }
2005  }
2006 
2007  return success();
2008 }
2009 
2010 LogicalResult TransposeOp::reifyResultShapes(
2011  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2012 
2013  const llvm::ArrayRef<int32_t> transposePerms = getPerms();
2014 
2015  Value input = getInput1();
2016  auto inputType = cast<TensorType>(input.getType());
2017 
2018  SmallVector<OpFoldResult> returnedDims(inputType.getRank());
2019  for (auto dim : transposePerms) {
2020  int32_t dimInInput = transposePerms[dim];
2021  if (inputType.isDynamicDim(dimInInput))
2022  returnedDims[dim] =
2023  builder.create<tensor::DimOp>(getLoc(), input, dimInInput)
2024  .getResult();
2025  else
2026  returnedDims[dim] =
2027  builder.getIndexAttr(inputType.getDimSize(dimInInput));
2028  }
2029 
2030  reifiedReturnShapes.emplace_back(std::move(returnedDims));
2031  return success();
2032 }
2033 
2034 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2035  MLIRContext *context, ::std::optional<Location> location,
2036  GatherOp::Adaptor adaptor,
2037  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2038  llvm::SmallVector<int64_t> outputShape;
2039  outputShape.resize(3, ShapedType::kDynamic);
2040 
2041  ShapeAdaptor valuesShape(adaptor.getValues().getType());
2042  if (valuesShape.hasRank()) {
2043  outputShape[0] = valuesShape.getDimSize(0);
2044  outputShape[2] = valuesShape.getDimSize(2);
2045  }
2046 
2047  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2048  if (indicesShape.hasRank()) {
2049  if (outputShape[0] == ShapedType::kDynamic)
2050  outputShape[0] = indicesShape.getDimSize(0);
2051  if (outputShape[1] == ShapedType::kDynamic)
2052  outputShape[1] = indicesShape.getDimSize(1);
2053  }
2054 
2055  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2056  return success();
2057 }
2058 
2059 LogicalResult tosa::GatherOp::verify() {
2060  return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
2061  /* outType = */ getOutput().getType());
2062 }
2063 
2064 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2065  MLIRContext *context, ::std::optional<Location> location,
2066  ResizeOp::Adaptor adaptor,
2067  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2068  llvm::SmallVector<int64_t, 4> outputShape;
2069  outputShape.resize(4, ShapedType::kDynamic);
2070 
2071  ShapeAdaptor inputShape(adaptor.getInput().getType());
2072  if (!inputShape.hasRank())
2073  return failure();
2074 
2075  outputShape[0] = inputShape.getDimSize(0);
2076  outputShape[3] = inputShape.getDimSize(3);
2077  int64_t inputHeight = inputShape.getDimSize(1);
2078  int64_t inputWidth = inputShape.getDimSize(2);
2079 
2080  if ((inputHeight == ShapedType::kDynamic) ||
2081  (inputWidth == ShapedType::kDynamic))
2082  return failure();
2083 
2084  SmallVector<int64_t> scaleInt, offsetInt, borderInt;
2085  if (!tosa::getConstShapeValues(adaptor.getScale().getDefiningOp(),
2086  scaleInt) ||
2087  !tosa::getConstShapeValues(adaptor.getOffset().getDefiningOp(),
2088  offsetInt) ||
2089  !tosa::getConstShapeValues(adaptor.getBorder().getDefiningOp(),
2090  borderInt)) {
2091  return failure();
2092  }
2093 
2094  // Compute the output shape based on attributes: scale, offset, and border.
2095  outputShape[1] =
2096  (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2097  scaleInt[1]) +
2098  1;
2099 
2100  outputShape[2] =
2101  (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2102  scaleInt[3]) +
2103  1;
2104 
2105  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2106  return success();
2107 }
2108 
2109 LogicalResult tosa::ResizeOp::verify() {
2110  const Value input = getInput();
2111  const Value output = getOutput();
2112  const RankedTensorType inputType =
2113  llvm::dyn_cast<RankedTensorType>(input.getType());
2114  const RankedTensorType outputType =
2115  llvm::dyn_cast<RankedTensorType>(output.getType());
2116 
2117  if (!inputType)
2118  return emitOpError("expect a ranked input tensor");
2119  if (!outputType)
2120  return emitOpError("expect a ranked output tensor");
2121 
2122  const int64_t oh = outputType.getDimSize(1);
2123  const int64_t ow = outputType.getDimSize(2);
2124  const int64_t ih = inputType.getDimSize(1);
2125  const int64_t iw = inputType.getDimSize(2);
2126 
2127  SmallVector<int64_t> scaleValues;
2128  SmallVector<int64_t> offsetValues;
2129  SmallVector<int64_t> borderValues;
2130  if (!tosa::getConstShapeValues(getScale().getDefiningOp(), scaleValues) ||
2131  !tosa::getConstShapeValues(getOffset().getDefiningOp(), offsetValues) ||
2132  !tosa::getConstShapeValues(getBorder().getDefiningOp(), borderValues)) {
2133  // Skip following checks if shape is not constant
2134  return success();
2135  }
2136 
2137  if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
2138  return emitOpError("expect all scale values to be > 0, got ")
2139  << scaleValues;
2140 
2141  const int64_t scaleYN = scaleValues[0];
2142  const int64_t scaleYD = scaleValues[1];
2143  const int64_t scaleXN = scaleValues[2];
2144  const int64_t scaleXD = scaleValues[3];
2145 
2146  const int64_t offsetY = offsetValues[0];
2147  const int64_t offsetX = offsetValues[1];
2148 
2149  const int64_t borderY = borderValues[0];
2150  const int64_t borderX = borderValues[1];
2151 
2152  // Don't check with input height that could be broadcast (ih != 1)
2153  // since Linalg, a consumer of TOSA, expects broadcasting support
2154  // in resize to be available. Taking the cautious approach for now,
2155  // we can consider removing support for broadcasting later.
2156  if (ih != ShapedType::kDynamic && ih != 1) {
2157  const std::optional<int64_t> calculatedOutHeightMinusOne =
2158  idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
2159  if (!calculatedOutHeightMinusOne.has_value())
2160  return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
2161  "border_y ")
2162  << "to be wholly divisible by scale_y_d, got ((" << ih
2163  << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
2164  << ") / " << scaleYD;
2165  const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
2166  if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
2167  return emitOpError("calculated output height did not match expected: ")
2168  << "calculated=" << calculatedOutHeight << ", expected=" << oh;
2169  }
2170 
2171  // Don't check with input width that could be broadcast (iw != 1)
2172  // since Linalg, a consumer of TOSA, expects broadcasting support
2173  // in resize to be available. Taking the cautious approach for now,
2174  // we can consider removing support for broadcasting later.
2175  if (iw != ShapedType::kDynamic && iw != 1) {
2176  const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
2177  const std::optional<int64_t> calculatedOutWidthMinusOne =
2178  idivCheck(scaledInWidth, scaleXD);
2179  if (!calculatedOutWidthMinusOne.has_value())
2180  return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
2181  "border_x ")
2182  << "to be wholly divisible by scale_x_d, got ((" << iw
2183  << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
2184  << ") / " << scaleXD;
2185  const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
2186  if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
2187  return emitOpError("calculated output width did not match expected: ")
2188  << "calculated=" << calculatedOutWidth << ", expected=" << ow;
2189  }
2190 
2191  return success();
2192 }
2193 
2194 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
2195  MLIRContext *context, ::std::optional<Location> location,
2196  ScatterOp::Adaptor adaptor,
2197  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2198  llvm::SmallVector<int64_t> outputShape;
2199  outputShape.resize(3, ShapedType::kDynamic);
2200 
2201  ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
2202  if (valuesInShape.hasRank()) {
2203  outputShape[0] = valuesInShape.getDimSize(0);
2204  outputShape[1] = valuesInShape.getDimSize(1);
2205  outputShape[2] = valuesInShape.getDimSize(2);
2206  }
2207 
2208  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2209  if (indicesShape.hasRank()) {
2210  if (outputShape[0] == ShapedType::kDynamic)
2211  outputShape[0] = indicesShape.getDimSize(0);
2212  }
2213 
2214  ShapeAdaptor inputShape(adaptor.getInput().getType());
2215  if (inputShape.hasRank()) {
2216  if (outputShape[0] == ShapedType::kDynamic)
2217  outputShape[0] = inputShape.getDimSize(0);
2218  if (outputShape[2] == ShapedType::kDynamic)
2219  outputShape[2] = inputShape.getDimSize(2);
2220  }
2221 
2222  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2223  return success();
2224 }
2225 
2226 LogicalResult tosa::ScatterOp::verify() {
2227  if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
2228  /* outType = */ getValuesOut().getType())
2229  .failed() ||
2230  verifySameElementTypes(*this, /* inType = */ getInput().getType(),
2231  /* outType = */ getValuesOut().getType())
2232  .failed()) {
2233  return failure();
2234  }
2235  return success();
2236 }
2237 
2238 static LogicalResult ReduceInferReturnTypes(
2239  ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
2240  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2241  int64_t axisVal = axis.getValue().getSExtValue();
2242  if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
2243  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
2244  return success();
2245  }
2246 
2247  SmallVector<int64_t> outputShape;
2248  operandShape.getDims(outputShape);
2249  outputShape[axisVal] = 1;
2250  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2251  return success();
2252 }
2253 
2254 #define COMPATIBLE_RETURN_TYPES(OP) \
2255  bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
2256  if (l.size() != r.size() || l.size() != 1) \
2257  return false; \
2258  if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
2259  return false; \
2260  return succeeded(verifyCompatibleShape(l[0], r[0])); \
2261  }
2262 
2263 #define REDUCE_SHAPE_INFER(OP) \
2264  LogicalResult OP::inferReturnTypeComponents( \
2265  MLIRContext *context, ::std::optional<Location> location, \
2266  OP::Adaptor adaptor, \
2267  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2268  Type inputType = \
2269  llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
2270  ShapeAdaptor inputShape(adaptor.getInput().getType()); \
2271  const Properties &prop = adaptor.getProperties(); \
2272  return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
2273  inferredReturnShapes); \
2274  } \
2275  COMPATIBLE_RETURN_TYPES(OP)
2276 
2277 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
2278 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
2279 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
2280 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
2281 REDUCE_SHAPE_INFER(tosa::ReduceProductOp)
2282 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
2283 #undef REDUCE_SHAPE_INFER
2284 COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
2285 #undef COMPATIBLE_RETURN_TYPES
2286 
2287 template <typename T>
2288 static LogicalResult verifyReduceOp(T op) {
2289  // All TOSA reduce Ops have input, output and axis.
2290  TensorType inputType = op.getInput().getType();
2291  TensorType outputType = op.getOutput().getType();
2292  int32_t reduceAxis = op.getAxis();
2293 
2294  if (reduceAxis < 0) {
2295  op.emitOpError("reduce axis must not be negative");
2296  return failure();
2297  }
2298  if (inputType.hasRank()) {
2299  int64_t inputRank = inputType.getRank();
2300  // We allow for a special case where the input/output shape has rank 0 and
2301  // axis is also 0.
2302  if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
2303  op.emitOpError("expect input tensor rank (")
2304  << inputRank << ") to be larger than reduce axis (" << reduceAxis
2305  << ")";
2306  return failure();
2307  }
2308  }
2309  if (outputType.hasRank()) {
2310  int64_t outputRank = outputType.getRank();
2311  if (inputType.hasRank() && outputRank != inputType.getRank()) {
2312  op.emitOpError(
2313  "expect output tensor rank to be equal to input tensor rank");
2314  return failure();
2315  }
2316  if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
2317  op.emitOpError("expect output tensor rank (")
2318  << outputRank << ") to be larger than reduce axis (" << reduceAxis
2319  << ")";
2320  return failure();
2321  }
2322  // We can only verify the reduced dimension size to be 1 if this is not
2323  // the special case of output rank == 0.
2324  if (outputRank != 0) {
2325  auto outputShape = outputType.getShape();
2326  if (!outputType.isDynamicDim(reduceAxis) &&
2327  outputShape[reduceAxis] != 1) {
2328  op.emitOpError("expect reduced dimension size to be 1, got ")
2329  << outputShape[reduceAxis];
2330  return failure();
2331  }
2332  }
2333  }
2334  return success();
2335 }
2336 
2337 LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
2338 LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
2339 LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
2340 LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
2341 LogicalResult tosa::ReduceProductOp::verify() { return verifyReduceOp(*this); }
2342 LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
2343 
2344 static LogicalResult NAryInferReturnTypes(
2345  const ValueShapeRange &operands,
2346  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2347  llvm::SmallVector<int64_t> outShape;
2348  if (resolveBroadcastShape(operands, outShape).failed()) {
2349  inferredReturnShapes.push_back(ShapedTypeComponents());
2350  } else {
2351  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2352  }
2353  return success();
2354 }
2355 
2356 #define NARY_SHAPE_INFER(OP) \
2357  LogicalResult OP::inferReturnTypeComponents( \
2358  MLIRContext *context, ::std::optional<Location> location, \
2359  ValueShapeRange operands, DictionaryAttr attributes, \
2360  OpaqueProperties properties, RegionRange regions, \
2361  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2362  return NAryInferReturnTypes(operands, inferredReturnShapes); \
2363  }
2364 
2365 NARY_SHAPE_INFER(tosa::AbsOp)
2366 NARY_SHAPE_INFER(tosa::AddOp)
2367 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
2368 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
2369 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
2370 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
2371 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
2372 NARY_SHAPE_INFER(tosa::CastOp)
2373 NARY_SHAPE_INFER(tosa::CeilOp)
2374 NARY_SHAPE_INFER(tosa::ClampOp)
2375 NARY_SHAPE_INFER(tosa::ClzOp)
2376 NARY_SHAPE_INFER(tosa::CosOp)
2377 NARY_SHAPE_INFER(tosa::ExpOp)
2378 NARY_SHAPE_INFER(tosa::FloorOp)
2379 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
2380 NARY_SHAPE_INFER(tosa::GreaterOp)
2381 NARY_SHAPE_INFER(tosa::IdentityOp)
2382 NARY_SHAPE_INFER(tosa::IntDivOp)
2383 NARY_SHAPE_INFER(tosa::LogOp)
2384 NARY_SHAPE_INFER(tosa::LogicalAndOp)
2385 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
2386 NARY_SHAPE_INFER(tosa::LogicalNotOp)
2387 NARY_SHAPE_INFER(tosa::LogicalOrOp)
2388 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
2389 NARY_SHAPE_INFER(tosa::LogicalXorOp)
2390 NARY_SHAPE_INFER(tosa::MaximumOp)
2391 NARY_SHAPE_INFER(tosa::MinimumOp)
2392 NARY_SHAPE_INFER(tosa::PowOp)
2393 NARY_SHAPE_INFER(tosa::ReciprocalOp)
2394 NARY_SHAPE_INFER(tosa::ReverseOp)
2395 NARY_SHAPE_INFER(tosa::RsqrtOp)
2396 NARY_SHAPE_INFER(tosa::SinOp)
2397 NARY_SHAPE_INFER(tosa::SelectOp)
2398 NARY_SHAPE_INFER(tosa::SubOp)
2399 NARY_SHAPE_INFER(tosa::TanhOp)
2400 NARY_SHAPE_INFER(tosa::ErfOp)
2401 NARY_SHAPE_INFER(tosa::SigmoidOp)
2402 #undef PRED_SHAPE_INFER
2403 
2404 LogicalResult tosa::NegateOp::inferReturnTypeComponents(
2405  MLIRContext *context, ::std::optional<Location> location,
2406  NegateOp::Adaptor adaptor,
2407  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2408  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2409  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
2410  return success();
2411 }
2412 
2413 LogicalResult tosa::NegateOp::verify() {
2414  // Verify same element type
2415  const Type input1Type = getInput1().getType();
2416  const Type outputType = getOutput().getType();
2417  if (verifySameElementTypes(*this, input1Type, outputType).failed())
2418  return failure();
2419 
2420  // Verify same shape
2421  const SmallVector<Type, 2> types = {input1Type, outputType};
2422  if (failed(verifyCompatibleShapes(types)))
2423  return emitOpError() << "requires the same shape for input1 and output";
2424 
2425  const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
2426  const Type input1ZpEType =
2427  getStorageElementTypeOrSelf(getInput1Zp().getType());
2428  if (input1EType != input1ZpEType) {
2429  return emitOpError("expect both input1 and its zero point are the same "
2430  "element type, got ")
2431  << input1EType << " and " << input1ZpEType;
2432  }
2433  const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
2434  const Type outputZpEType =
2435  getStorageElementTypeOrSelf(getOutputZp().getType());
2436  if (outputEType != outputZpEType) {
2437  return emitOpError("expect both output and its zero point are the same "
2438  "element type, got ")
2439  << outputEType << " and " << outputZpEType;
2440  }
2441 
2442  FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2443  if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
2444  return failure();
2445 
2446  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2447  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
2448  return failure();
2449 
2450  return success();
2451 }
2452 
2453 static LogicalResult poolingInferReturnTypes(
2454  ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
2455  ArrayRef<int64_t> pad,
2456  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2457  llvm::SmallVector<int64_t> outputShape;
2458  outputShape.resize(4, ShapedType::kDynamic);
2459 
2460  // We only know the rank if the input type is unranked.
2461  if (!inputShape) {
2462  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2463  return success();
2464  }
2465 
2466  // Batch and number of channels are identical for pooling layer.
2467  outputShape[0] = inputShape.getDimSize(0);
2468  outputShape[3] = inputShape.getDimSize(3);
2469 
2470  int64_t height = inputShape.getDimSize(1);
2471  int64_t width = inputShape.getDimSize(2);
2472 
2473  if (!ShapedType::isDynamic(height)) {
2474  int64_t padded = height + pad[0] + pad[1] - kernel[0];
2475  outputShape[1] = padded / stride[0] + 1;
2476  }
2477 
2478  if (!ShapedType::isDynamic(width)) {
2479  int64_t padded = width + pad[2] + pad[3] - kernel[1];
2480  outputShape[2] = padded / stride[1] + 1;
2481  }
2482 
2483  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2484  return success();
2485 }
2486 
2487 LogicalResult Conv2DOp::inferReturnTypeComponents(
2488  MLIRContext *context, ::std::optional<Location> location,
2489  Conv2DOp::Adaptor adaptor,
2490  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2491  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
2492 
2493  int64_t inputWidth = ShapedType::kDynamic;
2494  int64_t inputHeight = ShapedType::kDynamic;
2495  int64_t weightWidth = ShapedType::kDynamic;
2496  int64_t weightHeight = ShapedType::kDynamic;
2497 
2498  // Input shape describes input width/height and batch.
2499 
2500  ShapeAdaptor inputShape(adaptor.getInput().getType());
2501  if (inputShape.hasRank()) {
2502  outputShape[0] = inputShape.getDimSize(0);
2503  inputHeight = inputShape.getDimSize(1);
2504  inputWidth = inputShape.getDimSize(2);
2505  }
2506 
2507  // Weight shapes describes the filter width/height and the output channels.
2508  ShapeAdaptor weightShape(adaptor.getWeight().getType());
2509  if (weightShape.hasRank()) {
2510  outputShape[3] = weightShape.getDimSize(0);
2511  weightHeight = weightShape.getDimSize(1);
2512  weightWidth = weightShape.getDimSize(2);
2513  }
2514 
2515  // Bias shape can describe the output channels.
2516  ShapeAdaptor biasShape(adaptor.getBias().getType());
2517  if (biasShape.hasRank()) {
2518  outputShape[3] = ShapedType::isDynamic(outputShape[3])
2519  ? biasShape.getDimSize(0)
2520  : outputShape[3];
2521  }
2522 
2523  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
2524  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
2525  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
2526 
2527  if (!ShapedType::isDynamic(inputHeight) &&
2528  !ShapedType::isDynamic(weightHeight)) {
2529  int64_t inputSize = inputHeight + padding[0] + padding[1];
2530  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
2531  int64_t unstridedResult = inputSize - filterSize + 1;
2532  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
2533  }
2534 
2535  if (!ShapedType::isDynamic(inputWidth) &&
2536  !ShapedType::isDynamic(weightWidth)) {
2537  int64_t inputSize = inputWidth + padding[2] + padding[3];
2538  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
2539  int64_t unstridedResult = inputSize - filterSize + 1;
2540  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
2541  }
2542 
2543  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2544  return success();
2545 }
2546 
2547 LogicalResult Conv2DOp::verify() {
2548  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
2549  return failure();
2550 
2551  llvm::ArrayRef<int64_t> padding = getPad();
2552  if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
2553  return emitOpError("expect all padding values to be >= 0, got ") << padding;
2554 
2555  llvm::ArrayRef<int64_t> strides = getStride();
2556  if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
2557  return emitOpError("expect all stride values to be >= 1, got ") << strides;
2558 
2559  llvm::ArrayRef<int64_t> dilations = getDilation();
2560  if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
2561  return emitOpError("expect all dilation values to be >= 1, got ")
2562  << dilations;
2563 
2564  const RankedTensorType outputType =
2565  llvm::dyn_cast<RankedTensorType>(getOutput().getType());
2566  if (!outputType)
2567  // Skip following checks if output is not ranked
2568  return success();
2569 
2570  const RankedTensorType inputType =
2571  llvm::dyn_cast<RankedTensorType>(getInput().getType());
2572  const RankedTensorType weightType =
2573  llvm::dyn_cast<RankedTensorType>(getWeight().getType());
2574 
2575  if (inputType && weightType) {
2576  const auto verifyOutputSize =
2577  [this](const int64_t inputSize, const int64_t kernelSize,
2578  const int64_t outputSize, const int64_t padBefore,
2579  const int64_t padAfter, const int64_t stride,
2580  const int64_t dilation, const llvm::StringRef dimName,
2581  const llvm::StringRef dimAxis,
2582  const llvm::StringRef padBeforeName,
2583  const llvm::StringRef padAfterName) -> LogicalResult {
2584  if (inputSize == ShapedType::kDynamic ||
2585  kernelSize == ShapedType::kDynamic)
2586  return success();
2587 
2588  const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
2589  inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
2590  stride);
2591  if (!calculatedOutSizeMinusOne.has_value())
2592  return emitOpError("expected input_")
2593  << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
2594  << padAfterName << " - (kernel_" << dimName
2595  << " - 1) * dilation_" << dimAxis
2596  << " to be wholly divisible by stride_" << dimAxis << ", got ("
2597  << inputSize << " - 1 + " << padBefore << " + " << padAfter
2598  << " - (" << kernelSize << " - 1) * " << dilation << ") / "
2599  << stride;
2600 
2601  const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
2602  if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
2603  return emitOpError("calculated output ")
2604  << dimName << " did not match expected: "
2605  << "calculated=" << calculatedOutSize
2606  << ", expected=" << outputSize;
2607 
2608  return success();
2609  };
2610 
2611  if (failed(verifyOutputSize(
2612  inputType.getDimSize(1), weightType.getDimSize(1),
2613  outputType.getDimSize(1), padding[0], padding[1], strides[0],
2614  dilations[0], "height", "y", "top", "bottom")))
2615  return failure();
2616 
2617  if (failed(verifyOutputSize(
2618  inputType.getDimSize(2), weightType.getDimSize(2),
2619  outputType.getDimSize(2), padding[2], padding[3], strides[1],
2620  dilations[1], "width", "x", "left", "right")))
2621  return failure();
2622  }
2623 
2624  const RankedTensorType biasType =
2625  llvm::dyn_cast<RankedTensorType>(getBias().getType());
2626  if (!biasType)
2627  // Skip following checks if bias is not ranked
2628  return success();
2629 
2630  const int64_t biasChannels = biasType.getDimSize(0);
2631  const int64_t outputChannels = outputType.getDimSize(3);
2632  if (biasChannels == ShapedType::kDynamic ||
2633  outputChannels == ShapedType::kDynamic)
2634  // Skip following checks if biasChannels or outputChannels is dynamic dim
2635  return success();
2636 
2637  if (biasChannels != outputChannels && biasChannels != 1)
2638  return emitOpError(
2639  "bias channels expected to be equal to output channels (")
2640  << outputChannels << ") or 1, got " << biasChannels;
2641  return success();
2642 }
2643 
2644 LogicalResult Conv3DOp::inferReturnTypeComponents(
2645  MLIRContext *context, ::std::optional<Location> location,
2646  Conv3DOp::Adaptor adaptor,
2647  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2648  llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
2649 
2650  int64_t inputWidth = ShapedType::kDynamic;
2651  int64_t inputHeight = ShapedType::kDynamic;
2652  int64_t inputDepth = ShapedType::kDynamic;
2653 
2654  int64_t weightWidth = ShapedType::kDynamic;
2655  int64_t weightHeight = ShapedType::kDynamic;
2656  int64_t weightDepth = ShapedType::kDynamic;
2657 
2658  // Input shape describes input width/height and batch.
2659  ShapeAdaptor inputShape(adaptor.getInput().getType());
2660  if (inputShape.hasRank()) {
2661  outputShape[0] = inputShape.getDimSize(0);
2662  inputDepth = inputShape.getDimSize(1);
2663  inputHeight = inputShape.getDimSize(2);
2664  inputWidth = inputShape.getDimSize(3);
2665  }
2666 
2667  // Weight shapes describes the filter width/height and the output channels.
2668  ShapeAdaptor weightShape(adaptor.getWeight().getType());
2669  if (weightShape.hasRank()) {
2670  outputShape[4] = weightShape.getDimSize(0);
2671  weightDepth = weightShape.getDimSize(1);
2672  weightHeight = weightShape.getDimSize(2);
2673  weightWidth = weightShape.getDimSize(3);
2674  }
2675 
2676  // Bias shape can describe the output channels.
2677  ShapeAdaptor biasShape(adaptor.getBias().getType());
2678  if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
2679  outputShape[4] = biasShape.getDimSize(0);
2680  }
2681 
2682  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
2683  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
2684  llvm::ArrayRef<int64_t> pad = adaptor.getPad();
2685 
2686  if (!ShapedType::isDynamic(inputDepth) &&
2687  !ShapedType::isDynamic(weightDepth)) {
2688  int32_t inputSize = inputDepth + pad[0] + pad[1];
2689  int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
2690  int32_t unstridedResult = inputSize - filterSize + 1;
2691  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
2692  }
2693 
2694  if (!ShapedType::isDynamic(inputHeight) &&
2695  !ShapedType::isDynamic(weightHeight)) {
2696  int32_t inputSize = inputHeight + pad[2] + pad[3];
2697  int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
2698  int32_t unstridedResult = inputSize - filterSize + 1;
2699  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
2700  }
2701 
2702  if (!ShapedType::isDynamic(inputWidth) &&
2703  !ShapedType::isDynamic(weightWidth)) {
2704  int32_t inputSize = inputWidth + pad[4] + pad[5];
2705  int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
2706  int32_t unstridedResult = inputSize - filterSize + 1;
2707  outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
2708  }
2709 
2710  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2711  return success();
2712 }
2713 
2714 LogicalResult Conv3DOp::verify() {
2715  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
2716  return failure();
2717  return success();
2718 }
2719 
2720 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
2721  MLIRContext *context, ::std::optional<Location> location,
2722  AvgPool2dOp::Adaptor adaptor,
2723  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2724  ShapeAdaptor inputShape(adaptor.getInput().getType());
2725  const Properties &prop = adaptor.getProperties();
2726  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
2727  inferredReturnShapes);
2728 }
2729 
2730 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
2731  MLIRContext *context, ::std::optional<Location> location,
2732  MaxPool2dOp::Adaptor adaptor,
2733  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2734  ShapeAdaptor inputShape(adaptor.getInput().getType());
2735  const Properties &prop = adaptor.getProperties();
2736  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
2737  inferredReturnShapes);
2738 }
2739 
2740 LogicalResult MaxPool2dOp::verify() {
2741  if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
2742  /* outType = */ getOutput().getType())))
2743  return failure();
2744 
2745  if (failed(verifyPoolingOp(*this)))
2746  return failure();
2747 
2748  return success();
2749 }
2750 
2751 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
2752  MLIRContext *context, ::std::optional<Location> location,
2753  DepthwiseConv2DOp::Adaptor adaptor,
2754  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2755  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
2756 
2757  int64_t inputWidth = ShapedType::kDynamic;
2758  int64_t inputHeight = ShapedType::kDynamic;
2759  int64_t inputChannels = ShapedType::kDynamic;
2760 
2761  int64_t weightWidth = ShapedType::kDynamic;
2762  int64_t weightHeight = ShapedType::kDynamic;
2763  int64_t depthChannels = ShapedType::kDynamic;
2764 
2765  // Input shape describes input width/height and batch.
2766  ShapeAdaptor inputShape(adaptor.getInput().getType());
2767  if (inputShape.hasRank()) {
2768  outputShape[0] = inputShape.getDimSize(0);
2769  inputHeight = inputShape.getDimSize(1);
2770  inputWidth = inputShape.getDimSize(2);
2771  inputChannels = inputShape.getDimSize(3);
2772  }
2773 
2774  // Weight shapes describes the filter width/height and the output channels.
2775  ShapeAdaptor weightShape(adaptor.getWeight().getType());
2776  if (weightShape.hasRank()) {
2777  weightHeight = weightShape.getDimSize(0);
2778  weightWidth = weightShape.getDimSize(1);
2779  inputChannels = ShapedType::isDynamic(inputChannels)
2780  ? weightShape.getDimSize(2)
2781  : inputChannels;
2782  depthChannels = weightShape.getDimSize(3);
2783  }
2784 
2785  // If both inputChannels and depthChannels are available we can determine
2786  // the output channels.
2787  if (!ShapedType::isDynamic(inputChannels) &&
2788  !ShapedType::isDynamic(depthChannels)) {
2789  outputShape[3] = inputChannels * depthChannels;
2790  }
2791 
2792  // Bias shape can describe the output channels.
2793  ShapeAdaptor biasShape(adaptor.getBias().getType());
2794  if (biasShape.hasRank()) {
2795  outputShape[3] = ShapedType::isDynamic(outputShape[3])
2796  ? biasShape.getDimSize(0)
2797  : outputShape[3];
2798  }
2799 
2800  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
2801  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
2802  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
2803 
2804  if (!ShapedType::isDynamic(inputHeight) &&
2805  !ShapedType::isDynamic(weightHeight)) {
2806  int64_t inputSize = inputHeight + padding[0] + padding[1];
2807  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
2808  int64_t unstridedResult = inputSize - filterSize + 1;
2809  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
2810  }
2811 
2812  if (!ShapedType::isDynamic(inputWidth) &&
2813  !ShapedType::isDynamic(weightWidth)) {
2814  int64_t inputSize = inputWidth + padding[2] + padding[3];
2815  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
2816  int64_t unstridedResult = inputSize - filterSize + 1;
2817  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
2818  }
2819 
2820  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2821  return success();
2822 }
2823 
2824 LogicalResult DepthwiseConv2DOp::verify() {
2825  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
2826  return failure();
2827  return success();
2828 }
2829 
2830 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
2831  MLIRContext *context, ::std::optional<Location> location,
2832  TransposeConv2DOp::Adaptor adaptor,
2833  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2834  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
2835 
2836  int64_t inputWidth = ShapedType::kDynamic;
2837  int64_t inputHeight = ShapedType::kDynamic;
2838  int64_t weightWidth = ShapedType::kDynamic;
2839  int64_t weightHeight = ShapedType::kDynamic;
2840 
2841  // Input shape describes input width/height and batch.
2842  ShapeAdaptor inputShape(adaptor.getInput().getType());
2843  if (inputShape.hasRank()) {
2844  outputShape[0] = ShapedType::isDynamic(outputShape[0])
2845  ? inputShape.getDimSize(0)
2846  : outputShape[0];
2847  inputHeight = inputShape.getDimSize(1);
2848  inputWidth = inputShape.getDimSize(2);
2849  }
2850 
2851  // Weight shapes describes the filter width/height and the output channels.
2852  ShapeAdaptor weightShape(adaptor.getWeight().getType());
2853  if (weightShape.hasRank()) {
2854  outputShape[3] = ShapedType::isDynamic(outputShape[3])
2855  ? weightShape.getDimSize(0)
2856  : outputShape[3];
2857  weightHeight = weightShape.getDimSize(1);
2858  weightWidth = weightShape.getDimSize(2);
2859  }
2860 
2861  // Bias shape can describe the output channels.
2862  ShapeAdaptor biasShape(adaptor.getInput().getType());
2863  if (biasShape.hasRank()) {
2864  outputShape[3] = ShapedType::isDynamic(outputShape[3])
2865  ? biasShape.getDimSize(0)
2866  : outputShape[3];
2867  }
2868 
2869  llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
2870  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
2871 
2872  if (!ShapedType::isDynamic(inputHeight) &&
2873  !ShapedType::isDynamic(weightHeight)) {
2874  int64_t calculateSize =
2875  (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
2876  outputShape[1] =
2877  ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
2878  }
2879 
2880  if (!ShapedType::isDynamic(inputWidth) &&
2881  !ShapedType::isDynamic(weightWidth)) {
2882  int64_t calculateSize =
2883  (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
2884  outputShape[2] =
2885  ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
2886  }
2887 
2888  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2889  return success();
2890 }
2891 
2892 LogicalResult TransposeConv2DOp::verify() {
2893  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
2894  return failure();
2895  return success();
2896 }
2897 
2898 LogicalResult RescaleOp::verify() {
2899  auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
2900  if (!inputType) {
2901  emitOpError("expect shaped tensor for input, got ") << getInput().getType();
2902  return failure();
2903  }
2904 
2905  auto inputElementType =
2906  getStorageElementTypeOrSelf(inputType.getElementType());
2907  if (!mlir::isa<IntegerType>(inputElementType)) {
2908  emitOpError("expect input to have integer element type, got ")
2909  << inputElementType;
2910  return failure();
2911  }
2912 
2913  auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
2914  if (!outputType) {
2915  emitOpError("expect shaped tensor for output, got ")
2916  << getOutput().getType();
2917  return failure();
2918  }
2919 
2920  auto outputElementType =
2921  getStorageElementTypeOrSelf(outputType.getElementType());
2922  if (!mlir::isa<IntegerType>(outputElementType)) {
2923  emitOpError("expect output to have integer element type, got ")
2924  << outputElementType;
2925  return failure();
2926  }
2927 
2928  if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
2929  .failed())
2930  return failure();
2931 
2932  if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
2933  .failed())
2934  return failure();
2935 
2936  FailureOr<int64_t> maybeIZp = getInputZeroPoint();
2937  if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
2938  return failure();
2939 
2940  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2941  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
2942  return failure();
2943 
2944  auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
2945  if (!multiplierType) {
2946  emitOpError("expect shaped tensor for multiplier, got ")
2947  << getMultiplier().getType();
2948  return failure();
2949  }
2950 
2951  auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
2952  if (!shiftType) {
2953  emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
2954  return failure();
2955  }
2956 
2957  // multiplier element type must be i32 for scale32 = true
2958  if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
2959  emitOpError("expect i32 element type for multiplier for scale32=true, got ")
2960  << multiplierType.getElementType();
2961  return failure();
2962  }
2963 
2964  // multiplier element type must be i16 for scale32 = false
2965  if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
2966  emitOpError(
2967  "expect i16 element type for multiplier for scale32=false, got ")
2968  << multiplierType.getElementType();
2969  return failure();
2970  }
2971 
2972  if (!inputType.hasRank())
2973  return success();
2974 
2975  // multiplier/shift must have shape = {numChannels},
2976  // where numChannel is 1 if per_channel = false
2977  // otherwise numChannel is dimension in input shape's last axis
2978  int64_t numChannels = 1;
2979  if (getPerChannel()) {
2980  numChannels = inputType.getDimSize(inputType.getRank() - 1);
2981  }
2982 
2983  if (!multiplierType.hasRank())
2984  return success();
2985 
2986  ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
2987  // multiplier input has rank 1 by dialect definition
2988  if (multiplierShape[0] != ShapedType::kDynamic &&
2989  multiplierShape[0] != numChannels) {
2990  emitOpError("expect shape of { ")
2991  << numChannels << " } for multiplier input, got { "
2992  << multiplierShape[0] << " }";
2993  return failure();
2994  }
2995 
2996  if (!shiftType.hasRank())
2997  return success();
2998 
2999  ArrayRef<int64_t> shiftShape = shiftType.getShape();
3000  // shift input has rank 1 by dialect definition
3001  if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3002  emitOpError("expect shape of { ")
3003  << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
3004  return failure();
3005  }
3006 
3007  return success();
3008 }
3009 
3010 LogicalResult RescaleOp::inferReturnTypeComponents(
3011  MLIRContext *context, ::std::optional<Location> location,
3012  RescaleOp::Adaptor adaptor,
3013  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3014  ShapeAdaptor inputShape(adaptor.getInput().getType());
3015  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3016  return success();
3017 }
3018 
3019 LogicalResult IfOp::inferReturnTypeComponents(
3020  MLIRContext *context, ::std::optional<Location> location,
3021  IfOp::Adaptor adaptor,
3022  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3024  for (Region *region : adaptor.getRegions()) {
3025  for (auto &block : *region)
3026  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3027  yieldOps.push_back(returnOp);
3028  }
3029 
3030  if (yieldOps.empty())
3031  return failure();
3032 
3033  // Get the initial type information for the yield op.
3034  llvm::SmallVector<ValueKnowledge> resultKnowledge;
3035  resultKnowledge.reserve(yieldOps.front().getNumOperands());
3036  for (auto operand : yieldOps.front().getOperands()) {
3037  resultKnowledge.push_back(
3038  ValueKnowledge::getKnowledgeFromType(operand.getType()));
3039  }
3040 
3041  for (auto yieldOp : yieldOps) {
3042  if (resultKnowledge.size() != yieldOp.getNumOperands())
3043  return failure();
3044 
3045  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
3046  int32_t index = it.index();
3047  auto meet = ValueKnowledge::meet(
3048  resultKnowledge[index],
3049  ValueKnowledge::getKnowledgeFromType(it.value().getType()));
3050  if (!meet)
3051  continue;
3052  resultKnowledge[index] = meet;
3053  }
3054  }
3055 
3056  for (const ValueKnowledge &result : resultKnowledge) {
3057  inferredReturnShapes.push_back(result.getShapedTypeComponents());
3058  }
3059 
3060  return success();
3061 }
3062 
3063 LogicalResult WhileOp::inferReturnTypeComponents(
3064  MLIRContext *context, ::std::optional<Location> location,
3065  WhileOp::Adaptor adaptor,
3066  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3068  for (auto &block : adaptor.getBodyGraph())
3069  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3070  yieldOps.push_back(returnOp);
3071 
3072  // TOSA's while must have a tosa.yield as its terminator. If not found this
3073  // tosa.while is invalid.
3074  if (yieldOps.empty())
3075  return failure();
3076 
3077  // Get the initial type information from the operand types.
3078  llvm::SmallVector<ValueKnowledge> resultKnowledge;
3079  resultKnowledge.reserve(yieldOps.front().getNumOperands());
3080  for (auto operand : yieldOps.front().getOperands()) {
3081  resultKnowledge.push_back(
3082  ValueKnowledge::getKnowledgeFromType(operand.getType()));
3083  }
3084 
3085  for (auto yieldOp : yieldOps) {
3086  if (resultKnowledge.size() != yieldOp.getNumOperands())
3087  return failure();
3088 
3089  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
3090  int32_t index = it.index();
3091  if (auto meet = ValueKnowledge::meet(
3092  resultKnowledge[index],
3093  ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
3094  resultKnowledge[index] = meet;
3095  }
3096  }
3097  }
3098 
3099  for (const ValueKnowledge &result : resultKnowledge) {
3100  inferredReturnShapes.push_back(result.getShapedTypeComponents());
3101  }
3102 
3103  return success();
3104 }
3105 
3106 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
3107  if (auto vt = llvm::dyn_cast<VectorType>(getType()))
3108  return llvm::to_vector<4>(vt.getShape());
3109  return std::nullopt;
3110 }
3111 
3112 // parse and print of IfOp refer to the implementation of SCF dialect.
3113 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
3114  // Create the regions for 'then'.
3115  result.regions.reserve(2);
3116  Region *thenRegion = result.addRegion();
3117  Region *elseRegion = result.addRegion();
3118 
3119  auto &builder = parser.getBuilder();
3121  // Create a i1 tensor type for the boolean condition.
3122  Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
3123  if (parser.parseOperand(cond) ||
3124  parser.resolveOperand(cond, i1Type, result.operands))
3125  return failure();
3126  // Parse optional results type list.
3127  if (parser.parseOptionalArrowTypeList(result.types))
3128  return failure();
3129  // Parse the 'then' region.
3130  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
3131  return failure();
3132 
3133  // If we find an 'else' keyword then parse the 'else' region.
3134  if (!parser.parseOptionalKeyword("else")) {
3135  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
3136  return failure();
3137  }
3138 
3139  // Parse the optional attribute list.
3140  if (parser.parseOptionalAttrDict(result.attributes))
3141  return failure();
3142  return success();
3143 }
3144 
3145 void IfOp::print(OpAsmPrinter &p) {
3146  bool printBlockTerminators = false;
3147 
3148  p << " " << getCondition();
3149  if (!getResults().empty()) {
3150  p << " -> (" << getResultTypes() << ")";
3151  // Print yield explicitly if the op defines values.
3152  printBlockTerminators = true;
3153  }
3154  p << ' ';
3155  p.printRegion(getThenGraph(),
3156  /*printEntryBlockArgs=*/false,
3157  /*printBlockTerminators=*/printBlockTerminators);
3158 
3159  // Print the 'else' regions if it exists and has a block.
3160  auto &elseRegion = getElseGraph();
3161  if (!elseRegion.empty()) {
3162  p << " else ";
3163  p.printRegion(elseRegion,
3164  /*printEntryBlockArgs=*/false,
3165  /*printBlockTerminators=*/printBlockTerminators);
3166  }
3167 
3168  p.printOptionalAttrDict((*this)->getAttrs());
3169 }
3170 
3171 LogicalResult ReverseOp::verify() {
3172  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
3173  /* outType = */ getOutput().getType())
3174  .failed())
3175  return failure();
3176  TensorType inputType = getInput1().getType();
3177  TensorType outputType = getOutput().getType();
3178  int32_t reverseAxis = getAxis();
3179 
3180  if (reverseAxis < 0)
3181  return emitOpError("expected non-negative reverse axis");
3182  if (inputType.hasRank()) {
3183  int64_t inputRank = inputType.getRank();
3184  // We allow for a special case where the input/output shape has rank 0 and
3185  // axis is also 0.
3186  if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
3187  return emitOpError("expect input tensor rank (")
3188  << inputRank << ") to be larger than reverse axis (" << reverseAxis
3189  << ")";
3190  }
3191  if (outputType.hasRank()) {
3192  int64_t outputRank = outputType.getRank();
3193  if (inputType.hasRank() && outputRank != inputType.getRank())
3194  return emitOpError(
3195  "expect output tensor rank to be equal to input tensor rank");
3196  if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
3197  return emitOpError("expect output tensor rank (")
3198  << outputRank << ") to be larger than reverse axis ("
3199  << reverseAxis << ")";
3200  }
3201  return success();
3202 }
3203 
3204 LogicalResult tosa::SelectOp::verify() {
3205  // verify input2 and input3 have same element type as output
3206  if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(),
3207  /* outType = */ getOutput().getType())
3208  .failed() ||
3209  verifySameElementTypes(*this, /* inType = */ getInput3().getType(),
3210  /* outType = */ getOutput().getType())
3211  .failed()) {
3212  return failure();
3213  }
3214  // verify input1 has element type of bool
3215  auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().getType());
3216  if (!predicateType) {
3217  return emitOpError("expect shaped tensor for input1, got ")
3218  << getInput1().getType();
3219  }
3220  auto predicateElementType = predicateType.getElementType();
3221  if (!predicateElementType.isInteger(1)) {
3222  return emitOpError("expect element type of bool for input1, got ")
3223  << predicateElementType;
3224  }
3225 
3226  return success();
3227 }
3228 
3229 // parse and print of WhileOp refer to the implementation of SCF dialect.
3230 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3233  Region *cond = result.addRegion();
3234  Region *body = result.addRegion();
3235 
3236  OptionalParseResult listResult =
3237  parser.parseOptionalAssignmentList(regionArgs, operands);
3238  if (listResult.has_value() && failed(listResult.value()))
3239  return failure();
3240 
3241  FunctionType functionType;
3242  SMLoc typeLoc = parser.getCurrentLocation();
3243  if (failed(parser.parseColonType(functionType)))
3244  return failure();
3245 
3246  result.addTypes(functionType.getResults());
3247 
3248  if (functionType.getNumInputs() != operands.size()) {
3249  return parser.emitError(typeLoc)
3250  << "expected as many input types as operands "
3251  << "(expected " << operands.size() << " got "
3252  << functionType.getNumInputs() << ")";
3253  }
3254 
3255  // Resolve input operands.
3256  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3257  parser.getCurrentLocation(),
3258  result.operands)))
3259  return failure();
3260 
3261  // Propagate the types into the region arguments.
3262  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3263  regionArgs[i].type = functionType.getInput(i);
3264 
3265  return failure(parser.parseRegion(*cond, regionArgs) ||
3266  parser.parseKeyword("do") || parser.parseRegion(*body) ||
3268 }
3269 
3271  Block::BlockArgListType blocksArgs,
3272  ValueRange initializers,
3273  StringRef prefix = "") {
3274  assert(blocksArgs.size() == initializers.size() &&
3275  "expected same length of arguments and initializers");
3276  if (initializers.empty())
3277  return;
3278 
3279  parser << prefix << '(';
3280  llvm::interleaveComma(
3281  llvm::zip(blocksArgs, initializers), parser,
3282  [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
3283  parser << ")";
3284 }
3285 
3286 void WhileOp::print(OpAsmPrinter &parser) {
3287  printInitializationList(parser, getCondGraph().front().getArguments(),
3288  getInputList(), " ");
3289  parser << " : ";
3290  parser.printFunctionalType(getInputList().getTypes(),
3291  getResults().getTypes());
3292  parser << ' ';
3293  parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
3294  parser << " do ";
3295  parser.printRegion(getBodyGraph());
3296  parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3297 }
3298 
3299 // Create a rank-1 const tensor for zero point of the source tensor.
3300 std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
3301  Location loc,
3302  Type srcElemType,
3303  int64_t zp) {
3304  srcElemType = getStorageElementTypeOrSelf(srcElemType);
3305  auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
3306  if (llvm::isa<FloatType>(srcElemType)) {
3307  auto zpAttr = DenseElementsAttr::get(
3308  zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
3309  return builder.create<tosa::ConstOp>(loc, zpType, zpAttr);
3310  }
3311  if (llvm::isa<IntegerType>(srcElemType)) {
3312  auto zpAttr =
3313  DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp));
3314  return builder.create<tosa::ConstOp>(loc, zpType, zpAttr);
3315  }
3316  llvm::errs() << "zero point is not allowed for unsupported data types\n";
3317  return std::nullopt;
3318 }
3319 
3320 //===----------------------------------------------------------------------===//
3321 // TOSA Shape and Shape Operators Helper functions.
3322 //===----------------------------------------------------------------------===//
3323 
3325  return mlir::isa<tosa::shapeType>(t);
3326 }
3327 
3328 LogicalResult
3330  int rank) {
3331  if (rank < 0)
3332  return emitError() << "invalid rank (must be >= 0): " << rank;
3333  return success();
3334 }
3335 
3337  for (auto v : op->getOperands()) {
3338  if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
3339  Operation *definingOp = v.getDefiningOp();
3340  if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
3341  return op->emitOpError("shape operand is not compile time resolvable");
3342  }
3343  }
3344  }
3345  return success();
3346 }
3347 
3349  for (auto type : op->getOperandTypes()) {
3350  if (!mlir::isa<mlir::tosa::shapeType>(type)) {
3351  return op->emitOpError("must have operands with tosa shape type");
3352  }
3353  }
3354  for (auto type : op->getResultTypes()) {
3355  if (!mlir::isa<mlir::tosa::shapeType>(type)) {
3356  return op->emitOpError("must have result with tosa shape type");
3357  }
3358  }
3359  return success();
3360 }
3361 
3362 LogicalResult
3364  if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) ||
3365  failed(verifyTosaShapeOperator(op)))
3366  return failure();
3367 
3368  // delegate function that returns rank of shape type
3369  auto getRank = [](const Type type) {
3370  return mlir::cast<mlir::tosa::shapeType>(type).getRank();
3371  };
3372  auto operandTypes = op->getOperandTypes();
3373  auto resultTypes = op->getResultTypes();
3374 
3375  auto rank = getRank(*op->getOperandTypes().begin());
3376  for (auto type : operandTypes) {
3377  if (getRank(type) != rank) {
3378  return op->emitOpError("operands don't have matching ranks");
3379  }
3380  }
3381  for (auto type : resultTypes) {
3382  if (getRank(type) != rank) {
3383  return op->emitOpError("result shape has different rank than operands");
3384  }
3385  }
3386  return success();
3387 }
3388 
3389 //===----------------------------------------------------------------------===//
3390 // TOSA Shape Operators verify functions.
3391 //===----------------------------------------------------------------------===//
3392 
3393 LogicalResult tosa::ConstShapeOp::verify() {
3394  // check one dimensional rank
3395  auto valuesRank = getValues().getType().getRank();
3396  if (valuesRank != 1)
3397  return emitOpError("expect elements in attribute values with rank 1");
3398  // check that number of elements in values attr equal to rank of result shape
3399  auto count = getValues().getNumElements();
3400  auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
3401  if (!(count == rank || (count == 1 && rank == 0))) {
3402  return emitOpError("expect number of elements in attribute values (")
3403  << count << ") to be equal to the rank (" << rank
3404  << ") for the result shape type";
3405  }
3406  return success();
3407 }
3408 
3409 //===----------------------------------------------------------------------===//
3410 // TOSA Attribute Definitions.
3411 //===----------------------------------------------------------------------===//
3412 
3413 #define GET_ATTRDEF_CLASSES
3414 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
3415 
3416 //===----------------------------------------------------------------------===//
3417 // TOSA Type Definitions.
3418 //===----------------------------------------------------------------------===//
3419 #define GET_TYPEDEF_CLASSES
3420 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
3421 
3422 //===----------------------------------------------------------------------===//
3423 // TOSA Operator Definitions.
3424 //===----------------------------------------------------------------------===//
3425 
3426 #define GET_OP_CLASSES
3427 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition: TosaOps.cpp:736
static FailureOr< int64_t > getZeroPoint(T op, Value val)
Definition: TosaOps.cpp:1799
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType)
Definition: TosaOps.cpp:444
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:2238
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
Definition: TosaOps.cpp:257
#define REDUCE_SHAPE_INFER(OP)
Definition: TosaOps.cpp:2263
static LogicalResult verifyConvOp(T op)
Definition: TosaOps.cpp:281
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:2453
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition: TosaOps.cpp:850
static LogicalResult verifyReduceOp(T op)
Definition: TosaOps.cpp:2288
#define NARY_SHAPE_INFER(OP)
Definition: TosaOps.cpp:2356
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition: TosaOps.cpp:714
static LogicalResult verifyConvOpModes(T op)
Definition: TosaOps.cpp:397
std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition: TosaOps.cpp:239
#define ZERO_POINT_HELPER(OP, OPERAND_NAME)
Definition: TosaOps.cpp:1866
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:2344
static Type getStorageElementTypeOrSelf(Type type)
Definition: TosaOps.cpp:249
#define COMPATIBLE_RETURN_TYPES(OP)
Definition: TosaOps.cpp:2254
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition: TosaOps.cpp:868
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
Definition: TosaOps.cpp:810
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition: TosaOps.cpp:690
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
Definition: TosaOps.cpp:765
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition: TosaOps.cpp:1373
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
Definition: TosaOps.cpp:1824
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Definition: TosaOps.cpp:3270
static LogicalResult verifyPoolingOp(T op)
Definition: TosaOps.cpp:508
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
Definition: TosaOps.cpp:957
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:95
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class indicates that op operates on tosa shape types.
Definition: TosaOps.h:126
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
Definition: Operation.cpp:915
LogicalResult verifyTosaShapeOperator(Operation *op)
Definition: TosaOps.cpp:3348
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
Definition: TosaOps.cpp:3363
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
Definition: TosaOps.cpp:3336
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition: Traits.cpp:60
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Definition: QuantUtils.cpp:198
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
Definition: QuantUtils.cpp:289
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
Definition: QuantUtils.cpp:269
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
Definition: QuantUtils.cpp:162
ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &attr)
Definition: TosaOps.cpp:180
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
Definition: QuantUtils.cpp:214
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
Definition: TosaOps.cpp:3300
bool isa_tosa_shape_type(mlir::Type t)
Definition: TosaOps.cpp:3324
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, Attribute attr)
Definition: TosaOps.cpp:202
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Definition: QuantUtils.cpp:243
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
Definition: TosaOps.cpp:220
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
Definition: ShapeUtils.h:33
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45