MLIR  22.0.0git
TosaOps.cpp
Go to the documentation of this file.
1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://www.mlplatform.org/tosa/tosa_spec.html
12 //
13 //===----------------------------------------------------------------------===//
14 
22 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/TypeUtilities.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 
31 #include <numeric>
32 
33 using namespace mlir;
34 using namespace mlir::tosa;
35 
36 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
38 
39 //===----------------------------------------------------------------------===//
40 // Tosa dialect interface includes.
41 //===----------------------------------------------------------------------===//
42 
43 #include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
44 #include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
45 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
46 #include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
47 
48 namespace {
49 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
50 
51 //===----------------------------------------------------------------------===//
52 // Dialect Function Inliner Interface.
53 //===----------------------------------------------------------------------===//
54 struct TosaInlinerInterface : public DialectInlinerInterface {
56 
57  //===--------------------------------------------------------------------===//
58  // Analysis Hooks.
59  //===--------------------------------------------------------------------===//
60 
61  /// All operations can be inlined by default.
62  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
63  IRMapping &map) const final {
64  return true;
65  }
66 
67  /// All regions with If and While parent operators can be inlined.
68  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
69  IRMapping &map) const final {
70  return (isa<tosa::IfOp>(dest->getParentOp()) ||
71  isa<tosa::WhileOp>(dest->getParentOp()));
72  }
73 };
74 
75 /// This class implements the bytecode interface for the Tosa dialect.
76 struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
77  TosaDialectBytecodeInterface(Dialect *dialect)
78  : BytecodeDialectInterface(dialect) {}
79 
80  //===--------------------------------------------------------------------===//
81  // Attributes
82 
83  Attribute readAttribute(DialectBytecodeReader &reader) const override {
84  return ::readAttribute(getContext(), reader);
85  }
86 
87  LogicalResult writeAttribute(Attribute attr,
88  DialectBytecodeWriter &writer) const override {
89  return ::writeAttribute(attr, writer);
90  }
91 
92  //===--------------------------------------------------------------------===//
93  // Types
94 
95  Type readType(DialectBytecodeReader &reader) const override {
96  return ::readType(getContext(), reader);
97  }
98 
99  LogicalResult writeType(Type type,
100  DialectBytecodeWriter &writer) const override {
101  return ::writeType(type, writer);
102  }
103 
104  void writeVersion(DialectBytecodeWriter &writer) const final {
105  // TODO: Populate.
106  }
107 
108  std::unique_ptr<DialectVersion>
109  readVersion(DialectBytecodeReader &reader) const final {
110  // TODO: Populate
111  reader.emitError("Dialect does not support versioning");
112  return nullptr;
113  }
114 
115  LogicalResult upgradeFromVersion(Operation *topLevelOp,
116  const DialectVersion &version) const final {
117  return success();
118  }
119 };
120 
121 } // namespace
122 
123 //===----------------------------------------------------------------------===//
124 // TOSA control flow support.
125 //===----------------------------------------------------------------------===//
126 
127 /// Returns the while loop body.
128 SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
129  return {&getBodyGraph()};
130 }
131 
132 //===----------------------------------------------------------------------===//
133 // TOSA variable operator support.
134 //===----------------------------------------------------------------------===//
135 
137  return to_vector(llvm::map_range(shape, [](int64_t dim) {
138  return dim == -1 ? ShapedType::kDynamic : dim;
139  }));
140 }
141 
142 // returns type of variable op
143 RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
144  Type elementType = variableOp.getType();
145  DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
146  auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
147  return RankedTensorType::get(shape, elementType);
148 }
149 
150 //===----------------------------------------------------------------------===//
151 // Tosa dialect initialization.
152 //===----------------------------------------------------------------------===//
153 
154 void TosaDialect::initialize() {
155  addTypes<
156 #define GET_TYPEDEF_LIST
157 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
158  >();
159  addOperations<
160 #define GET_OP_LIST
161 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
162  >();
163  addAttributes<
164 #define GET_ATTRDEF_LIST
165 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
166  >();
167  addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
168  declarePromisedInterfaces<
169  shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
170  ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
171  LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
172  LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
173  BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
174  NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
175  GreaterEqualOp, MatMulOp>();
176 }
177 
179  Type type, Location loc) {
180  // Tosa dialect constants only support ElementsAttr unlike standard dialect
181  // constant which supports all attributes.
182  if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
183  return tosa::ConstShapeOp::create(builder, loc, type,
184  llvm::cast<DenseIntElementsAttr>(value));
185  }
186  if (llvm::isa<ElementsAttr>(value))
187  return tosa::ConstOp::create(builder, loc, type,
188  llvm::cast<ElementsAttr>(value));
189  return nullptr;
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // Parsers and printers
194 //===----------------------------------------------------------------------===//
195 
196 namespace {
197 
198 ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
199  DenseElementsAttr &varShapeAttr,
200  TypeAttr &typeAttr) {
201  if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
202  if (!shapedType.hasRank())
203  return parser.emitError(parser.getCurrentLocation())
204  << "expected ranked type";
205 
206  auto elementType = shapedType.getElementType();
207  typeAttr = TypeAttr::get(elementType);
208  ArrayRef<int64_t> shape = shapedType.getShape();
209  Builder builder(parser.getContext());
210  varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
211  return success();
212  }
213  return parser.emitError(parser.getCurrentLocation())
214  << "expected shaped type";
215 }
216 
217 } // namespace
218 
219 // parses the optional initial value or type for a tosa variable
220 // with initial value:
221 // tosa.variable @name = dense<0.0> : tensor<1x8xf32>
222 //
223 // without initial value:
224 // tosa.variable @name : tensor<1x8xf32>
226  OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
227  Attribute &initialValueAttr) {
228  if (succeeded(parser.parseOptionalEqual())) {
229  if (failed(parser.parseAttribute(initialValueAttr))) {
230  return parser.emitError(parser.getCurrentLocation())
231  << "expected attribute";
232  }
233  if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
234  return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
235  typeAttr);
236  }
237  return parser.emitError(parser.getCurrentLocation())
238  << "expected Typed attr";
239  }
240 
241  initialValueAttr = nullptr;
242  Type parsedType;
243  if (failed(parser.parseColonType(parsedType))) {
244  return parser.emitError(parser.getCurrentLocation())
245  << "expected type after colon";
246  }
247  return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
248 }
249 
251  OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
252  TypeAttr typeAttr, Attribute initialValueAttr) {
253  bool needsSpace = false;
254  if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
255  auto shape =
256  convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
257  Type elementType = typeAttr.getValue();
258  RankedTensorType tensorType =
259  RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
260  auto tensorTypeAttr = TypeAttr::get(tensorType);
261  p << ": ";
262  p.printAttribute(tensorTypeAttr);
263  needsSpace = true; // subsequent attr value needs a space separator
264  }
265  if (initialValueAttr) {
266  if (needsSpace)
267  p << ' ';
268  p << "= ";
269  p.printAttribute(initialValueAttr);
270  }
271 }
272 
273 namespace {
274 
275 // parse attributes with special handling for tosa enum attributes
276 template <typename EnumType>
277 ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser,
278  NamedAttrList &outAttrs) {
279  llvm::StringRef name;
280  if (parser.parseOptionalKeyword(&name) || parser.parseEqual())
281  return failure();
282 
283  // special handling: rounding_mode accepts a *bare* RoundingMode enum
284  // keyword.
285  llvm::StringRef kw;
286  if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
287  if (name == "rounding_mode" &&
288  succeeded(parser.parseOptionalKeyword(&kw))) {
289  auto sym = symbolizeRoundingMode(kw);
290  if (!sym)
291  return parser.emitError(parser.getCurrentLocation())
292  << "invalid rounding_mode value: " << kw;
293  auto attr = RoundingModeAttr::get(parser.getContext(), sym.value());
294  outAttrs.push_back(NamedAttribute(name, attr));
295  return success();
296  }
297  }
298  // special handling: mode accepts a *bare* ResizeMode enum keyword.
299  if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
300  if (name == "mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
301  auto sym = symbolizeResizeMode(kw);
302  if (!sym)
303  return parser.emitError(parser.getCurrentLocation())
304  << "invalid resize mode value: " << kw;
305  auto attr = ResizeModeAttr::get(parser.getContext(), sym.value());
306  outAttrs.push_back(NamedAttribute(name, attr));
307  return success();
308  }
309  }
310  // special handling: nan_mode accepts a *bare* NanPropagationMode enum
311  // keyword.
312  if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
313  if (name == "nan_mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
314  auto sym = symbolizeNanPropagationMode(kw);
315  if (!sym)
316  return parser.emitError(parser.getCurrentLocation())
317  << "invalid nan_mode value: " << kw;
318  auto attr = NanPropagationModeAttr::get(parser.getContext(), sym.value());
319  outAttrs.push_back(NamedAttribute(name, attr));
320  return success();
321  }
322  }
323 
324  // Default path: parse any normal attribute literal, including fully qualified
325  // enum keyword
326  Attribute attr;
327  return parser.parseAttribute(attr, name, outAttrs);
328 }
329 
330 template <typename EnumType>
331 ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
332  // parse operands
334  if (parser.parseCommaSeparatedList(
335  [&]() { return parser.parseOperand(operands.emplace_back()); }))
336  return failure();
337 
338  // Parse { attr-dict } with special handling for enum bare token
339  NamedAttrList attrs;
340  if (succeeded(parser.parseOptionalLBrace()) &&
341  failed(parser.parseOptionalRBrace())) {
342  do {
343  if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
344  return failure();
345  } while (succeeded(parser.parseOptionalComma()));
346  if (parser.parseRBrace())
347  return failure();
348  }
349 
350  FunctionType fnTy;
351  if (parser.parseColonType(fnTy))
352  return failure();
353 
354  // Resolve operands and types
355  if (failed(parser.resolveOperands(operands, fnTy.getInputs(),
356  parser.getCurrentLocation(),
357  result.operands)))
358  return failure();
359 
360  result.addTypes(fnTy.getResult(0));
361  result.addAttributes(attrs);
362 
363  return success();
364 }
365 
366 void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) {
367  parser << namedAttr.getName().strref() << " = ";
368  auto attr = namedAttr.getValue();
369  if (auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
370  parser << roundingModeAttr.getValue();
371  } else if (auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
372  parser << resizeModeAttr.getValue();
373  } else if (auto nanPropagationModeAttr =
374  dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
375  parser << nanPropagationModeAttr.getValue();
376  } else {
377  parser.printAttribute(attr);
378  }
379 }
380 
381 // print with special handling for default valued NanPropagationMode attribute
382 void printWithNanPropagationHandling(OpAsmPrinter &parser, Operation *op) {
383  parser << " ";
384  parser.printOperands(op->getOperands());
385 
386  NamedAttrList toPrint(op->getAttrs());
387  // remove default NanPropagate attribute
388  const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
389  for (auto attr : op->getAttrs()) {
390  if (auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
391  if (nanAttr.getValue() == kDefaultNanValue) {
392  // elide from toPrint
393  toPrint.erase(attr.getName());
394  break;
395  }
396  }
397  }
398 
399  if (!toPrint.empty()) {
400  parser << " {";
401  llvm::interleaveComma(toPrint, parser, [&](const NamedAttribute namedAttr) {
402  printNamedAttr(parser, namedAttr);
403  });
404  parser << "}";
405  }
406 
407  parser << " : ";
408  parser.printFunctionalType(op);
409 }
410 
411 // print with special handling for enums: RoundingMode, ResizeMode
412 void printWithEnumHandling(OpAsmPrinter &parser, Operation *op) {
413  parser << " ";
414  parser.printOperands(op->getOperands());
415 
416  if (!op->getAttrs().empty()) {
417  parser << " {";
418  llvm::interleaveComma(op->getAttrs(), parser,
419  [&](const NamedAttribute namedAttr) {
420  printNamedAttr(parser, namedAttr);
421  });
422  parser << "}";
423  }
424 
425  parser << " : ";
426  parser.printFunctionalType(op);
427 }
428 
429 } // namespace
430 
431 ParseResult RescaleOp::parse(OpAsmParser &parser, OperationState &result) {
432  return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
433 }
434 
435 void RescaleOp::print(OpAsmPrinter &parser) {
436  printWithEnumHandling(parser, *this);
437 }
438 
439 ParseResult ApplyScaleOp::parse(OpAsmParser &parser, OperationState &result) {
440  return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
441 }
442 
443 void ApplyScaleOp::print(OpAsmPrinter &parser) {
444  printWithEnumHandling(parser, *this);
445 }
446 
447 ParseResult ResizeOp::parse(OpAsmParser &parser, OperationState &result) {
448  return parseWithEnumHandling<tosa::ResizeMode>(parser, result);
449 }
450 
451 void ResizeOp::print(OpAsmPrinter &parser) {
452  printWithEnumHandling(parser, *this);
453 }
454 
455 ParseResult ArgMaxOp::parse(OpAsmParser &parser, OperationState &result) {
456  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
457 }
458 
459 void ArgMaxOp::print(OpAsmPrinter &parser) {
460  printWithNanPropagationHandling(parser, *this);
461 }
462 
463 ParseResult MaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) {
464  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
465 }
466 
467 void MaxPool2dOp::print(OpAsmPrinter &parser) {
468  printWithNanPropagationHandling(parser, *this);
469 }
470 
471 ParseResult ClampOp::parse(OpAsmParser &parser, OperationState &result) {
472  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
473 }
474 
475 void ClampOp::print(OpAsmPrinter &parser) {
476  printWithNanPropagationHandling(parser, *this);
477 }
478 
479 ParseResult MaximumOp::parse(OpAsmParser &parser, OperationState &result) {
480  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
481 }
482 
483 void MaximumOp::print(OpAsmPrinter &parser) {
484  printWithNanPropagationHandling(parser, *this);
485 }
486 
487 ParseResult MinimumOp::parse(OpAsmParser &parser, OperationState &result) {
488  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
489 }
490 
491 void MinimumOp::print(OpAsmPrinter &parser) {
492  printWithNanPropagationHandling(parser, *this);
493 }
494 
495 ParseResult ReduceMaxOp::parse(OpAsmParser &parser, OperationState &result) {
496  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
497 }
498 
499 void ReduceMaxOp::print(OpAsmPrinter &parser) {
500  printWithNanPropagationHandling(parser, *this);
501 }
502 
503 ParseResult ReduceMinOp::parse(OpAsmParser &parser, OperationState &result) {
504  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
505 }
506 
507 void ReduceMinOp::print(OpAsmPrinter &parser) {
508  printWithNanPropagationHandling(parser, *this);
509 }
510 
511 //===----------------------------------------------------------------------===//
512 // Tosa utilities.
513 //===----------------------------------------------------------------------===//
514 
515 static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
516  if (lhs % rhs != 0)
517  return std::nullopt;
518  return lhs / rhs;
519 }
520 
522  auto srcType = getElementTypeOrSelf(type);
523  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
524  srcType = quantType.getStorageType();
525  return srcType;
526 }
527 
529  return getStorageElementTypeOrSelf(value.getType());
530 }
531 
532 static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
533  Value valZp, StringRef name) {
535  Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
536 
537  bool bothInts =
538  mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
539  bool sameBitWidth =
540  (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
541 
542  if (!bothInts || !sameBitWidth) {
543  return op->emitOpError()
544  << "expected " << name << " and " << name
545  << "_zp to both be integer of the same bitwidth, but got " << eType
546  << " vs. " << eZpType;
547  }
548  return success();
549 }
550 
551 // Create a pad-const const tensor with value of `val` of required data-type
553  Value src, int32_t val) {
554  const auto srcType = getElementTypeOrSelf(src);
555  const auto srcElemType = getStorageElementTypeOrSelf(src);
556  const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
557  const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
558  const auto padConstAttr{
559  llvm::isa<FloatType>(srcElemType)
560  ? DenseElementsAttr::get(padConstEType,
561  builder.getFloatAttr(srcElemType, val))
562  : DenseElementsAttr::get(padConstEType,
563  builder.getIntegerAttr(srcElemType, val))};
564  return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
565 }
566 
567 //===----------------------------------------------------------------------===//
568 // TOSA Operator Verifiers.
569 //===----------------------------------------------------------------------===//
570 
571 template <typename T>
572 static LogicalResult verifyConvOp(T op) {
573  const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
574  const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
575 
576  auto inputEType = inputType.getElementType();
577  auto weightEType = weightType.getElementType();
578  auto biasEType =
579  llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
580  auto resultEType =
581  llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
582  bool biasIsFloat = llvm::isa<FloatType>(biasEType);
583  bool resultIsFloat = llvm::isa<FloatType>(resultEType);
584 
585  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
586  inputEType = quantType.getStorageType();
587 
588  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
589  weightEType = quantType.getStorageType();
590 
591  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
592  biasEType = quantType.getStorageType();
593 
594  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
595  resultEType = quantType.getStorageType();
596 
597  if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
598  // for now, only enforce bias element type == result element type for
599  // float types.
600  op.emitOpError(
601  "expect both bias and result to have same element type, got ")
602  << biasEType << " and " << resultEType;
603  return failure();
604  }
605 
606  if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
607  isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
608  if (inputEType != weightEType) {
609  op.emitOpError(
610  "expect both input and weight to have same element type, got ")
611  << inputEType << " and " << weightEType;
612  return failure();
613  }
614  }
615 
616  bool inputIsFloat = llvm::isa<FloatType>(inputEType);
617  bool weightIsFloat = llvm::isa<FloatType>(weightEType);
618 
619  // Either both must be float or both non-float.
620  if (inputIsFloat != weightIsFloat) {
621  op.emitOpError(
622  "expect both input and weight to be float or not together, got ")
623  << inputEType << " and " << weightEType;
624  return failure();
625  }
626 
627  auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
628  if (inputEType != inputZpEType) {
629  return op.emitOpError("expect both input and its zero point are the same "
630  "element type, got ")
631  << inputEType << " and " << inputZpEType;
632  }
633 
634  auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
635  if (weightEType != weightZpEType) {
636  return op.emitOpError("expect both weight and its zero point are the same "
637  "element type, got ")
638  << weightEType << " and " << weightZpEType;
639  }
640 
641  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
642  if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
643  return failure();
644 
645  FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
646  if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
647  return failure();
648 
649  return success();
650 }
651 
652 LogicalResult tosa::ConstOp::verify() {
653 
654  auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().getType());
655  auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
656 
657  if (!attrType || !outputType) {
658  emitOpError("expected tensors for attr/result type");
659  return failure();
660  }
661 
662  if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
663  outputType.getElementType())) {
664  if (result.getStorageType() == attrType.getElementType())
665  return success();
666  }
667 
668  if (attrType.getElementType() != outputType.getElementType()) {
669  emitOpError("expected same attr/result element types");
670  return failure();
671  }
672 
673  return success();
674 }
675 
676 template <typename T>
677 static LogicalResult verifyConvOpModes(T op) {
678  auto inputEType =
679  llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
680 
681  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
682  inputEType = quantType.getStorageType();
683 
684  auto accType = op.getAccType();
685  if (inputEType.isInteger(8) && !accType.isInteger(32))
686  return op.emitOpError("accumulator type for i8 tensor is not i32");
687 
688  if (inputEType.isInteger(16) && !accType.isInteger(48))
689  return op.emitOpError("accumulator type for i16 tensor is not i48");
690 
691  if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
692  return op.emitOpError("accumulator type for f8 tensor is not f16");
693 
694  if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
695  return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
696 
697  if (inputEType.isBF16() && !accType.isF32())
698  return op.emitOpError("accumulator type for bf16 tensor is not f32");
699 
700  if (inputEType.isF32() && !accType.isF32())
701  return op.emitOpError("accumulator type for f32 tensor is not f32");
702 
703  auto resultEType =
704  llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
705 
706  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
707  resultEType = quantType.getStorageType();
708 
709  return success();
710 }
711 
712 //===----------------------------------------------------------------------===//
713 // ERROR_IF functions.
714 // ERROR_IF is a predicate that must set an error if the condition holds.
715 //===----------------------------------------------------------------------===//
716 
717 template <typename T>
718 static LogicalResult verifyConvOpErrorIf(T op) {
719  llvm::ArrayRef<int64_t> padding = op.getPad();
720  if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
721  return op.emitOpError("expect all padding values to be >= 0, got ")
722  << padding;
723 
724  llvm::ArrayRef<int64_t> strides = op.getStride();
725  if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
726  return op.emitOpError("expect all stride values to be >= 1, got ")
727  << strides;
728 
729  llvm::ArrayRef<int64_t> dilations = op.getDilation();
730  if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
731  return op.emitOpError("expect all dilation values to be >= 1, got ")
732  << dilations;
733 
734  const RankedTensorType outputType =
735  llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
736  if (!outputType)
737  // Skip following checks if output is not ranked
738  return success();
739 
740  const RankedTensorType inputType =
741  llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
742  const RankedTensorType weightType =
743  llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
744 
745  if (inputType && weightType) {
746  const auto verifyOutputSize =
747  [&op](const int64_t inputSize, const int64_t kernelSize,
748  const int64_t outputSize, const int64_t padBefore,
749  const int64_t padAfter, const int64_t stride,
750  const int64_t dilation, const llvm::StringRef dimName,
751  const llvm::StringRef dimAxis,
752  const llvm::StringRef padBeforeName,
753  const llvm::StringRef padAfterName) -> LogicalResult {
754  if (inputSize == ShapedType::kDynamic ||
755  kernelSize == ShapedType::kDynamic)
756  return success();
757 
758  // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
759 
760  const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
761  inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
762  stride);
763  if (!calculatedOutSizeMinusOne.has_value())
764  return op.emitOpError("expected input_")
765  << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
766  << padAfterName << " - (kernel_" << dimName
767  << " - 1) * dilation_" << dimAxis
768  << " to be wholly divisible by stride_" << dimAxis << ", got ("
769  << inputSize << " - 1 + " << padBefore << " + " << padAfter
770  << " - (" << kernelSize << " - 1) * " << dilation << ") / "
771  << stride;
772 
773  const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
774  if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
775  return op.emitOpError("calculated output ")
776  << dimName << " did not match expected: "
777  << "calculated=" << calculatedOutSize
778  << ", expected=" << outputSize;
779 
780  return success();
781  };
782 
783  // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
784  if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
785  if (failed(verifyOutputSize(
786  inputType.getDimSize(1), weightType.getDimSize(1),
787  outputType.getDimSize(1), padding[0], padding[1], strides[0],
788  dilations[0], "height", "y", "top", "bottom")))
789  return failure();
790 
791  if (failed(verifyOutputSize(
792  inputType.getDimSize(2), weightType.getDimSize(2),
793  outputType.getDimSize(2), padding[2], padding[3], strides[1],
794  dilations[1], "width", "x", "left", "right")))
795  return failure();
796  }
797 
798  // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
799  if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
800  if (failed(verifyOutputSize(
801  inputType.getDimSize(1), weightType.getDimSize(0),
802  outputType.getDimSize(1), padding[0], padding[1], strides[0],
803  dilations[0], "height", "y", "top", "bottom")))
804  return failure();
805 
806  if (failed(verifyOutputSize(
807  inputType.getDimSize(2), weightType.getDimSize(1),
808  outputType.getDimSize(2), padding[2], padding[3], strides[1],
809  dilations[1], "width", "x", "left", "right")))
810  return failure();
811  }
812 
813  // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
814  if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
815  if (failed(verifyOutputSize(
816  inputType.getDimSize(1), weightType.getDimSize(1),
817  outputType.getDimSize(1), padding[0], padding[1], strides[0],
818  dilations[0], "depth", "d", "front", "back")))
819  return failure();
820 
821  if (failed(verifyOutputSize(
822  inputType.getDimSize(2), weightType.getDimSize(2),
823  outputType.getDimSize(2), padding[2], padding[3], strides[1],
824  dilations[1], "height", "y", "top", "bottom")))
825  return failure();
826 
827  if (failed(verifyOutputSize(
828  inputType.getDimSize(3), weightType.getDimSize(3),
829  outputType.getDimSize(3), padding[4], padding[5], strides[2],
830  dilations[2], "width", "x", "left", "right")))
831  return failure();
832  }
833  }
834 
835  const RankedTensorType biasType =
836  llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
837  if (!biasType)
838  // Skip following checks if bias is not ranked
839  return success();
840 
841  const int64_t biasChannels = biasType.getDimSize(0);
842  const int64_t outputChannels =
843  outputType.getDimSize(outputType.getRank() - 1);
844  if (biasChannels == ShapedType::kDynamic ||
845  outputChannels == ShapedType::kDynamic)
846  // Skip following checks if biasChannels or outputChannels is dynamic dim
847  return success();
848 
849  if (biasChannels != outputChannels && biasChannels != 1)
850  return op.emitOpError(
851  "bias channels expected to be equal to output channels (")
852  << outputChannels << ") or 1, got " << biasChannels;
853 
854  return success();
855 }
856 
857 // Verify whether same type and shape of the given two types.
858 static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
859  StringRef name1, Type type2,
860  StringRef name2) {
861  auto shapeType1 = dyn_cast<ShapedType>(type1);
862  auto shapeType2 = dyn_cast<ShapedType>(type2);
863  if (!shapeType1 || !shapeType2)
864  return failure();
865 
866  auto elemType1 = shapeType1.getElementType();
867  auto elemType2 = shapeType2.getElementType();
868  if (elemType1 != elemType2)
869  return op->emitOpError()
870  << "require same element type for " << name1 << " (" << elemType1
871  << ") and " << name2 << " (" << elemType2 << ")";
872 
873  if (failed(verifyCompatibleShape(type1, type2)))
874  return op->emitOpError()
875  << "require same shapes for " << name1 << " (" << type1 << ") and "
876  << name2 << " (" << type2 << ")";
877 
878  return success();
879 }
880 
881 // Verify whether same length, type, and shape of the given two tensor lists.
882 static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
883  StringRef name1,
884  ValueRange list2,
885  StringRef name2) {
886  if (list1.size() != list2.size())
887  return op->emitOpError()
888  << "require same number of values in " << name1 << " ("
889  << list1.size() << ") and " << name2 << " (" << list2.size() << ")";
890 
891  for (auto [type1, type2] :
892  llvm::zip_equal(list1.getTypes(), list2.getTypes())) {
893  if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
894  return failure();
895  }
896 
897  return success();
898 }
899 
900 static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
901  ShapeAdaptor shapeAdaptor(type);
902  if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
903  return success();
904 
905  return shapeAdaptor.getNumElements() == 1 ? success() : failure();
906 }
907 
908 template <typename T>
909 static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
910  Operation *symTableOp =
911  op->template getParentWithTrait<OpTrait::SymbolTable>();
912  if (!symTableOp)
913  // If the operation is not the scope of a symbol table, we cannot
914  // verify it against it's declaration.
915  return success();
916 
917  SymbolTable symTable(symTableOp);
918  const auto varOp = symTable.lookup<tosa::VariableOp>(op.getName());
919 
920  // Verify prior declaration
921  if (!varOp)
922  return op->emitOpError("'")
923  << op.getName() << "' has not been declared by 'tosa.variable'";
924 
925  // Verify type and shape
926  auto variableType = getVariableType(varOp);
927  if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
928  "the input tensor")
929  .failed())
930  return failure();
931  return success();
932 }
933 
934 // verify that inType and outType have same element types
935 template <typename T>
936 static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
937  auto inputType = llvm::dyn_cast<TensorType>(inType);
938  auto outputType = llvm::dyn_cast<TensorType>(outType);
939  if (!inputType) {
940  op.emitOpError("expect shaped tensor for input, got ") << inType;
941  return failure();
942  }
943  if (!outputType) {
944  op.emitOpError("expect shaped tensor for output, got ") << outType;
945  return failure();
946  }
947  auto inputElementType = inputType.getElementType();
948  auto outputElementType = outputType.getElementType();
949  auto inputQuantType =
950  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
951  auto outputQuantType =
952  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
953  if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
954  (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
955  inputElementType != outputElementType) {
956  // only check if both element types are int/index/float/UniformQuantized
957  // eg, not sure how to check quant::QuantizedType
958  // this happens in test_conv2d_q_grouped_convolution in
959  // tfl-to-tosa-pipeline.mlir
960  op.emitOpError("expect input and output to have same element type, got ")
961  << inputElementType << " and " << outputElementType;
962  return failure();
963  }
964  return success();
965 }
966 
967 LogicalResult tosa::ArgMaxOp::verify() {
968  const ShapedType resultType = llvm::cast<ShapedType>(getType());
969 
970  // Ensure output is of 32-bit integer
971  if (const auto resultETy = resultType.getElementType();
972  !resultETy.isIntOrIndex())
973  return emitOpError("result tensor is not of integer type");
974 
975  const auto inputType = llvm::cast<ShapedType>(getInput().getType());
976  if (!inputType.hasRank())
977  return success();
978 
979  // Ensure axis is within the tensor rank
980  const int64_t axis = getAxisAttr().getInt();
981  if (((axis < 0) || axis >= inputType.getRank()))
982  return emitOpError("specified axis is outside the rank of the tensor");
983 
984  if (!resultType.hasRank())
985  return success();
986 
987  const ArrayRef<int64_t> inputShape = inputType.getShape();
988  const ArrayRef<int64_t> outputShape = resultType.getShape();
989  llvm::SmallVector<int64_t> expectedOutputShape(inputShape);
990  expectedOutputShape.erase(expectedOutputShape.begin() + axis);
991  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
992  return emitOpError("expected output shape '")
993  << expectedOutputShape << "', got '" << outputShape << "'";
994 
995  return success();
996 }
997 
998 template <typename T>
999 static LogicalResult verifyPoolingOp(T op) {
1000  const llvm::ArrayRef<int64_t> kernel = op.getKernel();
1001  if (llvm::any_of(kernel, [](int64_t s) { return s < 1; }))
1002  return op.emitOpError("expect all kernel values to be >= 1, got ")
1003  << kernel;
1004 
1005  const llvm::ArrayRef<int64_t> strides = op.getStride();
1006  if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
1007  return op.emitOpError("expect all stride values to be >= 1, got ")
1008  << strides;
1009 
1010  const llvm::ArrayRef<int64_t> padding = op.getPad();
1011  if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
1012  return op.emitOpError("expect all padding values to be >= 0, got ")
1013  << padding;
1014 
1015  // Padding must be less than kernel size to avoid a divide-by-zero
1016  const int64_t kernelX = kernel[1];
1017  const int64_t padLeft = padding[2];
1018  const int64_t padRight = padding[3];
1019  if (padRight >= kernelX || padLeft >= kernelX)
1020  return op.emitOpError("expected left/right padding to be less than the "
1021  "width of the kernel, got pad_left=")
1022  << padLeft << ", pad_right=" << padRight << ", kernel_x=" << kernelX;
1023 
1024  const int64_t kernelY = kernel[0];
1025  const int64_t padTop = padding[0];
1026  const int64_t padBottom = padding[1];
1027  if (padTop >= kernelY || padBottom >= kernelY)
1028  return op.emitOpError("expected top/bottom padding to be less than the "
1029  "height of the kernel, got pad_top=")
1030  << padTop << ", pad_bottom=" << padBottom
1031  << ", kernel_y=" << kernelY;
1032 
1033  const auto inputType =
1034  llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
1035  const auto outputType =
1036  llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
1037  if (!inputType || !outputType)
1038  return success();
1039 
1040  const auto verifyOutputSize =
1041  [&op](const int64_t inputSize, const int64_t outputSize,
1042  const int64_t kernelSize, const int64_t strideSize,
1043  const int64_t padBefore, const int64_t padAfter,
1044  const llvm::StringRef dimName, const llvm::StringRef dimAxis,
1045  const llvm::StringRef padBeforeName,
1046  const llvm::StringRef padAfterName) -> LogicalResult {
1047  if (ShapedType::isDynamic(inputSize))
1048  return success();
1049 
1050  const std::optional<int64_t> calculatedOutSizeMinusOne =
1051  idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1052  if (!calculatedOutSizeMinusOne.has_value())
1053  return op.emitOpError("expected input_")
1054  << dimName << " + pad_" << padBeforeName << " + pad_"
1055  << padAfterName << " - kernel_" << dimAxis
1056  << " to be wholly divisible by stride_" << dimAxis << ", got ("
1057  << inputSize << " + " << padBefore << " + " << padAfter << " - "
1058  << kernelSize << ") / " << strideSize;
1059 
1060  const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1061  if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1062  return op.emitOpError("calculated output ")
1063  << dimName << " did not match expected: "
1064  << "calculated=" << calculatedOutSize
1065  << ", expected=" << outputSize;
1066 
1067  return success();
1068  };
1069 
1070  if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
1071  kernel[0], strides[0], padding[0], padding[1],
1072  "height", "y", "top", "bottom")))
1073  return failure();
1074 
1075  if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
1076  kernel[1], strides[1], padding[2], padding[3],
1077  "width", "x", "left", "right")))
1078  return failure();
1079 
1080  return success();
1081 }
1082 
1083 LogicalResult tosa::AvgPool2dOp::verify() {
1084  if (failed(verifyPoolingOp(*this)))
1085  return failure();
1086 
1087  const Type inputETy = getStorageElementTypeOrSelf(getInput().getType());
1088  const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType());
1089  const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType());
1090  const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType());
1091 
1092  auto accType = getAccType();
1093  if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1094  return emitOpError("accumulator type for integer tensor is not i32");
1095 
1096  if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
1097  return emitOpError("accumulator type for f16 tensor is not f16/f32");
1098 
1099  if (inputETy.isBF16() && !accType.isF32())
1100  return emitOpError("accumulator type for bf16 tensor is not f32");
1101 
1102  if (inputETy.isF32() && !accType.isF32())
1103  return emitOpError("accumulator type for f32 tensor is not f32");
1104 
1105  if (inputETy != inputZpETy)
1106  return emitOpError("expect both input and its zero point are the same "
1107  "element type, got ")
1108  << inputETy << " and " << inputZpETy;
1109 
1110  if (resultETy != outputZpETy)
1111  return emitOpError("expect both output and its zero point are the same "
1112  "element type, got ")
1113  << resultETy << " and " << outputZpETy;
1114 
1115  FailureOr<int64_t> maybeIZp = getInputZeroPoint();
1116  if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
1117  return failure();
1118 
1119  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1120  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
1121  return failure();
1122 
1123  return success();
1124 }
1125 
1126 LogicalResult tosa::ClampOp::verify() {
1127  mlir::Type inputETy =
1128  llvm::cast<ShapedType>(getInput().getType()).getElementType();
1129  if (auto quantType =
1130  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1131  inputETy = quantType.getStorageType();
1132  }
1133  mlir::Type outputETy =
1134  llvm::cast<ShapedType>(getOutput().getType()).getElementType();
1135  if (auto quantType =
1136  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1137  outputETy = quantType.getStorageType();
1138  }
1139  if (inputETy != outputETy)
1140  return emitOpError("input/output element types are incompatible.");
1141 
1142  auto maxValAttr = getMaxValAttr();
1143  auto minValAttr = getMinValAttr();
1144 
1145  unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
1146 
1147  if (inputETy.isInteger(dataTypeBitWidth)) {
1148  // if input datatype is integer, check that the min_val/max_val attributes
1149  // are integer attributes, and that their type is the same as the input's
1150  // datatype
1151  auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1152  auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1153  if (!intMaxValAttr || !intMinValAttr ||
1154  (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1155  (intMaxValAttr.getType() != inputETy))
1156  return emitOpError("min/max attributes types are incompatible with "
1157  "input/output element types.");
1158 
1159  const bool isUnsigned = inputETy.isUnsignedInteger();
1160  const bool isBoolean = inputETy.isInteger(1);
1161  const APInt minVal = intMinValAttr.getValue();
1162  const APInt maxVal = intMaxValAttr.getValue();
1163  if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1164  return emitOpError("expected min_val <= max_val, got min_val=")
1165  << minValAttr << ", max_val=" << maxValAttr;
1166  } else {
1167  // otherwise, input datatype is float, check that the min_val/max_val
1168  // attributes share the same type and that their type is the same as the
1169  // input's datatype
1170  auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1171  auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1172  if (!floatMaxValAttr || !floatMinValAttr ||
1173  (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1174  (floatMaxValAttr.getType() != inputETy))
1175  return emitOpError("min/max attributes types are incompatible with "
1176  "input/output element types.");
1177 
1178  const APFloat minVal = floatMinValAttr.getValue();
1179  const APFloat maxVal = floatMaxValAttr.getValue();
1180  if (minVal.isNaN() || maxVal.isNaN())
1181  return emitOpError("min/max attributes should not be 'NaN', got min_val=")
1182  << minValAttr << ", max_val=" << maxValAttr;
1183 
1184  if (maxVal < minVal)
1185  return emitOpError("expected min_val <= max_val, got min_val=")
1186  << minValAttr << ", max_val=" << maxValAttr;
1187  }
1188 
1189  return success();
1190 }
1191 
1192 //===----------------------------------------------------------------------===//
1193 // TOSA Operator Quantization Builders.
1194 //===----------------------------------------------------------------------===//
1195 
1196 /// This builder is called on all convolution operators except TransposeConv,
1197 /// which has specialized output shape semantics. The builder also defines the
1198 /// bitwidth of the output given the bit width of the input & weight content.
1199 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
1200  Type outputType, Value input, Value weight,
1201  Value bias, DenseI64ArrayAttr pad,
1202  DenseI64ArrayAttr stride,
1203  DenseI64ArrayAttr dilation,
1204  TypeAttr accType) {
1205  auto zps = createZPsAsConst(builder, input, weight);
1206  result.addOperands({input, weight, bias, zps.first, zps.second});
1207  result.addAttribute("pad", pad);
1208  result.addAttribute("stride", stride);
1209  result.addAttribute("dilation", dilation);
1210  result.addAttribute("acc_type", accType);
1211  Type finalOutputType = outputType;
1212  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1213  if (quantAttr) {
1214  finalOutputType =
1215  buildConvOpResultTypeInfo(builder, outputType, input, weight);
1216  }
1217  result.addTypes(finalOutputType);
1218 }
1219 
1220 /// Handles tosa.transpose_conv2d which has outpad and output shape
1221 /// attributes.
1222 static void
1224  Type outputType, Value input, Value weight,
1225  Value bias, DenseI64ArrayAttr outpad,
1226  DenseI64ArrayAttr stride, TypeAttr accType) {
1227  auto zps = createZPsAsConst(builder, input, weight);
1228  result.addOperands({input, weight, bias, zps.first, zps.second});
1229  result.addAttribute("out_pad", outpad);
1230  result.addAttribute("stride", stride);
1231  result.addAttribute("acc_type", accType);
1232  Type finalOutputType = outputType;
1233  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1234  if (quantAttr) {
1235  finalOutputType =
1236  buildConvOpResultTypeInfo(builder, outputType, input, weight);
1237  }
1238  result.addTypes(finalOutputType);
1239 }
1240 
1241 /// The tosa.matmul op is also intended to be generated where a fully_connected
1242 /// op must be constructed where the weight is not a constant. In this case,
1243 /// the fully_connected op must be expressed using matmul.
1244 /// TODO: Add link to the leglization document explaining this.
1246  OperationState &result, Type outputType,
1247  Value a, Value b) {
1248  auto zps = createZPsAsConst(builder, a, b);
1249  result.addOperands({a, b, zps.first, zps.second});
1250 
1251  Type finalOutputType{outputType};
1252  if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
1253  auto eType = getStorageElementTypeOrSelf(a.getType());
1254  auto inputBits = eType.getIntOrFloatBitWidth();
1255 
1256  auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1257  assert(outputShapedType && "Output must be a shaped type");
1258 
1259  IntegerType accElementType;
1260  if (inputBits == 16)
1261  accElementType = builder.getIntegerType(48);
1262  else
1263  accElementType = builder.getI32Type();
1264 
1265  finalOutputType = outputShapedType.clone(accElementType);
1266  }
1267  result.addTypes(finalOutputType);
1268 }
1269 
1270 /// Both the tosa.avg_pool2d and unary ops use the same
1271 /// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
1272 /// has additional parameters not part of the unary ops.
1273 static void
1275  Type outputType, Value input,
1276  DenseArrayAttr kernel, DenseArrayAttr stride,
1277  DenseArrayAttr pad, TypeAttr accType) {
1278  const Location loc{result.location};
1279  int64_t inputZp{0};
1280  int64_t outputZp{0};
1281 
1282  if (auto quantAttr =
1283  buildUnaryOpQuantizationAttr(builder, input, outputType)) {
1284  inputZp = quantAttr.getInputZp();
1285  outputZp = quantAttr.getOutputZp();
1286  }
1287  const std::optional<Value> inputZpOp =
1288  createZeroPointTensor(builder, loc, input.getType(), inputZp);
1289  if (!inputZpOp) {
1290  (void)emitError(
1291  loc,
1292  "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1293  }
1294  const std::optional<Value> outputZpOp =
1295  createZeroPointTensor(builder, loc, outputType, outputZp);
1296  if (!outputZpOp) {
1297  (void)emitError(loc, "Failed to create output zero point tensor for "
1298  "quantized AVG_POOL2D op");
1299  }
1300 
1301  if (inputZpOp && outputZpOp) {
1302  result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1303  } else {
1304  // failed to create one or more zero points above: just add input as
1305  // operands this will trigger error in building the op because of missing
1306  // zero points
1307  result.addOperands({input});
1308  }
1309  result.addAttribute("kernel", kernel);
1310  result.addAttribute("stride", stride);
1311  result.addAttribute("pad", pad);
1312  result.addAttribute("acc_type", accType);
1313  result.types.push_back(outputType);
1314 }
1315 
1316 /// This builder is called on single-parameter negate operator
1317 /// to construct input and output zero points based on their
1318 /// types.
1320  OperationState &result, Type outputType,
1321  Value input) {
1322  const Location loc{result.location};
1323  int64_t input1Zp{0};
1324  int64_t outputZp{0};
1325  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
1326  if (quantAttr) {
1327  input1Zp = quantAttr.getInputZp();
1328  outputZp = quantAttr.getOutputZp();
1329  }
1330  const std::optional<Value> input1ZpOp =
1331  createZeroPointTensor(builder, loc, input.getType(), input1Zp);
1332  if (!input1ZpOp) {
1333  (void)emitError(
1334  loc, "Failed to create input1 zero point for quantized NEGATE op");
1335  }
1336 
1337  const std::optional<Value> outputZpOp =
1338  createZeroPointTensor(builder, loc, input.getType(), outputZp);
1339  if (!outputZpOp) {
1340  (void)emitError(
1341  loc, "Failed to create output zero point for quantized NEGATE op");
1342  }
1343 
1344  if (input1ZpOp && outputZpOp) {
1345  result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1346  } else {
1347  // failed to create one or more zero points above: just add input as
1348  // operands. This will trigger error in building the op because of
1349  // missing zero points
1350  result.addOperands({input});
1351  }
1352 
1353  result.types.push_back(outputType);
1354 }
1355 
1356 /// This builder is called on TOSA pad operator that needs to create its own
1357 /// OptionalAttr quantization_attr parameter to scale the padding values
1358 /// correctly. No pad_const is interpreted as zero-padding.
1359 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
1360  Type outputType, Value input,
1361  Value paddings) {
1362  const Location loc{result.location};
1363  int32_t zp{0};
1364  const auto quantAttr = buildPadOpQuantizationAttr(builder, input);
1365  if (quantAttr) {
1366  zp = static_cast<int32_t>(quantAttr.getInputZp());
1367  }
1368  const auto padConstOp{createPadConstTensor(builder, loc, input, zp)};
1369  result.addOperands({input, paddings, padConstOp});
1370  result.types.push_back(outputType);
1371 }
1372 
1373 static void buildVariableOp(OpBuilder &builder, OperationState &result,
1374  StringRef name, Type variableType,
1375  Attribute initialValue) {
1376  const Location loc{result.location};
1377  auto nameAttr = builder.getStringAttr(name);
1378 
1379  auto shapedType = dyn_cast<ShapedType>(variableType);
1380  if (!shapedType) {
1381  (void)emitError(loc, "variable type must be a shaped type");
1382  return;
1383  }
1384  if (!shapedType.hasRank()) {
1385  (void)emitError(loc, "variable type must be a ranked type");
1386  return;
1387  }
1388 
1389  auto elementType = shapedType.getElementType();
1390  auto elementTypeAttr = TypeAttr::get(elementType);
1391  ArrayRef<int64_t> shape = shapedType.getShape();
1392  auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
1393 
1394  result.addAttribute("sym_name", nameAttr);
1395  result.addAttribute("var_shape", varShapeAttr);
1396  result.addAttribute("type", elementTypeAttr);
1397  result.addAttribute("initial_value", initialValue);
1398 }
1399 
1400 //===----------------------------------------------------------------------===//
1401 // TOSA Operator Return Type Inference.
1402 //===----------------------------------------------------------------------===//
1403 
1404 static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
1405  SmallVector<int64_t> &outShape) {
1406  int64_t outRank = 0;
1407  for (int i = 0, e = operands.size(); i != e; ++i) {
1408  auto shape = operands.getShape(i);
1409  if (!shape.hasRank()) {
1410  // TODO(jennik): Update function to have better case handling for
1411  // invalid operands and for ranked tensors.
1412  return failure();
1413  }
1414  outRank = std::max<int64_t>(outRank, shape.getRank());
1415  }
1416 
1417  outShape.resize(outRank, 1);
1418 
1419  for (int i = 0, e = operands.size(); i != e; ++i) {
1420  auto shape = operands.getShape(i);
1421  auto rankDiff = outShape.size() - shape.getRank();
1422 
1423  for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
1424  auto dim1 = outShape[i + rankDiff];
1425  auto dim2 = shape.getDimSize(i);
1426  auto resolvedDim = dim1;
1427 
1428  if (dim1 == 1) {
1429  resolvedDim = dim2;
1430  } else if (dim2 == 1) {
1431  resolvedDim = dim1;
1432  } else if (dim1 != dim2) {
1433  return failure();
1434  }
1435  outShape[i + rankDiff] = resolvedDim;
1436  }
1437  }
1438 
1439  return success();
1440 }
1441 
1442 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1443  MLIRContext *context, ::std::optional<Location> location,
1444  ArgMaxOp::Adaptor adaptor,
1445  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1446  ShapeAdaptor inputShape(adaptor.getInput().getType());
1447  IntegerAttr axis = adaptor.getProperties().axis;
1448  int32_t axisVal = axis.getValue().getSExtValue();
1449 
1450  if (!inputShape.hasRank()) {
1451  inferredReturnShapes.push_back(ShapedTypeComponents());
1452  return success();
1453  }
1454 
1455  SmallVector<int64_t> outShape;
1456  outShape.reserve(inputShape.getRank() - 1);
1457  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1458  if (i == axisVal)
1459  continue;
1460  outShape.push_back(inputShape.getDimSize(i));
1461  }
1462 
1463  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1464  return success();
1465 }
1466 
1467 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1468  MLIRContext *context, ::std::optional<Location> location,
1469  RFFT2dOp::Adaptor adaptor,
1470  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1471  ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1472 
1473  if (!inputShape.hasRank())
1474  return failure();
1475 
1476  llvm::SmallVector<int64_t> outputShape;
1477  outputShape.resize(3, ShapedType::kDynamic);
1478  outputShape[0] = inputShape.getDimSize(0);
1479  outputShape[1] = inputShape.getDimSize(1);
1480  int64_t inWidth = inputShape.getDimSize(2);
1481 
1482  // Note that we can support this calculation symbolically
1483  // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
1484  if (inWidth != ShapedType::kDynamic)
1485  outputShape[2] = inWidth / 2 + 1;
1486 
1487  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1488  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1489 
1490  return success();
1491 }
1492 
1493 static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
1494  const llvm::StringRef dimName) {
1495  const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1496  if (!isPowerOfTwo)
1497  return op->emitOpError("expected ")
1498  << dimName << " to be a power of two, got " << dimSize;
1499 
1500  return success();
1501 }
1502 
1503 LogicalResult tosa::RFFT2dOp::verify() {
1504  const auto outputTypes = getResultTypes();
1505  if (failed(verifyCompatibleShapes(outputTypes)))
1506  return emitOpError("expected output shapes to match, got ") << outputTypes;
1507 
1508  const auto inputType =
1509  llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1510  if (!inputType)
1511  return success();
1512 
1513  const int64_t height = inputType.getDimSize(1);
1514  if (ShapedType::isStatic(height) &&
1515  failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1516  return failure();
1517 
1518  const int64_t width = inputType.getDimSize(2);
1519  if (ShapedType::isStatic(width) &&
1520  failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1521  return failure();
1522 
1523  const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1524  if (!outputType)
1525  return success();
1526 
1527  // Batch and height input/output dimensions should match
1528  if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
1529  outputType.getShape().drop_back())))
1530  return emitOpError("expected batch and height dimensions of input/output "
1531  "to match, got input=")
1532  << inputType << " output=" << outputType;
1533 
1534  // Output width dimension expected to be input_width / 2 + 1
1535  const int64_t outputWidth = outputType.getDimSize(2);
1536  if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1537  (outputWidth != (width / 2) + 1))
1538  return emitOpError(
1539  "expected output width to be equal to input_width / 2 + 1, got ")
1540  << outputWidth;
1541 
1542  return success();
1543 }
1544 
1545 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1546  MLIRContext *context, ::std::optional<Location> location,
1547  FFT2dOp::Adaptor adaptor,
1548  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1549  inferredReturnShapes.push_back(
1550  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
1551  inferredReturnShapes.push_back(
1552  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
1553  return success();
1554 }
1555 
1556 LogicalResult tosa::FFT2dOp::verify() {
1557  const auto inputRealType =
1558  llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1559  const auto inputImagType =
1560  llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
1561  if (!inputRealType || !inputImagType)
1562  return success();
1563 
1564  const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
1565  return ShapedType::isDynamic(a) ? a : b;
1566  };
1567 
1568  const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1569  inputImagType.getDimSize(1));
1570  if (ShapedType::isStatic(height) &&
1571  failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1572  return failure();
1573 
1574  const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1575  inputImagType.getDimSize(2));
1576  if (ShapedType::isStatic(width) &&
1577  failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1578  return failure();
1579 
1580  return success();
1581 }
1582 
1583 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1584  MLIRContext *context, ::std::optional<Location> location,
1585  ConcatOp::Adaptor adaptor,
1586  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1587  // Infer all dimension sizes by reducing based on inputs.
1588  const Properties &prop = adaptor.getProperties();
1589  int32_t axis = prop.axis.getValue().getSExtValue();
1590  llvm::SmallVector<int64_t> outputShape;
1591  bool hasRankedInput = false;
1592  for (auto operand : adaptor.getOperands()) {
1593  ShapeAdaptor operandShape(operand.getType());
1594  if (!operandShape.hasRank())
1595  continue;
1596 
1597  // Copy the Operand's rank.
1598  if (!hasRankedInput)
1599  outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1600 
1601  // Copy shapes until the dim is non-dynamic.
1602  for (int i = 0, s = operandShape.getRank(); i < s; i++) {
1603  if (i == axis || operandShape.isDynamicDim(i))
1604  continue;
1605  if (outputShape[i] == ShapedType::kDynamic)
1606  outputShape[i] = operandShape.getDimSize(i);
1607  if (outputShape[i] != operandShape.getDimSize(i))
1608  return emitOptionalError(location,
1609  "Cannot concat tensors with different sizes"
1610  " on the non-axis dimension ",
1611  i);
1612  }
1613 
1614  hasRankedInput = true;
1615  }
1616 
1617  if (adaptor.getInput1().empty())
1618  return failure();
1619 
1620  Type inputType =
1621  llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1622  if (!hasRankedInput) {
1623  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1624  return success();
1625  }
1626 
1627  // Determine the dimension size along the concatenation axis.
1628  int64_t concatDimSize = 0;
1629  for (auto operand : adaptor.getOperands()) {
1630  ShapeAdaptor operandShape(operand.getType());
1631 
1632  // We need to know the length of the concatenation axis of all inputs to
1633  // determine the dimension size of the output shape.
1634  if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1635  concatDimSize = ShapedType::kDynamic;
1636  break;
1637  }
1638 
1639  concatDimSize += operandShape.getDimSize(axis);
1640  }
1641 
1642  outputShape[axis] = concatDimSize;
1643 
1644  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1645  return success();
1646 }
1647 
1648 LogicalResult tosa::ConcatOp::verify() {
1649  // check that each input has same element type as output
1650  auto outType = getOutput().getType();
1651  const Operation::operand_range inputList = getInput1();
1652 
1653  // Check there is at least one input
1654  if (inputList.empty())
1655  return emitOpError("expect at least one input");
1656 
1657  if (!llvm::all_of(inputList, [&](auto input) {
1658  return succeeded(verifySameElementTypes(
1659  *this, /* inType = */ input.getType(), outType));
1660  })) {
1661  return failure();
1662  }
1663 
1664  const int32_t axis = getAxis();
1665  ShapeAdaptor firstRankedInputShape = nullptr;
1666  for (const auto &input : inputList) {
1667  const Type inputType = input.getType();
1668  ShapeAdaptor currShape(inputType);
1669  if (currShape.hasRank()) {
1670  firstRankedInputShape = currShape;
1671  // Check axis is in expected range
1672  if (axis < 0 || axis >= firstRankedInputShape.getRank())
1673  return emitOpError("expect axis to be within range 0 < axis < "
1674  "rank(input1[firstRankedTensorIdx]), got ")
1675  << axis;
1676  break;
1677  }
1678  }
1679 
1680  const auto allOperandsHasRank = [](const Value input) {
1681  return ShapeAdaptor(input.getType()).hasRank();
1682  };
1683  if (llvm::all_of(inputList, allOperandsHasRank)) {
1684  const int64_t firstInputRank = firstRankedInputShape.getRank();
1685 
1686  for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {
1687  const ShapeAdaptor inputShape(input.getType());
1688  const int64_t inputRank = inputShape.getRank();
1689  const size_t operandNum = index + 1;
1690 
1691  // Check that each operand has the same rank
1692  if (inputRank != firstInputRank)
1693  return emitOpError(
1694  "expect all operands to have the same rank, but got ")
1695  << firstInputRank << " vs " << inputRank << " on operands 0 and "
1696  << operandNum;
1697 
1698  // Check non-axis dims match
1699  for (int i = 0; i < inputRank; i++) {
1700  const int64_t inputDim = inputShape.getDimSize(i);
1701  const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
1702  if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
1703  inputShape.isDynamicDim(i))
1704  continue;
1705  if (inputDim != firstInputDim)
1706  return emitOpError("expect all operand shapes to have the same sizes "
1707  "on non-axis dimensions, but got ")
1708  << inputDim << " vs " << firstInputDim << " at index " << i
1709  << " on operands 0 and " << operandNum;
1710  }
1711  }
1712 
1713  // ERROR_IF(axis_sum != shape[axis]);
1714  int64_t axisSum = 0;
1715  for (const auto &input : inputList) {
1716  const ShapeAdaptor inputShape(input.getType());
1717  if (inputShape.isDynamicDim(axis)) {
1718  // make axisSum negative to indicate invalid value
1719  axisSum = -1;
1720  break;
1721  }
1722  axisSum += inputShape.getDimSize(axis);
1723  }
1724  const ShapeAdaptor outputShape(outType);
1725  if (axisSum >= 0 && outputShape.hasRank() &&
1726  !outputShape.isDynamicDim(axis) &&
1727  axisSum != outputShape.getDimSize(axis))
1728  return emitOpError("requires sum of axis dimensions of input1 "
1729  "equal to output axis dimension, got ")
1730  << axisSum << " and " << outputShape.getDimSize(axis);
1731  }
1732 
1733  return success();
1734 }
1735 
1736 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1737  MLIRContext *context, ::std::optional<Location> location,
1738  ValueShapeRange operands, DictionaryAttr attributes,
1739  OpaqueProperties properties, RegionRange regions,
1740  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1741  auto elementType = IntegerType::get(context, /*width=*/1);
1742 
1743  llvm::SmallVector<int64_t> outShape;
1744  if (resolveBroadcastShape(operands, outShape).failed()) {
1745  inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
1746  return success();
1747  }
1748 
1749  inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
1750  return success();
1751 }
1752 
1753 bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1754  if (l.size() != r.size() || l.size() != 1)
1755  return false;
1756  return succeeded(verifyCompatibleShape(l[0], r[0]));
1757 }
1758 
1759 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1760  MLIRContext *context, ::std::optional<Location> location,
1761  MatMulOp::Adaptor adaptor,
1762  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1763  ShapeAdaptor lhsShape(adaptor.getA().getType());
1764  ShapeAdaptor rhsShape(adaptor.getB().getType());
1765 
1766  // All shapes are dynamic.
1767  SmallVector<int64_t> outShape;
1768  outShape.resize(3, ShapedType::kDynamic);
1769 
1770  if (lhsShape.hasRank()) {
1771  outShape[0] = lhsShape.getDimSize(0);
1772  outShape[1] = lhsShape.getDimSize(1);
1773  }
1774 
1775  if (rhsShape.hasRank()) {
1776  outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1777  : outShape[0];
1778  outShape[2] = rhsShape.getDimSize(2);
1779  }
1780 
1781  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1782  return success();
1783 }
1784 
1785 LogicalResult MatMulOp::verify() {
1786  auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
1787  auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
1788 
1789  // Must be shaped tensor types
1790  if (!aType)
1791  return emitOpError("expect a shaped tensor for input a, got ")
1792  << getA().getType();
1793 
1794  if (!bType)
1795  return emitOpError("expect a shaped tensor for input b, got ")
1796  << getB().getType();
1797 
1798  auto aElementType = aType.getElementType();
1799  auto bElementType = bType.getElementType();
1800 
1801  auto aQuantizedEType =
1802  llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1803  auto bQuantizedEType =
1804  llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1805 
1806  if (aQuantizedEType || bQuantizedEType) {
1807  if (!aQuantizedEType || !bQuantizedEType) {
1808  return emitOpError("expect operands to be both quantized or both not "
1809  "quantized, got ")
1810  << aElementType << " and " << bElementType;
1811  }
1812  // both a and b have quantized element types
1813  auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1814  auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1815  if (aQuantWidth != bQuantWidth) {
1816  return emitOpError("expect quantized operands to have same widths, got ")
1817  << aQuantWidth << " and " << bQuantWidth;
1818  }
1819  }
1820 
1821  // check a_zp and b_zp
1822  auto aEType = getStorageElementTypeOrSelf(aType);
1823  auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
1824  if (aEType != aZpEType) {
1825  return emitOpError("expect input a and a_zp have the same "
1826  "element type, got ")
1827  << aEType << " and " << aZpEType;
1828  }
1829 
1830  auto bEType = getStorageElementTypeOrSelf(bType);
1831  auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
1832  if (bEType != bZpEType) {
1833  return emitOpError("expect input b and b_zp have the same "
1834  "element type, got ")
1835  << bEType << " and " << bZpEType;
1836  }
1837 
1838  FailureOr<int64_t> maybeAZp = getAZeroPoint();
1839  if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1840  return failure();
1841 
1842  FailureOr<int64_t> maybeBZp = getBZeroPoint();
1843  if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1844  return failure();
1845 
1846  return success();
1847 }
1848 
1849 LogicalResult tosa::PadOp::inferReturnTypeComponents(
1850  MLIRContext *context, ::std::optional<Location> location,
1851  PadOp::Adaptor adaptor,
1852  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1853  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1854  auto paddingRank =
1855  cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
1856  SmallVector<int64_t> outputShape;
1857 
1858  // If the input rank is unknown, we can infer the output rank using the
1859  // padding shape's rank divided by 2.
1860  if (!inputShape.hasRank()) {
1861  outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
1862  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1863  return success();
1864  }
1865 
1866  SmallVector<int64_t> paddingValues;
1867  // If the paddings value is not a constant, all dimensions must be dynamic.
1868  if (!tosa::getConstShapeValues(adaptor.getPadding().getDefiningOp(),
1869  paddingValues)) {
1870  outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1871  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1872  return success();
1873  }
1874 
1875  outputShape.reserve(inputShape.getRank());
1876  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1877  if (inputShape.isDynamicDim(i)) {
1878  outputShape.push_back(ShapedType::kDynamic);
1879  continue;
1880  }
1881  auto padFront = paddingValues[i * 2];
1882  auto padBack = paddingValues[i * 2 + 1];
1883  if (padFront < 0 || padBack < 0) {
1884  // if either padding for dim i is -1, output dim is unknown
1885  outputShape.push_back(ShapedType::kDynamic);
1886  continue;
1887  }
1888 
1889  outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
1890  }
1891 
1892  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1893  return success();
1894 }
1895 
1896 LogicalResult tosa::PadOp::verify() {
1897  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1898  /* outType = */ getOutput().getType())
1899  .failed()) {
1900  return failure();
1901  }
1902 
1903  if (auto padConst = getPadConst()) {
1904  if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
1905  /* outType = */ getOutput().getType())
1906  .failed()) {
1907  return failure();
1908  }
1909  }
1910 
1911  RankedTensorType inputType =
1912  llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1913  RankedTensorType outputType =
1914  llvm::dyn_cast<RankedTensorType>(getOutput().getType());
1915  if (!inputType || !outputType)
1916  return success();
1917 
1918  auto inputRank = inputType.getRank();
1919  auto outputRank = outputType.getRank();
1920  if (inputRank != outputRank)
1921  return emitOpError() << "expect same input and output tensor rank, but got "
1922  << "inputRank: " << inputRank
1923  << ", outputRank: " << outputRank;
1924 
1925  DenseIntElementsAttr paddingAttr;
1926  if (!matchPattern(getPadding(), m_Constant(&paddingAttr))) {
1927  return failure();
1928  }
1929 
1930  auto paddingValues = paddingAttr.getValues<APInt>();
1931  if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
1932  return emitOpError() << "padding tensor must have " << inputRank
1933  << " * 2 = " << inputRank * 2 << " elements, but got "
1934  << paddingValues.size();
1935 
1936  auto inputShape = inputType.getShape();
1937  auto outputShape = outputType.getShape();
1938 
1939  for (int64_t i = 0; i < inputRank; ++i) {
1940  int64_t padStart = paddingValues[i * 2].getSExtValue();
1941  int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
1942 
1943  if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
1944  return emitOpError()
1945  << "invalid padding values at dimension " << i
1946  << ": values must be non-negative or -1 for dynamic padding, got ["
1947  << padStart << ", " << padEnd << "]";
1948  }
1949 
1950  // Skip shape verification for dynamic input/output
1951  if (inputShape[i] == ShapedType::kDynamic ||
1952  outputShape[i] == ShapedType::kDynamic)
1953  continue;
1954 
1955  if (outputShape[i] != inputShape[i] + padStart + padEnd) {
1956  return emitOpError() << "mismatch in output shape at dimension " << i
1957  << ": expected " << inputShape[i] << " + "
1958  << padStart << " + " << padEnd << " = "
1959  << (inputShape[i] + padStart + padEnd)
1960  << ", but got " << outputShape[i];
1961  }
1962  }
1963 
1964  return success();
1965 }
1966 
1967 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
1968  MLIRContext *context, ::std::optional<Location> location,
1969  SliceOp::Adaptor adaptor,
1970  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1971 
1972  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
1973  SmallVector<int64_t> start;
1974  SmallVector<int64_t> size;
1975 
1976  if (!tosa::getConstShapeValues(adaptor.getStart().getDefiningOp(), start) ||
1977  !tosa::getConstShapeValues(adaptor.getSize().getDefiningOp(), size)) {
1978  auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
1979  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
1980  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
1981  return success();
1982  }
1983 
1984  // if size[i] is -1, all remaining elements in dimension i are included
1985  // in the slice, similar to TF.
1986  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1987  // initialize outputShape to all unknown
1988  SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
1989  if (inputShape.hasRank()) {
1990  for (size_t i = 0; i < size.size(); i++) {
1991  if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
1992  (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
1993  start[i] < inputShape.getDimSize(i))) {
1994  // size[i] is not 0 and not < -1, and start[i] is in valid range
1995  if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
1996  // input shape has unknown dim[i] - only valid if size[i] > 0
1997  if (size[i] > 0) {
1998  outputShape[i] = size[i];
1999  }
2000  } else {
2001  // input shape has known dim[i]
2002  if (size[i] == -1) {
2003  outputShape[i] = inputShape.getDimSize(i) - start[i];
2004  } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2005  // start[i] + size[i] is within bound of input shape's dim[i]
2006  outputShape[i] = size[i];
2007  }
2008  }
2009  }
2010  }
2011  } else {
2012  outputShape = convertToMlirShape(size);
2013  }
2014  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2015  return success();
2016 }
2017 
2018 LogicalResult tosa::SliceOp::verify() {
2019  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2020  /* outType = */ getOutput().getType())
2021  .failed())
2022  return failure();
2023 
2024  const ShapeAdaptor inputShape(getInput1().getType());
2025  if (inputShape.hasRank()) {
2026  const auto inputRank = inputShape.getRank();
2027  const ShapeAdaptor outputShape(getOutput().getType());
2028  if (outputShape.hasRank() && inputRank != outputShape.getRank())
2029  return emitOpError(
2030  "expect input1 and output to have the same ranks, got ")
2031  << inputRank << " and " << outputShape.getRank();
2032 
2033  const auto startShapeRank =
2034  llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
2035  if (inputRank != startShapeRank)
2036  return emitOpError("length of start is not equal to rank of input shape");
2037 
2038  const auto sizeShapeRank =
2039  llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
2040  if (inputRank != sizeShapeRank)
2041  return emitOpError("length of size is not equal to rank of input shape");
2042  }
2043 
2044  return success();
2045 }
2046 
2047 LogicalResult tosa::MulOp::inferReturnTypeComponents(
2048  MLIRContext *context, ::std::optional<Location> location,
2049  ValueShapeRange operands, DictionaryAttr attributes,
2050  OpaqueProperties properties, RegionRange regions,
2051  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2052  // mul op's output shape only depend on input1 and input2, not on shift
2053  ValueShapeRange twoInputs = operands.drop_back();
2054  llvm::SmallVector<int64_t> outShape;
2055  if (resolveBroadcastShape(twoInputs, outShape).failed()) {
2056  inferredReturnShapes.push_back(ShapedTypeComponents());
2057  } else {
2058  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2059  }
2060  return success();
2061 }
2062 
2063 LogicalResult tosa::MulOp::verify() {
2064  const Value output = getOutput();
2065  auto resElemType = getElementTypeOrSelf(output);
2066 
2067  // Verify if the element type among operands and result match tosa
2068  // specification.
2069  if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2070  IntegerType lhsIntType =
2071  dyn_cast<IntegerType>(getElementTypeOrSelf(getInput1()));
2072  IntegerType rhsIntType =
2073  dyn_cast<IntegerType>(getElementTypeOrSelf(getInput2()));
2074  if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2075  return emitOpError("requires the same element type for all operands");
2076 
2077  // Though the spec requires the element type of result to be i32, a more
2078  // relaxed way is provided at dialect level for easier cooperating with
2079  // other dialects.
2080  if (lhsIntType.getWidth() > resIntType.getWidth())
2081  return emitOpError("invalid data type size for operands or result");
2082 
2083  } else {
2084  // For other supported type, the spec requires requires the same element
2085  // type for all operands (excludes `shift` operand) and results.
2086  for (int i = 0; i < 2; ++i) {
2087  if (getElementTypeOrSelf(getOperand(i)) != resElemType)
2088  return emitOpError(
2089  "requires the same element type for all operands and results");
2090  }
2091 
2092  // verify shift has value 0 for non-integer types
2093  ElementsAttr shift_elem;
2094  if (matchPattern(getShift(), m_Constant(&shift_elem))) {
2095  int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
2096  if (shift != 0) {
2097  return emitOpError() << "require shift to be 0 for float type";
2098  }
2099  }
2100  }
2101 
2102  // Verify the op has same ranks for all main operands (excludes extra operands
2103  // such as shift of mul op, so this is the only difference with the built-in
2104  // `SameOperandsAndResultRank` trait) and results types, if known.
2105  TypeRange operandTypes = getOperandTypes();
2106  ShapedType aType = cast<ShapedType>(operandTypes[0]);
2107  ShapedType bType = cast<ShapedType>(operandTypes[1]);
2108 
2109  const bool aHasRank = aType.hasRank();
2110  const bool bHasRank = bType.hasRank();
2111  if (aHasRank && bHasRank) {
2112  const int64_t aRank = aType.getRank();
2113  const int64_t bRank = bType.getRank();
2114  if (aRank != bRank)
2115  return emitOpError("a and b operands don't have matching ranks, got ")
2116  << aRank << " and " << bRank;
2117 
2118  // check for broadcast compatible shapes
2119  SmallVector<int64_t> resultShape;
2121  aType.getShape(), bType.getShape(), resultShape))
2122  return emitOpError("a and b operands don't have broadcast-compatible "
2123  "shapes, got ")
2124  << aType << " and " << bType;
2125  }
2126 
2127  ShapedType resultType = cast<ShapedType>(output.getType());
2128  if (!resultType.hasRank())
2129  return success();
2130 
2131  const int64_t resultRank = resultType.getRank();
2132  if (aHasRank && resultRank != aType.getRank())
2133  return emitOpError("result type has different rank than a, got ")
2134  << resultRank << " vs " << aType.getRank();
2135  if (bHasRank && resultRank != bType.getRank())
2136  return emitOpError("result type has different rank than b, got ")
2137  << resultRank << " vs " << bType.getRank();
2138 
2139  return success();
2140 }
2141 
2142 LogicalResult tosa::TableOp::inferReturnTypeComponents(
2143  MLIRContext *context, ::std::optional<Location> location,
2144  TableOp::Adaptor adaptor,
2145  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2146  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2147 
2148  if (!inputShape.hasRank()) {
2149  inferredReturnShapes.push_back(ShapedTypeComponents());
2150  return success();
2151  }
2152 
2153  inferredReturnShapes.resize(1);
2154  inputShape.getDims(inferredReturnShapes[0]);
2155  return success();
2156 }
2157 
2158 LogicalResult tosa::TableOp::verify() {
2159  const TensorType inputType = getInput1().getType();
2160  const TensorType outputType = getOutput().getType();
2161 
2162  if (!inputType.hasRank() || !outputType.hasRank())
2163  return success();
2164 
2165  if (inputType.getRank() != outputType.getRank())
2166  return emitOpError()
2167  << "expected input tensor rank to equal result tensor rank";
2168 
2169  auto inputDims = inputType.getShape();
2170  auto outputDims = outputType.getShape();
2171  for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
2172  int64_t dim = it.index();
2173  auto [inputDim, outputDim] = it.value();
2174  if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2175  return emitOpError() << "dim(result, " << dim << ") = " << outputDim
2176  << " doesn't match dim(input, " << dim
2177  << ") = " << inputDim;
2178  }
2179  }
2180  return success();
2181 }
2182 
2183 LogicalResult
2184 tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
2185  // Multiples must be constants.
2186  DenseIntElementsAttr multiplesAttr;
2187  if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
2188  return failure();
2189  multiples = llvm::to_vector(
2190  llvm::map_range(multiplesAttr.getValues<APInt>(),
2191  [](const APInt &val) { return val.getSExtValue(); }));
2192  return success();
2193 }
2194 
2195 LogicalResult tosa::TileOp::inferReturnTypeComponents(
2196  MLIRContext *context, ::std::optional<Location> location,
2197  TileOp::Adaptor adaptor,
2198  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2199  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2200  SmallVector<int64_t> multiples;
2201  if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
2202  multiples)) {
2203  auto rank =
2204  cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2205  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2206  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2207  return success();
2208  } else {
2209  multiples = convertToMlirShape(multiples);
2210  }
2211 
2212  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2213  SmallVector<int64_t> outputShape;
2214  if (!inputShape.hasRank()) {
2215  outputShape.resize(multiples.size(), ShapedType::kDynamic);
2216  inferredReturnShapes.push_back(
2217  ShapedTypeComponents(outputShape, inputType));
2218  return success();
2219  } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
2220  return failure();
2221 
2222  // Any non dynamic dimension can be multiplied to a known size.
2223  outputShape.reserve(multiples.size());
2224  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2225  if (multiples[i] == ShapedType::kDynamic) {
2226  outputShape.push_back(ShapedType::kDynamic);
2227  } else {
2228  int64_t dim = inputShape.getDimSize(i);
2229  if (dim != ShapedType::kDynamic)
2230  dim *= multiples[i];
2231  outputShape.push_back(dim);
2232  }
2233  }
2234 
2235  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2236  return success();
2237 }
2238 
2239 LogicalResult tosa::TileOp::verify() {
2240  if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
2241  /* outType = */ getOutput().getType())
2242  .failed()) {
2243  return failure();
2244  }
2245  ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
2246  ShapedType outputType = llvm::cast<ShapedType>(getType());
2247 
2248  shapeType multiplesType =
2249  llvm::cast<tosa::shapeType>(getMultiples().getType());
2250 
2251  auto multiplesRank = multiplesType.getRank();
2252 
2253  if (inputType.hasRank()) {
2254  if (inputType.getRank() != multiplesRank)
2255  return emitOpError("expect 'multiples' to have rank ")
2256  << inputType.getRank() << " but got " << multiplesRank << ".";
2257  if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
2258  return emitOpError("expect same input and output tensor rank.");
2259  } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2260  return emitOpError("expect 'multiples' array to have length ")
2261  << outputType.getRank() << " but got " << multiplesRank << ".";
2262 
2263  SmallVector<int64_t> multiples;
2264  if (getConstantMultiples(multiples).succeeded() &&
2265  llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
2266  return emitOpError(
2267  "expect element of 'multiples' to be positive integer or -1.");
2268 
2269  return success();
2270 }
2271 
2272 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2273  if (l.size() != r.size() || l.size() != 1)
2274  return false;
2275  return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
2276 }
2277 
2278 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2279  MLIRContext *context, ::std::optional<Location> location,
2280  ReshapeOp::Adaptor adaptor,
2281  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2282  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2283  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2284  llvm::SmallVector<int64_t> newShapeValue;
2285  if (!tosa::getConstShapeValues(adaptor.getShape().getDefiningOp(),
2286  newShapeValue)) {
2287  auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2288  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2289  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2290  return success();
2291  } else {
2292  newShapeValue = convertToMlirShape(newShapeValue);
2293  }
2294 
2295  // We cannot infer from the total number of elements so we must take the
2296  // shape attribute as exact.
2297  if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2298  inferredReturnShapes.push_back(
2299  ShapedTypeComponents(newShapeValue, inputType));
2300  return success();
2301  }
2302 
2303  // Determine the number of elements covered by the slice of all static
2304  // dimensions. This allows us to infer the length of the remaining dynamic
2305  // dimension.
2306  int64_t numElements = inputShape.getNumElements();
2307  int64_t staticMul = 1;
2308  for (auto val : newShapeValue) {
2309  if (ShapedType::isStatic(val)) {
2310  staticMul *= val;
2311  }
2312  }
2313 
2314  // Determine the length of the dynamic dimension.
2315  for (auto &val : newShapeValue) {
2316  if (ShapedType::isDynamic(val))
2317  val = numElements / staticMul;
2318  }
2319 
2320  inferredReturnShapes.push_back(
2321  ShapedTypeComponents(newShapeValue, inputType));
2322  return success();
2323 }
2324 
2325 llvm::LogicalResult tosa::ReshapeOp::verify() {
2326  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2327  /* outType = */ getOutput().getType())
2328  .failed()) {
2329  return failure();
2330  }
2331  TensorType inputType = getInput1().getType();
2332 
2333  SmallVector<int64_t> shapeValues;
2334  if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
2335  // skip following checks if shape is not constant
2336  return mlir::success();
2337  }
2338 
2339  int missingDims = llvm::count(shapeValues, -1);
2340  if (missingDims > 1)
2341  return emitOpError() << "expected at most one target dimension to be -1";
2342 
2343  const auto outputType = dyn_cast<RankedTensorType>(getType());
2344  if (!outputType)
2345  return success();
2346 
2347  if ((int64_t)shapeValues.size() != outputType.getRank())
2348  return emitOpError() << "new shape does not match result rank";
2349 
2350  for (auto [newShapeDim, outputShapeDim] :
2351  zip(shapeValues, outputType.getShape())) {
2352  if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2353  outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2354  return emitOpError() << "new shape is inconsistent with result shape";
2355 
2356  if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2357  return emitOpError() << "new shape has invalid tensor dimension size "
2358  << newShapeDim;
2359  }
2360 
2361  if (inputType.hasStaticShape()) {
2362  int64_t inputElementsNum = inputType.getNumElements();
2363  if (outputType.hasStaticShape()) {
2364  int64_t outputElementsNum = outputType.getNumElements();
2365  if (inputElementsNum != outputElementsNum) {
2366  return emitOpError() << "cannot reshape " << inputElementsNum
2367  << " elements into " << outputElementsNum;
2368  }
2369  }
2370 
2371  int64_t newShapeElementsNum =
2372  llvm::accumulate(shapeValues, int64_t(1), [](int64_t acc, int64_t dim) {
2373  return (dim > 0) ? acc * dim : acc;
2374  });
2375  bool isStaticNewShape =
2376  llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
2377  if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2378  (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2379  return emitOpError() << "cannot reshape " << inputElementsNum
2380  << " elements into " << newShapeElementsNum;
2381  }
2382  }
2383 
2384  return mlir::success();
2385 }
2386 
2387 // return failure if val is not a constant
2388 // set zp to -1 if val is non-zero float or val is not integer nor float
2389 // otherwise set zp to val's constant value
2390 static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
2391  ElementsAttr zpAttr;
2392  if (!matchPattern(val, m_Constant(&zpAttr))) {
2393  return failure();
2394  }
2395 
2396  Type zpElemType = zpAttr.getElementType();
2397 
2398  if (llvm::isa<FloatType>(zpElemType)) {
2399  if (zpAttr.getValues<APFloat>()[0].isZero()) {
2400  return 0;
2401  }
2402  // return non-zero value to trigger error check
2403  return -1;
2404  }
2405 
2406  if (llvm::isa<IntegerType>(zpElemType)) {
2407  if (signExtend)
2408  return zpAttr.getValues<APInt>()[0].getSExtValue();
2409  else
2410  return zpAttr.getValues<APInt>()[0].getZExtValue();
2411  }
2412 
2413  // return non-zero value to trigger error check
2414  return -1;
2415 }
2416 
2417 template <typename T>
2418 static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
2419  const std::string &operand) {
2420  Type zpElemType = getElementTypeOrSelf(val);
2421 
2422  if (!zpElemType.isInteger(8) && zp != 0) {
2423  // convert operand to lower case for error message
2424  std::string lower = operand;
2425  std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
2426  return op.emitOpError()
2427  << lower << " zero point must be zero for non-int8 integer types";
2428  }
2429 
2430  return success();
2431 }
2432 
2433 static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
2434  const int64_t &zp,
2435  const std::string &operand) {
2436  bool isInputZp = (operand == "Input");
2437 
2438  bool tensorUnsigned =
2439  isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2440  StringRef tensorName = isInputZp ? "input" : "output";
2441 
2442  Type zpElemType = getElementTypeOrSelf(zpVal);
2443 
2444  if (zp != 0) {
2445  if (!zpElemType.isInteger(8) &&
2446  !(zpElemType.isInteger(16) && tensorUnsigned)) {
2447  return op.emitOpError()
2448  << "expect " << tensorName << "_zp of 0, got " << zp;
2449  }
2450  if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
2451  return op.emitOpError() << "expect " << tensorName
2452  << "_zp of 0 or 32768 for unsigned int16 "
2453  << tensorName << ", got " << zp;
2454  }
2455  }
2456 
2457  return success();
2458 }
2459 
2460 #define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2461  FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2462  return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2463  } \
2464  LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2465  return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2466  }
2467 
2468 ZERO_POINT_HELPER(Conv2DOp, Input, true)
2469 ZERO_POINT_HELPER(Conv2DOp, Weight, true)
2470 ZERO_POINT_HELPER(Conv3DOp, Input, true)
2471 ZERO_POINT_HELPER(Conv3DOp, Weight, true)
2472 ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true)
2473 ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true)
2474 ZERO_POINT_HELPER(TransposeConv2DOp, Input, true)
2475 ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true)
2476 ZERO_POINT_HELPER(AvgPool2dOp, Input, true)
2477 ZERO_POINT_HELPER(AvgPool2dOp, Output, true)
2478 ZERO_POINT_HELPER(MatMulOp, A, true)
2479 ZERO_POINT_HELPER(MatMulOp, B, true)
2480 ZERO_POINT_HELPER(NegateOp, Input1, true)
2481 ZERO_POINT_HELPER(NegateOp, Output, true)
2482 ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned())
2483 ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned())
2484 #undef ZERO_POINT_HELPER
2485 
2486 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2487  MLIRContext *context, ::std::optional<Location> location,
2488  TransposeOp::Adaptor adaptor,
2489  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2490  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2491 
2492  // If input rank and permutation length is unknown, the output rank is
2493  // unknown.
2494  if (!inputShape.hasRank()) {
2495  inferredReturnShapes.push_back(ShapedTypeComponents());
2496  return success();
2497  }
2498 
2499  const auto inputRank = inputShape.getRank();
2500 
2501  // This would imply the number of permutations does not match the rank of
2502  // the input which is illegal.
2503  if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
2504  return failure();
2505  }
2506 
2507  SmallVector<int64_t> outputShape;
2508  // Rank-0 means no permutations matter.
2509  if (inputRank == 0) {
2510  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2511  return success();
2512  }
2513 
2514  // Check whether the input dimensions are all the same.
2515  bool allTheSame = true;
2516  for (int i = 1, s = inputRank; i < s; i++) {
2517  if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
2518  allTheSame = false;
2519  break;
2520  }
2521  }
2522 
2523  // If all of the input dimensions are the same we don't care about the
2524  // permutation.
2525  if (allTheSame) {
2526  outputShape.resize(inputRank, inputShape.getDimSize(0));
2527  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2528  return success();
2529  }
2530 
2531  outputShape.resize(inputRank, ShapedType::kDynamic);
2532 
2533  // Constant permutation values must be within the input rank.
2534  if (llvm::any_of(adaptor.getPerms(),
2535  [inputRank](const auto i) { return i >= inputRank; }))
2536  return failure();
2537 
2538  outputShape.reserve(inputRank);
2539  for (int i = 0, s = inputRank; i < s; i++) {
2540  outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
2541  }
2542 
2543  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2544  return success();
2545 }
2546 
2547 LogicalResult tosa::TransposeOp::verify() {
2548  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2549  /* outType = */ getOutput().getType())
2550  .failed()) {
2551  return failure();
2552  }
2553 
2554  const ShapeAdaptor inputShape(getInput1().getType());
2555  const ShapeAdaptor outputShape(getOutput().getType());
2556 
2557  const llvm::ArrayRef<int32_t> constantPerms = getPerms();
2558 
2559  if (inputShape.hasRank() &&
2560  constantPerms.size() != static_cast<size_t>(inputShape.getRank()))
2561  return emitOpError() << "expected perms attribute to have size "
2562  << inputShape.getRank()
2563  << " (input rank) but got size "
2564  << constantPerms.size();
2565 
2566  if (inputShape.hasRank() && outputShape.hasRank() &&
2567  inputShape.getRank() != outputShape.getRank())
2568  return emitOpError()
2569  << "expected input tensor rank to equal result tensor rank";
2570 
2571  if (outputShape.hasRank() &&
2572  constantPerms.size() != static_cast<size_t>(outputShape.getRank()))
2573  return emitOpError() << "expected perms attribute to have size "
2574  << outputShape.getRank()
2575  << " (output rank) but got size "
2576  << constantPerms.size();
2577 
2578  if (!llvm::all_of(constantPerms,
2579  [&constantPerms](int32_t s) {
2580  return s >= 0 &&
2581  static_cast<size_t>(s) < constantPerms.size();
2582  }) ||
2583  !isPermutationVector(llvm::to_vector(llvm::map_range(
2584  constantPerms, [](int32_t v) -> int64_t { return v; }))))
2585  return emitOpError() << "expected valid permutation indices";
2586 
2587  // ERROR_IF(tensor_size(shape1) != tensor_size(shape))
2588  if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2589  inputShape.getNumElements() != outputShape.getNumElements())
2590  return emitOpError() << "expected input1 and output to have same numbers "
2591  "of elements, got "
2592  << inputShape.getNumElements() << " and "
2593  << outputShape.getNumElements();
2594 
2595  // Verify that the types of the input and output tensors are properly
2596  // permuted.
2597  if (inputShape.hasRank() && outputShape.hasRank()) {
2598  for (auto i = 0; i < outputShape.getRank(); i++) {
2599  if (inputShape.isDynamicDim(constantPerms[i]) ||
2600  outputShape.isDynamicDim(i))
2601  continue;
2602 
2603  if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2604  return emitOpError()
2605  << "expected output tensor dim " << i << " to match "
2606  << "input dim " << constantPerms[i] << " with value of "
2607  << inputShape.getDimSize(constantPerms[i]);
2608  }
2609  }
2610 
2611  return success();
2612 }
2613 
2614 LogicalResult TransposeOp::reifyResultShapes(
2615  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2616 
2617  const llvm::ArrayRef<int32_t> transposePerms = getPerms();
2618 
2619  Value input = getInput1();
2620  auto inputType = cast<TensorType>(input.getType());
2621 
2622  SmallVector<OpFoldResult> returnedDims(inputType.getRank());
2623  for (auto dim : transposePerms) {
2624  int32_t dimInInput = transposePerms[dim];
2625  if (inputType.isDynamicDim(dimInInput))
2626  returnedDims[dim] =
2627  tensor::DimOp::create(builder, getLoc(), input, dimInInput)
2628  .getResult();
2629  else
2630  returnedDims[dim] =
2631  builder.getIndexAttr(inputType.getDimSize(dimInInput));
2632  }
2633 
2634  reifiedReturnShapes.emplace_back(std::move(returnedDims));
2635  return success();
2636 }
2637 
2638 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2639  MLIRContext *context, ::std::optional<Location> location,
2640  GatherOp::Adaptor adaptor,
2641  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2642  llvm::SmallVector<int64_t> outputShape;
2643  outputShape.resize(3, ShapedType::kDynamic);
2644 
2645  ShapeAdaptor valuesShape(adaptor.getValues().getType());
2646  if (valuesShape.hasRank()) {
2647  outputShape[0] = valuesShape.getDimSize(0);
2648  outputShape[2] = valuesShape.getDimSize(2);
2649  }
2650 
2651  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2652  if (indicesShape.hasRank()) {
2653  if (outputShape[0] == ShapedType::kDynamic)
2654  outputShape[0] = indicesShape.getDimSize(0);
2655  if (outputShape[1] == ShapedType::kDynamic)
2656  outputShape[1] = indicesShape.getDimSize(1);
2657  }
2658 
2659  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2660  return success();
2661 }
2662 
2663 LogicalResult tosa::GatherOp::verify() {
2664  if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
2665  /* outType = */ getOutput().getType())
2666  .failed()) {
2667  return failure();
2668  }
2669 
2670  const ShapeAdaptor valuesShape(getValues().getType());
2671  const ShapeAdaptor indicesShape(getIndices().getType());
2672  const ShapeAdaptor outputShape(getOutput().getType());
2673 
2674  int64_t N = ShapedType::kDynamic;
2675  int64_t W = ShapedType::kDynamic;
2676  int64_t C = ShapedType::kDynamic;
2677 
2678  if (valuesShape.hasRank()) {
2679  N = valuesShape.getDimSize(0);
2680  C = valuesShape.getDimSize(2);
2681  }
2682  if (indicesShape.hasRank()) {
2683  const int64_t indicesN = indicesShape.getDimSize(0);
2684  W = indicesShape.getDimSize(1);
2685  if (N == ShapedType::kDynamic)
2686  N = indicesN;
2687  else if (indicesN != ShapedType::kDynamic && N != indicesN)
2688  return emitOpError() << "requires indices dimension 0 to have size " << N
2689  << ", got " << indicesN;
2690  }
2691  if (outputShape.hasRank()) {
2692  const int64_t outputN = outputShape.getDimSize(0);
2693  const int64_t outputW = outputShape.getDimSize(1);
2694  const int64_t outputC = outputShape.getDimSize(2);
2695  if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2696  N != outputN)
2697  return emitOpError() << "requires output dimension 0 to have size " << N
2698  << ", got " << outputN;
2699 
2700  if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2701  W != outputW)
2702  return emitOpError() << "requires output dimension 1 to have size " << W
2703  << ", got " << outputW;
2704  if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2705  C != outputC)
2706  return emitOpError() << "requires output dimension 2 to have size " << C
2707  << ", got " << outputC;
2708  }
2709  return success();
2710 }
2711 
2712 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2713  MLIRContext *context, ::std::optional<Location> location,
2714  ResizeOp::Adaptor adaptor,
2715  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2716  llvm::SmallVector<int64_t, 4> outputShape;
2717  outputShape.resize(4, ShapedType::kDynamic);
2718 
2719  ShapeAdaptor inputShape(adaptor.getInput().getType());
2720  if (!inputShape.hasRank())
2721  return failure();
2722 
2723  outputShape[0] = inputShape.getDimSize(0);
2724  outputShape[3] = inputShape.getDimSize(3);
2725  int64_t inputHeight = inputShape.getDimSize(1);
2726  int64_t inputWidth = inputShape.getDimSize(2);
2727 
2728  if ((inputHeight == ShapedType::kDynamic) ||
2729  (inputWidth == ShapedType::kDynamic))
2730  return failure();
2731 
2732  SmallVector<int64_t> scaleInt, offsetInt, borderInt;
2733  if (!tosa::getConstShapeValues(adaptor.getScale().getDefiningOp(),
2734  scaleInt) ||
2735  !tosa::getConstShapeValues(adaptor.getOffset().getDefiningOp(),
2736  offsetInt) ||
2737  !tosa::getConstShapeValues(adaptor.getBorder().getDefiningOp(),
2738  borderInt)) {
2739  return failure();
2740  }
2741 
2742  // Compute the output shape based on attributes: scale, offset, and border.
2743  const int64_t outputHeight =
2744  (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2745  scaleInt[1]) +
2746  1;
2747 
2748  const int64_t outputWidth =
2749  (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2750  scaleInt[3]) +
2751  1;
2752 
2753  if (outputHeight < 0 || outputWidth < 0) {
2754  return emitOptionalError(
2755  location,
2756  "calculated output height and width must be non-negative, "
2757  "got height = ",
2758  outputHeight, ", width = ", outputWidth);
2759  }
2760 
2761  outputShape[1] = outputHeight;
2762  outputShape[2] = outputWidth;
2763  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2764  return success();
2765 }
2766 
2767 LogicalResult tosa::ResizeOp::verify() {
2768  const Value input = getInput();
2769  const Value output = getOutput();
2770  const RankedTensorType inputType =
2771  llvm::dyn_cast<RankedTensorType>(input.getType());
2772  const RankedTensorType outputType =
2773  llvm::dyn_cast<RankedTensorType>(output.getType());
2774 
2775  SmallVector<int64_t> scaleValues;
2776  SmallVector<int64_t> offsetValues;
2777  SmallVector<int64_t> borderValues;
2778  if (!tosa::getConstShapeValues(getScale().getDefiningOp(), scaleValues) ||
2779  !tosa::getConstShapeValues(getOffset().getDefiningOp(), offsetValues) ||
2780  !tosa::getConstShapeValues(getBorder().getDefiningOp(), borderValues)) {
2781  // Skip following checks if shape is not constant
2782  return success();
2783  }
2784 
2785  if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
2786  return emitOpError("expect all scale values to be > 0, got ")
2787  << scaleValues;
2788 
2789  const int64_t scaleYN = scaleValues[0];
2790  const int64_t scaleYD = scaleValues[1];
2791  const int64_t scaleXN = scaleValues[2];
2792  const int64_t scaleXD = scaleValues[3];
2793 
2794  const int64_t offsetY = offsetValues[0];
2795  const int64_t offsetX = offsetValues[1];
2796 
2797  const int64_t borderY = borderValues[0];
2798  const int64_t borderX = borderValues[1];
2799 
2800  if (!inputType)
2801  return success();
2802  if (!outputType)
2803  return success();
2804 
2805  const int64_t oh = outputType.getDimSize(1);
2806  const int64_t ow = outputType.getDimSize(2);
2807  const int64_t ih = inputType.getDimSize(1);
2808  const int64_t iw = inputType.getDimSize(2);
2809 
2810  // Don't check with input height that could be broadcast (ih != 1)
2811  // since Linalg, a consumer of TOSA, expects broadcasting support
2812  // in resize to be available. Taking the cautious approach for now,
2813  // we can consider removing support for broadcasting later.
2814  if (ih != ShapedType::kDynamic && ih != 1) {
2815  const std::optional<int64_t> calculatedOutHeightMinusOne =
2816  idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
2817  if (!calculatedOutHeightMinusOne.has_value())
2818  return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
2819  "border_y ")
2820  << "to be wholly divisible by scale_y_d, got ((" << ih
2821  << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
2822  << ") / " << scaleYD;
2823  const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
2824  if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
2825  return emitOpError("calculated output height did not match expected: ")
2826  << "calculated=" << calculatedOutHeight << ", expected=" << oh;
2827  }
2828 
2829  // Don't check with input width that could be broadcast (iw != 1)
2830  // since Linalg, a consumer of TOSA, expects broadcasting support
2831  // in resize to be available. Taking the cautious approach for now,
2832  // we can consider removing support for broadcasting later.
2833  if (iw != ShapedType::kDynamic && iw != 1) {
2834  const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
2835  const std::optional<int64_t> calculatedOutWidthMinusOne =
2836  idivCheck(scaledInWidth, scaleXD);
2837  if (!calculatedOutWidthMinusOne.has_value())
2838  return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
2839  "border_x ")
2840  << "to be wholly divisible by scale_x_d, got ((" << iw
2841  << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
2842  << ") / " << scaleXD;
2843  const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
2844  if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
2845  return emitOpError("calculated output width did not match expected: ")
2846  << "calculated=" << calculatedOutWidth << ", expected=" << ow;
2847  }
2848 
2849  return success();
2850 }
2851 
2852 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
2853  MLIRContext *context, ::std::optional<Location> location,
2854  ScatterOp::Adaptor adaptor,
2855  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2856  llvm::SmallVector<int64_t> outputShape;
2857  outputShape.resize(3, ShapedType::kDynamic);
2858 
2859  ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
2860  if (valuesInShape.hasRank()) {
2861  outputShape[0] = valuesInShape.getDimSize(0);
2862  outputShape[1] = valuesInShape.getDimSize(1);
2863  outputShape[2] = valuesInShape.getDimSize(2);
2864  }
2865 
2866  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2867  if (indicesShape.hasRank()) {
2868  if (outputShape[0] == ShapedType::kDynamic)
2869  outputShape[0] = indicesShape.getDimSize(0);
2870  }
2871 
2872  ShapeAdaptor inputShape(adaptor.getInput().getType());
2873  if (inputShape.hasRank()) {
2874  if (outputShape[0] == ShapedType::kDynamic)
2875  outputShape[0] = inputShape.getDimSize(0);
2876  if (outputShape[2] == ShapedType::kDynamic)
2877  outputShape[2] = inputShape.getDimSize(2);
2878  }
2879 
2880  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2881  return success();
2882 }
2883 
2884 LogicalResult tosa::ScatterOp::verify() {
2885  if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
2886  /* outType = */ getValuesOut().getType())
2887  .failed() ||
2888  verifySameElementTypes(*this, /* inType = */ getInput().getType(),
2889  /* outType = */ getValuesOut().getType())
2890  .failed()) {
2891  return failure();
2892  }
2893 
2894  const ShapeAdaptor valuesInShape(getValuesIn().getType());
2895  const ShapeAdaptor indicesShape(getIndices().getType());
2896  const ShapeAdaptor inputShape(getInput().getType());
2897  const ShapeAdaptor outputShape(getValuesOut().getType());
2898 
2899  int64_t N = ShapedType::kDynamic;
2900  int64_t K = ShapedType::kDynamic;
2901  int64_t W = ShapedType::kDynamic;
2902  int64_t C = ShapedType::kDynamic;
2903  if (valuesInShape.hasRank()) {
2904  N = valuesInShape.getDimSize(0);
2905  K = valuesInShape.getDimSize(1);
2906  C = valuesInShape.getDimSize(2);
2907  }
2908  if (indicesShape.hasRank()) {
2909  const int64_t indicesN = indicesShape.getDimSize(0);
2910  W = indicesShape.getDimSize(1);
2911  if (N == ShapedType::kDynamic)
2912  N = indicesN;
2913  else if (indicesN != ShapedType::kDynamic && N != indicesN)
2914  return emitOpError() << "requires indices dimension 0 to have size " << N
2915  << ", got " << indicesN;
2916  }
2917  if (inputShape.hasRank()) {
2918  const int64_t inputN = inputShape.getDimSize(0);
2919  const int64_t inputW = inputShape.getDimSize(1);
2920  const int64_t inputC = inputShape.getDimSize(2);
2921  if (N == ShapedType::kDynamic)
2922  N = inputN;
2923  else if (inputN != ShapedType::kDynamic && N != inputN)
2924  return emitOpError() << "requires input dimension 0 to have size " << N
2925  << ", got " << inputN;
2926  if (W == ShapedType::kDynamic)
2927  W = inputW;
2928  else if (inputW != ShapedType::kDynamic && W != inputW)
2929  return emitOpError() << "requires input dimension 1 to have size " << W
2930  << ", got " << inputW;
2931 
2932  if (C == ShapedType::kDynamic)
2933  C = inputC;
2934  else if (inputC != ShapedType::kDynamic && C != inputC)
2935  return emitOpError() << "requires input dimension 2 to have size " << C
2936  << ", got " << inputC;
2937  }
2938  if (outputShape.hasRank()) {
2939  const int64_t outputN = outputShape.getDimSize(0);
2940  const int64_t outputK = outputShape.getDimSize(1);
2941  const int64_t outputC = outputShape.getDimSize(2);
2942  if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2943  N != outputN)
2944  return emitOpError() << "requires values_out dimension 0 to have size "
2945  << N << ", got " << outputN;
2946  if (K == ShapedType::kDynamic)
2947  K = outputK;
2948  else if (outputK != ShapedType::kDynamic && K != outputK)
2949  return emitOpError() << "requires values_out dimension 1 to have size "
2950  << K << ", got " << outputK;
2951  if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2952  C != outputC)
2953  return emitOpError() << "requires values_out dimension 2 to have size "
2954  << C << ", got " << outputC;
2955  }
2956  if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
2957  return emitOpError() << "requires dimensions K >= W, got K=" << K
2958  << " and W=" << W;
2959 
2960  return success();
2961 }
2962 
2963 static LogicalResult ReduceInferReturnTypes(
2964  ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
2965  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2966  int64_t axisVal = axis.getValue().getSExtValue();
2967  if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
2968  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
2969  return success();
2970  }
2971 
2972  SmallVector<int64_t> outputShape;
2973  operandShape.getDims(outputShape);
2974  outputShape[axisVal] = 1;
2975  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2976  return success();
2977 }
2978 
2979 #define COMPATIBLE_RETURN_TYPES(OP) \
2980  bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
2981  if (l.size() != r.size() || l.size() != 1) \
2982  return false; \
2983  if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
2984  return false; \
2985  return succeeded(verifyCompatibleShape(l[0], r[0])); \
2986  }
2987 
2988 #define REDUCE_SHAPE_INFER(OP) \
2989  LogicalResult OP::inferReturnTypeComponents( \
2990  MLIRContext *context, ::std::optional<Location> location, \
2991  OP::Adaptor adaptor, \
2992  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2993  Type inputType = \
2994  llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
2995  ShapeAdaptor inputShape(adaptor.getInput().getType()); \
2996  const Properties &prop = adaptor.getProperties(); \
2997  return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
2998  inferredReturnShapes); \
2999  } \
3000  COMPATIBLE_RETURN_TYPES(OP)
3001 
3002 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
3003 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
3004 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
3005 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
3006 REDUCE_SHAPE_INFER(tosa::ReduceProductOp)
3007 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
3008 #undef REDUCE_SHAPE_INFER
3009 COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
3010 #undef COMPATIBLE_RETURN_TYPES
3011 
3012 template <typename T>
3013 static LogicalResult verifyReduceOp(T op) {
3014  // All TOSA reduce Ops have input, output and axis.
3015  TensorType inputType = op.getInput().getType();
3016  TensorType outputType = op.getOutput().getType();
3017  int32_t reduceAxis = op.getAxis();
3018 
3019  if (reduceAxis < 0) {
3020  op.emitOpError("reduce axis must not be negative");
3021  return failure();
3022  }
3023  if (inputType.hasRank()) {
3024  int64_t inputRank = inputType.getRank();
3025  // We allow for a special case where the input/output shape has rank 0 and
3026  // axis is also 0.
3027  if (reduceAxis >= inputRank && (reduceAxis != 0 || inputRank != 0)) {
3028  op.emitOpError("expect input tensor rank (")
3029  << inputRank << ") to be larger than reduce axis (" << reduceAxis
3030  << ")";
3031  return failure();
3032  }
3033  }
3034  if (outputType.hasRank()) {
3035  int64_t outputRank = outputType.getRank();
3036  if (inputType.hasRank() && outputRank != inputType.getRank()) {
3037  op.emitOpError(
3038  "expect output tensor rank to be equal to input tensor rank");
3039  return failure();
3040  }
3041  if (reduceAxis >= outputRank && (reduceAxis != 0 || outputRank != 0)) {
3042  op.emitOpError("expect output tensor rank (")
3043  << outputRank << ") to be larger than reduce axis (" << reduceAxis
3044  << ")";
3045  return failure();
3046  }
3047  // We can only verify the reduced dimension size to be 1 if this is not
3048  // the special case of output rank == 0.
3049  if (outputRank != 0) {
3050  auto outputShape = outputType.getShape();
3051  if (!outputType.isDynamicDim(reduceAxis) &&
3052  outputShape[reduceAxis] != 1) {
3053  op.emitOpError("expect reduced dimension size to be 1, got ")
3054  << outputShape[reduceAxis];
3055  return failure();
3056  }
3057  }
3058  }
3059  return success();
3060 }
3061 
3062 LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
3063 LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
3064 LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
3065 LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
3066 LogicalResult tosa::ReduceProductOp::verify() { return verifyReduceOp(*this); }
3067 LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
3068 
3069 static LogicalResult NAryInferReturnTypes(
3070  const ValueShapeRange &operands,
3071  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3072  llvm::SmallVector<int64_t> outShape;
3073  if (resolveBroadcastShape(operands, outShape).failed()) {
3074  inferredReturnShapes.push_back(ShapedTypeComponents());
3075  } else {
3076  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
3077  }
3078  return success();
3079 }
3080 
3081 #define NARY_SHAPE_INFER(OP) \
3082  LogicalResult OP::inferReturnTypeComponents( \
3083  MLIRContext *context, ::std::optional<Location> location, \
3084  ValueShapeRange operands, DictionaryAttr attributes, \
3085  OpaqueProperties properties, RegionRange regions, \
3086  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3087  return NAryInferReturnTypes(operands, inferredReturnShapes); \
3088  }
3089 
3090 NARY_SHAPE_INFER(tosa::AbsOp)
3091 NARY_SHAPE_INFER(tosa::AddOp)
3092 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
3093 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
3094 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
3095 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
3096 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
3097 NARY_SHAPE_INFER(tosa::CastOp)
3098 NARY_SHAPE_INFER(tosa::CeilOp)
3099 NARY_SHAPE_INFER(tosa::ClampOp)
3100 NARY_SHAPE_INFER(tosa::ClzOp)
3101 NARY_SHAPE_INFER(tosa::CosOp)
3102 NARY_SHAPE_INFER(tosa::ExpOp)
3103 NARY_SHAPE_INFER(tosa::FloorOp)
3104 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
3105 NARY_SHAPE_INFER(tosa::GreaterOp)
3106 NARY_SHAPE_INFER(tosa::IdentityOp)
3107 NARY_SHAPE_INFER(tosa::IntDivOp)
3108 NARY_SHAPE_INFER(tosa::LogOp)
3109 NARY_SHAPE_INFER(tosa::LogicalAndOp)
3110 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
3111 NARY_SHAPE_INFER(tosa::LogicalNotOp)
3112 NARY_SHAPE_INFER(tosa::LogicalOrOp)
3113 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
3114 NARY_SHAPE_INFER(tosa::LogicalXorOp)
3115 NARY_SHAPE_INFER(tosa::MaximumOp)
3116 NARY_SHAPE_INFER(tosa::MinimumOp)
3117 NARY_SHAPE_INFER(tosa::PowOp)
3118 NARY_SHAPE_INFER(tosa::ReciprocalOp)
3119 NARY_SHAPE_INFER(tosa::ReverseOp)
3120 NARY_SHAPE_INFER(tosa::RsqrtOp)
3121 NARY_SHAPE_INFER(tosa::SinOp)
3122 NARY_SHAPE_INFER(tosa::SelectOp)
3123 NARY_SHAPE_INFER(tosa::SubOp)
3124 NARY_SHAPE_INFER(tosa::TanhOp)
3125 NARY_SHAPE_INFER(tosa::ErfOp)
3126 NARY_SHAPE_INFER(tosa::SigmoidOp)
3127 #undef PRED_SHAPE_INFER
3128 
3129 LogicalResult tosa::NegateOp::inferReturnTypeComponents(
3130  MLIRContext *context, ::std::optional<Location> location,
3131  NegateOp::Adaptor adaptor,
3132  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3133  ShapeAdaptor inputShape(adaptor.getInput1().getType());
3134  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3135  return success();
3136 }
3137 
3138 LogicalResult tosa::NegateOp::verify() {
3139  // Verify same element type
3140  const Type input1Type = getInput1().getType();
3141  const Type outputType = getOutput().getType();
3142  if (verifySameElementTypes(*this, input1Type, outputType).failed())
3143  return failure();
3144 
3145  // Verify same shape
3146  const SmallVector<Type, 2> types = {input1Type, outputType};
3147  if (failed(verifyCompatibleShapes(types)))
3148  return emitOpError() << "requires the same shape for input1 and output";
3149 
3150  const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
3151  const Type input1ZpEType =
3152  getStorageElementTypeOrSelf(getInput1Zp().getType());
3153  if (input1EType != input1ZpEType) {
3154  return emitOpError("expect both input1 and its zero point are the same "
3155  "element type, got ")
3156  << input1EType << " and " << input1ZpEType;
3157  }
3158  const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
3159  const Type outputZpEType =
3160  getStorageElementTypeOrSelf(getOutputZp().getType());
3161  if (outputEType != outputZpEType) {
3162  return emitOpError("expect both output and its zero point are the same "
3163  "element type, got ")
3164  << outputEType << " and " << outputZpEType;
3165  }
3166 
3167  FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
3168  if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
3169  return failure();
3170 
3171  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3172  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3173  return failure();
3174 
3175  return success();
3176 }
3177 
3178 static LogicalResult poolingInferReturnTypes(
3179  ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
3180  ArrayRef<int64_t> pad,
3181  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3182  llvm::SmallVector<int64_t> outputShape;
3183  outputShape.resize(4, ShapedType::kDynamic);
3184 
3185  // We only know the rank if the input type is unranked.
3186  if (!inputShape) {
3187  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3188  return success();
3189  }
3190 
3191  // Batch and number of channels are identical for pooling layer.
3192  outputShape[0] = inputShape.getDimSize(0);
3193  outputShape[3] = inputShape.getDimSize(3);
3194 
3195  int64_t height = inputShape.getDimSize(1);
3196  int64_t width = inputShape.getDimSize(2);
3197 
3198  if (ShapedType::isStatic(height)) {
3199  int64_t padded = height + pad[0] + pad[1] - kernel[0];
3200  outputShape[1] = padded / stride[0] + 1;
3201  }
3202 
3203  if (ShapedType::isStatic(width)) {
3204  int64_t padded = width + pad[2] + pad[3] - kernel[1];
3205  outputShape[2] = padded / stride[1] + 1;
3206  }
3207 
3208  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3209  return success();
3210 }
3211 
3212 LogicalResult Conv2DOp::inferReturnTypeComponents(
3213  MLIRContext *context, ::std::optional<Location> location,
3214  Conv2DOp::Adaptor adaptor,
3215  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3216  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3217 
3218  int64_t inputWidth = ShapedType::kDynamic;
3219  int64_t inputHeight = ShapedType::kDynamic;
3220  int64_t weightWidth = ShapedType::kDynamic;
3221  int64_t weightHeight = ShapedType::kDynamic;
3222 
3223  // Input shape describes input width/height and batch.
3224 
3225  ShapeAdaptor inputShape(adaptor.getInput().getType());
3226  if (inputShape.hasRank()) {
3227  outputShape[0] = inputShape.getDimSize(0);
3228  inputHeight = inputShape.getDimSize(1);
3229  inputWidth = inputShape.getDimSize(2);
3230  }
3231 
3232  // Weight shapes describes the filter width/height and the output channels.
3233  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3234  if (weightShape.hasRank()) {
3235  outputShape[3] = weightShape.getDimSize(0);
3236  weightHeight = weightShape.getDimSize(1);
3237  weightWidth = weightShape.getDimSize(2);
3238  }
3239 
3240  // Bias shape can describe the output channels.
3241  ShapeAdaptor biasShape(adaptor.getBias().getType());
3242  if (biasShape.hasRank()) {
3243  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3244  ? biasShape.getDimSize(0)
3245  : outputShape[3];
3246  }
3247 
3248  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3249  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3250  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3251 
3252  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3253  int64_t inputSize = inputHeight + padding[0] + padding[1];
3254  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3255  int64_t unstridedResult = inputSize - filterSize + 1;
3256  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3257  }
3258 
3259  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3260  int64_t inputSize = inputWidth + padding[2] + padding[3];
3261  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3262  int64_t unstridedResult = inputSize - filterSize + 1;
3263  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3264  }
3265 
3266  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3267  return success();
3268 }
3269 
3270 LogicalResult Conv2DOp::verify() {
3271  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3272  verifyConvOpErrorIf(*this).failed())
3273  return failure();
3274  return success();
3275 }
3276 
3277 LogicalResult Conv3DOp::inferReturnTypeComponents(
3278  MLIRContext *context, ::std::optional<Location> location,
3279  Conv3DOp::Adaptor adaptor,
3280  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3281  llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
3282 
3283  int64_t inputWidth = ShapedType::kDynamic;
3284  int64_t inputHeight = ShapedType::kDynamic;
3285  int64_t inputDepth = ShapedType::kDynamic;
3286 
3287  int64_t weightWidth = ShapedType::kDynamic;
3288  int64_t weightHeight = ShapedType::kDynamic;
3289  int64_t weightDepth = ShapedType::kDynamic;
3290 
3291  // Input shape describes input width/height and batch.
3292  ShapeAdaptor inputShape(adaptor.getInput().getType());
3293  if (inputShape.hasRank()) {
3294  outputShape[0] = inputShape.getDimSize(0);
3295  inputDepth = inputShape.getDimSize(1);
3296  inputHeight = inputShape.getDimSize(2);
3297  inputWidth = inputShape.getDimSize(3);
3298  }
3299 
3300  // Weight shapes describes the filter width/height and the output channels.
3301  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3302  if (weightShape.hasRank()) {
3303  outputShape[4] = weightShape.getDimSize(0);
3304  weightDepth = weightShape.getDimSize(1);
3305  weightHeight = weightShape.getDimSize(2);
3306  weightWidth = weightShape.getDimSize(3);
3307  }
3308 
3309  // Bias shape can describe the output channels.
3310  ShapeAdaptor biasShape(adaptor.getBias().getType());
3311  if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
3312  outputShape[4] = biasShape.getDimSize(0);
3313  }
3314 
3315  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3316  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3317  llvm::ArrayRef<int64_t> pad = adaptor.getPad();
3318 
3319  if (ShapedType::isStatic(inputDepth) && ShapedType::isStatic(weightDepth)) {
3320  int32_t inputSize = inputDepth + pad[0] + pad[1];
3321  int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
3322  int32_t unstridedResult = inputSize - filterSize + 1;
3323  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3324  }
3325 
3326  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3327  int32_t inputSize = inputHeight + pad[2] + pad[3];
3328  int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
3329  int32_t unstridedResult = inputSize - filterSize + 1;
3330  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3331  }
3332 
3333  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3334  int32_t inputSize = inputWidth + pad[4] + pad[5];
3335  int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
3336  int32_t unstridedResult = inputSize - filterSize + 1;
3337  outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3338  }
3339 
3340  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3341  return success();
3342 }
3343 
3344 LogicalResult Conv3DOp::verify() {
3345  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3346  verifyConvOpErrorIf(*this).failed())
3347  return failure();
3348  return success();
3349 }
3350 
3351 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3352  MLIRContext *context, ::std::optional<Location> location,
3353  AvgPool2dOp::Adaptor adaptor,
3354  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3355  ShapeAdaptor inputShape(adaptor.getInput().getType());
3356  const Properties &prop = adaptor.getProperties();
3357  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3358  inferredReturnShapes);
3359 }
3360 
3361 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3362  MLIRContext *context, ::std::optional<Location> location,
3363  MaxPool2dOp::Adaptor adaptor,
3364  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3365  ShapeAdaptor inputShape(adaptor.getInput().getType());
3366  const Properties &prop = adaptor.getProperties();
3367  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3368  inferredReturnShapes);
3369 }
3370 
3371 LogicalResult MaxPool2dOp::verify() {
3372  if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
3373  /* outType = */ getOutput().getType())))
3374  return failure();
3375 
3376  if (failed(verifyPoolingOp(*this)))
3377  return failure();
3378 
3379  return success();
3380 }
3381 
3382 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3383  MLIRContext *context, ::std::optional<Location> location,
3384  DepthwiseConv2DOp::Adaptor adaptor,
3385  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3386  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3387 
3388  int64_t inputWidth = ShapedType::kDynamic;
3389  int64_t inputHeight = ShapedType::kDynamic;
3390  int64_t inputChannels = ShapedType::kDynamic;
3391 
3392  int64_t weightWidth = ShapedType::kDynamic;
3393  int64_t weightHeight = ShapedType::kDynamic;
3394  int64_t depthChannels = ShapedType::kDynamic;
3395 
3396  // Input shape describes input width/height and batch.
3397  ShapeAdaptor inputShape(adaptor.getInput().getType());
3398  if (inputShape.hasRank()) {
3399  outputShape[0] = inputShape.getDimSize(0);
3400  inputHeight = inputShape.getDimSize(1);
3401  inputWidth = inputShape.getDimSize(2);
3402  inputChannels = inputShape.getDimSize(3);
3403  }
3404 
3405  // Weight shapes describes the filter width/height and the output channels.
3406  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3407  if (weightShape.hasRank()) {
3408  weightHeight = weightShape.getDimSize(0);
3409  weightWidth = weightShape.getDimSize(1);
3410  inputChannels = ShapedType::isDynamic(inputChannels)
3411  ? weightShape.getDimSize(2)
3412  : inputChannels;
3413  depthChannels = weightShape.getDimSize(3);
3414  }
3415 
3416  // If both inputChannels and depthChannels are available we can determine
3417  // the output channels.
3418  if (ShapedType::isStatic(inputChannels) &&
3419  ShapedType::isStatic(depthChannels)) {
3420  outputShape[3] = inputChannels * depthChannels;
3421  }
3422 
3423  // Bias shape can describe the output channels.
3424  ShapeAdaptor biasShape(adaptor.getBias().getType());
3425  if (biasShape.hasRank()) {
3426  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3427  ? biasShape.getDimSize(0)
3428  : outputShape[3];
3429  }
3430 
3431  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3432  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3433  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3434 
3435  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3436  int64_t inputSize = inputHeight + padding[0] + padding[1];
3437  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3438  int64_t unstridedResult = inputSize - filterSize + 1;
3439  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3440  }
3441 
3442  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3443  int64_t inputSize = inputWidth + padding[2] + padding[3];
3444  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3445  int64_t unstridedResult = inputSize - filterSize + 1;
3446  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3447  }
3448 
3449  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3450  return success();
3451 }
3452 
3453 LogicalResult DepthwiseConv2DOp::verify() {
3454  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3455  verifyConvOpErrorIf(*this).failed())
3456  return failure();
3457  return success();
3458 }
3459 
3460 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3461  MLIRContext *context, ::std::optional<Location> location,
3462  TransposeConv2DOp::Adaptor adaptor,
3463  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3464  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3465 
3466  int64_t inputWidth = ShapedType::kDynamic;
3467  int64_t inputHeight = ShapedType::kDynamic;
3468  int64_t weightWidth = ShapedType::kDynamic;
3469  int64_t weightHeight = ShapedType::kDynamic;
3470 
3471  // Input shape describes input width/height and batch.
3472  ShapeAdaptor inputShape(adaptor.getInput().getType());
3473  if (inputShape.hasRank()) {
3474  outputShape[0] = ShapedType::isDynamic(outputShape[0])
3475  ? inputShape.getDimSize(0)
3476  : outputShape[0];
3477  inputHeight = inputShape.getDimSize(1);
3478  inputWidth = inputShape.getDimSize(2);
3479  }
3480 
3481  // Weight shapes describes the filter width/height and the output channels.
3482  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3483  if (weightShape.hasRank()) {
3484  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3485  ? weightShape.getDimSize(0)
3486  : outputShape[3];
3487  weightHeight = weightShape.getDimSize(1);
3488  weightWidth = weightShape.getDimSize(2);
3489  }
3490 
3491  // Bias shape can describe the output channels.
3492  ShapeAdaptor biasShape(adaptor.getInput().getType());
3493  if (biasShape.hasRank()) {
3494  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3495  ? biasShape.getDimSize(0)
3496  : outputShape[3];
3497  }
3498 
3499  llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
3500  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3501 
3502  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3503  int64_t calculateSize =
3504  (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3505  outputShape[1] =
3506  ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3507  }
3508 
3509  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3510  int64_t calculateSize =
3511  (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3512  outputShape[2] =
3513  ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3514  }
3515 
3516  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3517  return success();
3518 }
3519 
3520 LogicalResult TransposeConv2DOp::verify() {
3521  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
3522  return failure();
3523 
3524  const llvm::ArrayRef<int64_t> strides = getStride();
3525  const int64_t strideY = strides[0];
3526  const int64_t strideX = strides[1];
3527 
3528  if (strideY < 1 || strideX < 1)
3529  return emitOpError("expect all stride values to be >= 1, got [")
3530  << strides << "]";
3531 
3532  const auto checkPadAgainstKernelDim =
3533  [this](int64_t pad_value, int64_t kernel_dim_size,
3534  llvm::StringRef pad_name,
3535  llvm::StringRef kernel_dim_name) -> LogicalResult {
3536  if (pad_value <= -kernel_dim_size)
3537  return emitOpError("expected ")
3538  << pad_name << " > -" << kernel_dim_name
3539  << ", but got: " << pad_name << "=" << pad_value << " and "
3540  << kernel_dim_name << "=" << kernel_dim_size;
3541  return success();
3542  };
3543 
3544  const llvm::ArrayRef<int64_t> padding = getOutPad();
3545  const int64_t outPadTop = padding[0];
3546  const int64_t outPadBottom = padding[1];
3547  const int64_t outPadLeft = padding[2];
3548  const int64_t outPadRight = padding[3];
3549 
3550  const auto weightType =
3551  llvm::dyn_cast<RankedTensorType>(getWeight().getType());
3552 
3553  if (weightType) {
3554  const int64_t kernelHeight = weightType.getDimSize(1);
3555  if (ShapedType::isStatic(kernelHeight)) {
3556  if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
3557  "out_pad_top", "KH")))
3558  return failure();
3559 
3560  if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3561  "out_pad_bottom", "KH")))
3562  return failure();
3563  }
3564 
3565  const int64_t kernelWidth = weightType.getDimSize(2);
3566  if (ShapedType::isStatic(kernelWidth)) {
3567  if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
3568  "out_pad_left", "KW")))
3569  return failure();
3570 
3571  if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3572  "out_pad_right", "KW")))
3573  return failure();
3574  }
3575  }
3576 
3577  // Rest of the checks depend on the output type being a RankedTensorType
3578  const auto outputType =
3579  llvm::dyn_cast<RankedTensorType>(getOutput().getType());
3580  if (!outputType)
3581  return success();
3582 
3583  const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
3584  if (inputType && weightType) {
3585  const int64_t inputHeight = inputType.getDimSize(1);
3586  const int64_t kernelHeight = weightType.getDimSize(1);
3587  const int64_t outputHeight = outputType.getDimSize(1);
3588 
3589  if (ShapedType::isStatic(inputHeight) &&
3590  ShapedType::isStatic(outputHeight)) {
3591  if (outputHeight !=
3592  (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3593  return emitOpError(
3594  "dimension mismatch: expected OH == (IH - 1) * stride_y "
3595  "+ out_pad_top + out_pad_bottom + KH, but got ")
3596  << outputHeight << " != (" << inputHeight << " - 1) * "
3597  << strideY << " + " << outPadTop << " + " << outPadBottom
3598  << " + " << kernelHeight;
3599  }
3600 
3601  const int64_t inputWidth = inputType.getDimSize(2);
3602  const int64_t kernelWidth = weightType.getDimSize(2);
3603  const int64_t outputWidth = outputType.getDimSize(2);
3604 
3605  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
3606  if (outputWidth !=
3607  (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3608  return emitOpError(
3609  "dimension mismatch: expected OW == (IW - 1) * stride_x "
3610  "+ out_pad_left + out_pad_right + KW, but got ")
3611  << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
3612  << " + " << outPadLeft << " + " << outPadRight << " + "
3613  << kernelWidth;
3614  }
3615  }
3616 
3617  const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().getType());
3618 
3619  if (!biasType)
3620  return success();
3621 
3622  const int64_t biasChannels = biasType.getDimSize(0);
3623 
3624  // Skip further checks if bias is dynamic
3625  if (biasChannels == ShapedType::kDynamic)
3626  return success();
3627 
3628  const int64_t outputChannels = outputType.getDimSize(3);
3629  if (!ShapedType::isDynamic(outputChannels) &&
3630  biasChannels != outputChannels && biasChannels != 1)
3631  return emitOpError(
3632  "bias channels expected to be equal to output channels (")
3633  << outputChannels << ") or 1, got " << biasChannels;
3634 
3635  return success();
3636 }
3637 
3638 LogicalResult RescaleOp::verify() {
3639  auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
3640  if (!inputType) {
3641  emitOpError("expect shaped tensor for input, got ") << getInput().getType();
3642  return failure();
3643  }
3644 
3645  auto inputElementType =
3646  getStorageElementTypeOrSelf(inputType.getElementType());
3647  if (!mlir::isa<IntegerType>(inputElementType)) {
3648  emitOpError("expect input to have integer element type, got ")
3649  << inputElementType;
3650  return failure();
3651  }
3652 
3653  auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
3654  if (!outputType) {
3655  emitOpError("expect shaped tensor for output, got ")
3656  << getOutput().getType();
3657  return failure();
3658  }
3659 
3660  auto outputElementType =
3661  getStorageElementTypeOrSelf(outputType.getElementType());
3662  if (!mlir::isa<IntegerType>(outputElementType)) {
3663  emitOpError("expect output to have integer element type, got ")
3664  << outputElementType;
3665  return failure();
3666  }
3667 
3668  if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
3669  .failed())
3670  return failure();
3671 
3672  if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
3673  .failed())
3674  return failure();
3675 
3676  FailureOr<int64_t> maybeIZp = getInputZeroPoint();
3677  if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
3678  return failure();
3679 
3680  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3681  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3682  return failure();
3683 
3684  auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
3685  if (!multiplierType) {
3686  emitOpError("expect shaped tensor for multiplier, got ")
3687  << getMultiplier().getType();
3688  return failure();
3689  }
3690 
3691  auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
3692  if (!shiftType) {
3693  emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
3694  return failure();
3695  }
3696 
3697  // multiplier element type must be i32 for scale32 = true
3698  if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
3699  emitOpError("expect i32 element type for multiplier for scale32=true, got ")
3700  << multiplierType.getElementType();
3701  return failure();
3702  }
3703 
3704  // multiplier element type must be i16 for scale32 = false
3705  if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
3706  emitOpError(
3707  "expect i16 element type for multiplier for scale32=false, got ")
3708  << multiplierType.getElementType();
3709  return failure();
3710  }
3711 
3712  if (!inputType.hasRank())
3713  return success();
3714 
3715  // multiplier/shift must have shape = {numChannels},
3716  // where numChannel is 1 if per_channel = false
3717  // otherwise numChannel is dimension in input shape's last axis
3718  int64_t numChannels = 1;
3719  if (getPerChannel()) {
3720  if (inputType.getRank() < 1) {
3721  emitOpError("requires input to be at least rank 1 when per_channel is "
3722  "true, but got rank ")
3723  << inputType.getRank();
3724  return failure();
3725  }
3726  numChannels = inputType.getDimSize(inputType.getRank() - 1);
3727  }
3728 
3729  if (!multiplierType.hasRank())
3730  return success();
3731 
3732  ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
3733  // multiplier input has rank 1 by dialect definition
3734  if (multiplierShape[0] != ShapedType::kDynamic &&
3735  multiplierShape[0] != numChannels) {
3736  emitOpError("expect shape of { ")
3737  << numChannels << " } for multiplier input, got { "
3738  << multiplierShape[0] << " }";
3739  return failure();
3740  }
3741 
3742  if (!shiftType.hasRank())
3743  return success();
3744 
3745  ArrayRef<int64_t> shiftShape = shiftType.getShape();
3746  // shift input has rank 1 by dialect definition
3747  if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3748  emitOpError("expect shape of { ")
3749  << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
3750  return failure();
3751  }
3752 
3753  return success();
3754 }
3755 
3756 LogicalResult RescaleOp::inferReturnTypeComponents(
3757  MLIRContext *context, ::std::optional<Location> location,
3758  RescaleOp::Adaptor adaptor,
3759  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3760  ShapeAdaptor inputShape(adaptor.getInput().getType());
3761  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3762  return success();
3763 }
3764 
3765 LogicalResult IfOp::inferReturnTypeComponents(
3766  MLIRContext *context, ::std::optional<Location> location,
3767  IfOp::Adaptor adaptor,
3768  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3770  for (Region *region : adaptor.getRegions()) {
3771  for (auto &block : *region)
3772  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3773  yieldOps.push_back(returnOp);
3774  }
3775 
3776  if (yieldOps.empty())
3777  return failure();
3778 
3779  // Get the initial type information for the yield op.
3780  llvm::SmallVector<ValueKnowledge> resultKnowledge;
3781  resultKnowledge.reserve(yieldOps.front().getNumOperands());
3782  for (auto operand : yieldOps.front().getOperands()) {
3783  resultKnowledge.push_back(
3784  ValueKnowledge::getKnowledgeFromType(operand.getType()));
3785  }
3786 
3787  for (auto yieldOp : yieldOps) {
3788  if (resultKnowledge.size() != yieldOp.getNumOperands())
3789  return failure();
3790 
3791  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
3792  int32_t index = it.index();
3793  auto meet = ValueKnowledge::meet(
3794  resultKnowledge[index],
3795  ValueKnowledge::getKnowledgeFromType(it.value().getType()));
3796  if (!meet)
3797  continue;
3798  resultKnowledge[index] = meet;
3799  }
3800  }
3801 
3802  for (const ValueKnowledge &result : resultKnowledge) {
3803  inferredReturnShapes.push_back(result.getShapedTypeComponents());
3804  }
3805 
3806  return success();
3807 }
3808 
3809 LogicalResult WhileOp::inferReturnTypeComponents(
3810  MLIRContext *context, ::std::optional<Location> location,
3811  WhileOp::Adaptor adaptor,
3812  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3814  for (auto &block : adaptor.getBodyGraph())
3815  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3816  yieldOps.push_back(returnOp);
3817 
3818  // TOSA's while must have a tosa.yield as its terminator. If not found this
3819  // tosa.while is invalid.
3820  if (yieldOps.empty())
3821  return failure();
3822 
3823  // Get the initial type information from the operand types.
3824  llvm::SmallVector<ValueKnowledge> resultKnowledge;
3825  resultKnowledge.reserve(yieldOps.front().getNumOperands());
3826  for (auto operand : yieldOps.front().getOperands()) {
3827  resultKnowledge.push_back(
3828  ValueKnowledge::getKnowledgeFromType(operand.getType()));
3829  }
3830 
3831  for (auto yieldOp : yieldOps) {
3832  if (resultKnowledge.size() != yieldOp.getNumOperands())
3833  return failure();
3834 
3835  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
3836  int32_t index = it.index();
3837  if (auto meet = ValueKnowledge::meet(
3838  resultKnowledge[index],
3839  ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
3840  resultKnowledge[index] = meet;
3841  }
3842  }
3843  }
3844 
3845  for (const ValueKnowledge &result : resultKnowledge) {
3846  inferredReturnShapes.push_back(result.getShapedTypeComponents());
3847  }
3848 
3849  return success();
3850 }
3851 
3852 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
3853  if (auto vt = llvm::dyn_cast<VectorType>(getType()))
3854  return llvm::to_vector<4>(vt.getShape());
3855  return std::nullopt;
3856 }
3857 
3859  Block::BlockArgListType blocksArgs,
3860  ValueRange initializers,
3861  StringRef prefix = "") {
3862  assert(blocksArgs.size() == initializers.size() &&
3863  "expected same length of arguments and initializers");
3864  if (initializers.empty())
3865  return;
3866 
3867  parser << prefix << '(';
3868  llvm::interleaveComma(
3869  llvm::zip(blocksArgs, initializers), parser,
3870  [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
3871  parser << ")";
3872 }
3873 
3874 // parse and print of IfOp refer to the implementation of SCF dialect.
3875 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
3876  // Create the regions for 'then'.
3877  result.regions.reserve(2);
3878  Region *thenRegion = result.addRegion();
3879  Region *elseRegion = result.addRegion();
3880 
3882 
3883  if (parser.parseOperand(cond))
3884  return failure();
3885 
3888 
3889  // Parse the optional block arguments
3890  OptionalParseResult listResult =
3891  parser.parseOptionalAssignmentList(regionArgs, operands);
3892  if (listResult.has_value() && failed(listResult.value()))
3893  return failure();
3894 
3895  // Parse a colon.
3896  if (failed(parser.parseColon()))
3897  return parser.emitError(parser.getCurrentLocation(),
3898  "expected type for condition operand");
3899 
3900  // Parse the type of the condition operand
3901  Type condType;
3902  if (failed(parser.parseType(condType)))
3903  return parser.emitError(parser.getCurrentLocation(),
3904  "expected type for condition operand");
3905 
3906  // Resolve operand with provided type
3907  if (failed(parser.resolveOperand(cond, condType, result.operands)))
3908  return failure();
3909 
3910  // Parse optional block arg types
3911  if (listResult.has_value()) {
3912  FunctionType functionType;
3913 
3914  if (failed(parser.parseType(functionType)))
3915  return parser.emitError(parser.getCurrentLocation())
3916  << "expected list of types for block arguments "
3917  << "followed by arrow type and list of return types";
3918 
3919  result.addTypes(functionType.getResults());
3920 
3921  if (functionType.getNumInputs() != operands.size()) {
3922  return parser.emitError(parser.getCurrentLocation())
3923  << "expected as many input types as operands "
3924  << "(expected " << operands.size() << " got "
3925  << functionType.getNumInputs() << ")";
3926  }
3927 
3928  // Resolve input operands.
3929  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3930  parser.getCurrentLocation(),
3931  result.operands)))
3932  return failure();
3933  } else {
3934  // Parse optional results type list.
3935  if (parser.parseOptionalArrowTypeList(result.types))
3936  return failure();
3937  }
3938 
3939  // Parse the 'then' region.
3940  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
3941  return failure();
3942 
3943  // If we find an 'else' keyword then parse the 'else' region.
3944  if (!parser.parseOptionalKeyword("else")) {
3945  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
3946  return failure();
3947  }
3948 
3949  // Parse the optional attribute list.
3950  if (parser.parseOptionalAttrDict(result.attributes))
3951  return failure();
3952  return success();
3953 }
3954 
3955 void IfOp::print(OpAsmPrinter &p) {
3956  p << " " << getCondition();
3957 
3958  printInitializationList(p, getThenGraph().front().getArguments(),
3959  getInputList(), " ");
3960  p << " : ";
3961  p << getCondition().getType();
3962 
3963  if (!getInputList().empty()) {
3964  p << " (";
3965  llvm::interleaveComma(getInputList().getTypes(), p);
3966  p << ")";
3967  }
3968  p.printArrowTypeList(getResultTypes());
3969  p << " ";
3970 
3971  p.printRegion(getThenGraph());
3972 
3973  // Print the 'else' regions if it exists and has a block.
3974  auto &elseRegion = getElseGraph();
3975  if (!elseRegion.empty()) {
3976  p << " else ";
3977  p.printRegion(elseRegion);
3978  }
3979 
3980  p.printOptionalAttrDict((*this)->getAttrs());
3981 }
3982 
3983 LogicalResult IfOp::verify() {
3984  if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
3985  "'then_graph' arguments", getInputList(),
3986  "'input_list'")
3987  .failed())
3988  return failure();
3989 
3990  if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(),
3991  "'else_graph' arguments", getInputList(),
3992  "'input_list'")
3993  .failed())
3994  return failure();
3995 
3996  // MLIR will verify the absence of the terminator for us if otherwise.
3997  if (getThenGraph().front().mightHaveTerminator()) {
3998  auto thenYield =
3999  dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
4000  if (thenYield && errorIfTypeOrShapeMismatch(
4001  *this, thenYield.getInputs(), "'then_graph' results",
4002  getOutputList(), "'output_list'")
4003  .failed())
4004  return failure();
4005  }
4006 
4007  // MLIR will verify the absence of the terminator for us if otherwise.
4008  if (getElseGraph().front().mightHaveTerminator()) {
4009  auto elseYield =
4010  dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
4011  if (elseYield && errorIfTypeOrShapeMismatch(
4012  *this, elseYield.getInputs(), "'else_graph' results",
4013  getOutputList(), "'output_list'")
4014  .failed())
4015  return failure();
4016  }
4017 
4018  auto condType = getCondition().getType();
4019  if (errorIfShapeNotSizeOne(*this, condType).failed())
4020  return emitOpError() << "'condition' must be a size 1 tensor, got "
4021  << condType;
4022 
4023  return success();
4024 }
4025 
4026 LogicalResult WhileOp::verify() {
4027  if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'",
4028  getOutputList(), "'output_list'")
4029  .failed())
4030  return failure();
4031 
4032  if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(),
4033  "'cond_graph' arguments", getInputList(),
4034  "'input_list'")
4035  .failed())
4036  return failure();
4037 
4038  if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(),
4039  "'body_graph' arguments", getInputList(),
4040  "'input_list'")
4041  .failed())
4042  return failure();
4043 
4044  if (getBodyGraph().front().mightHaveTerminator()) {
4045  auto bodyYield =
4046  dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
4047  if (bodyYield && errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
4048  "'body_graph' results",
4049  getInputList(), "'input_list'")
4050  .failed())
4051  return failure();
4052  }
4053 
4054  // Condition block output must be a single element tensor with a single bool
4055  // value.
4056  if (!getCondGraph().front().mightHaveTerminator())
4057  return success();
4058 
4059  auto condYield =
4060  dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
4061  if (!condYield)
4062  return success();
4063 
4064  if (condYield.getInputs().size() != 1)
4065  return emitOpError() << "require 'cond_graph' only have one result";
4066 
4067  auto condOutType = condYield.getInputs()[0].getType();
4068  if (errorIfShapeNotSizeOne(*this, condOutType).failed())
4069  return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
4070  << condOutType;
4071 
4072  if (!getElementTypeOrSelf(condOutType).isInteger(1))
4073  return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
4074  << condOutType;
4075 
4076  return success();
4077 }
4078 
4079 LogicalResult ReverseOp::verify() {
4080  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
4081  /* outType = */ getOutput().getType())
4082  .failed())
4083  return failure();
4084  TensorType inputType = getInput1().getType();
4085  TensorType outputType = getOutput().getType();
4086  int32_t reverseAxis = getAxis();
4087 
4088  if (reverseAxis < 0)
4089  return emitOpError("expected non-negative reverse axis");
4090  if (inputType.hasRank()) {
4091  int64_t inputRank = inputType.getRank();
4092  // We allow for a special case where the input/output shape has rank 0 and
4093  // axis is also 0.
4094  if (reverseAxis >= inputRank && (reverseAxis != 0 || inputRank != 0))
4095  return emitOpError("expect input tensor rank (")
4096  << inputRank << ") to be larger than reverse axis (" << reverseAxis
4097  << ")";
4098  }
4099  if (outputType.hasRank()) {
4100  int64_t outputRank = outputType.getRank();
4101  if (inputType.hasRank() && outputRank != inputType.getRank())
4102  return emitOpError(
4103  "expect output tensor rank to be equal to input tensor rank");
4104  if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0))
4105  return emitOpError("expect output tensor rank (")
4106  << outputRank << ") to be larger than reverse axis ("
4107  << reverseAxis << ")";
4108  }
4109  return success();
4110 }
4111 
4112 LogicalResult tosa::SelectOp::verify() {
4113  // verify input2 and input3 have same element type as output
4114  if (verifySameElementTypes(*this, /* inType = */ getOnTrue().getType(),
4115  /* outType = */ getOutput().getType())
4116  .failed() ||
4117  verifySameElementTypes(*this, /* inType = */ getOnFalse().getType(),
4118  /* outType = */ getOutput().getType())
4119  .failed()) {
4120  return failure();
4121  }
4122  // verify input1 has element type of bool
4123  auto predicateType = llvm::dyn_cast<ShapedType>(getPred().getType());
4124  if (!predicateType) {
4125  return emitOpError("expect shaped tensor for input1, got ")
4126  << getInput1().getType();
4127  }
4128  auto predicateElementType = predicateType.getElementType();
4129  if (!predicateElementType.isInteger(1)) {
4130  return emitOpError("expect element type of bool for input1, got ")
4131  << predicateElementType;
4132  }
4133 
4134  return success();
4135 }
4136 
4137 LogicalResult tosa::VariableReadOp::verify() {
4138  if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
4139  .failed())
4140  return failure();
4141 
4142  return success();
4143 }
4144 
4145 LogicalResult tosa::VariableWriteOp::verify() {
4146  if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
4147  .failed())
4148  return failure();
4149 
4150  return success();
4151 }
4152 
4153 // parse and print of WhileOp refer to the implementation of SCF dialect.
4154 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
4157  Region *cond = result.addRegion();
4158  Region *body = result.addRegion();
4159 
4160  OptionalParseResult listResult =
4161  parser.parseOptionalAssignmentList(regionArgs, operands);
4162  if (listResult.has_value() && failed(listResult.value()))
4163  return failure();
4164 
4165  FunctionType functionType;
4166  SMLoc typeLoc = parser.getCurrentLocation();
4167  if (failed(parser.parseColonType(functionType)))
4168  return failure();
4169 
4170  result.addTypes(functionType.getResults());
4171 
4172  if (functionType.getNumInputs() != operands.size()) {
4173  return parser.emitError(typeLoc)
4174  << "expected as many input types as operands "
4175  << "(expected " << operands.size() << " got "
4176  << functionType.getNumInputs() << ")";
4177  }
4178 
4179  // Resolve input operands.
4180  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
4181  parser.getCurrentLocation(),
4182  result.operands)))
4183  return failure();
4184 
4185  // Propagate the types into the region arguments.
4186  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
4187  regionArgs[i].type = functionType.getInput(i);
4188 
4189  return failure(parser.parseRegion(*cond, regionArgs) ||
4190  parser.parseKeyword("do") || parser.parseRegion(*body) ||
4192 }
4193 
4194 void WhileOp::print(OpAsmPrinter &parser) {
4195  printInitializationList(parser, getCondGraph().front().getArguments(),
4196  getInputList(), " ");
4197  parser << " : ";
4198  parser.printFunctionalType(getInputList().getTypes(),
4199  getResults().getTypes());
4200  parser << ' ';
4201  parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
4202  parser << " do ";
4203  parser.printRegion(getBodyGraph());
4204  parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
4205 }
4206 
4207 // Create a rank-1 const tensor for zero point of the source tensor.
4208 std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
4209  Location loc,
4210  Type srcElemType,
4211  int64_t zp) {
4212  srcElemType = getStorageElementTypeOrSelf(srcElemType);
4213  auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
4214  if (llvm::isa<FloatType>(srcElemType)) {
4215  auto zpAttr = DenseElementsAttr::get(
4216  zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
4217  return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4218  }
4219  if (llvm::isa<IntegerType>(srcElemType)) {
4220  auto zpAttr =
4221  DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp));
4222  return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4223  }
4224  llvm::errs() << "zero point is not allowed for unsupported data types\n";
4225  return std::nullopt;
4226 }
4227 
4228 //===----------------------------------------------------------------------===//
4229 // TOSA Shape and Shape Operators Helper functions.
4230 //===----------------------------------------------------------------------===//
4231 
4233  return mlir::isa<tosa::shapeType>(t);
4234 }
4235 
4236 LogicalResult
4238  int rank) {
4239  if (rank < 0)
4240  return emitError() << "invalid rank (must be >= 0): " << rank;
4241  return success();
4242 }
4243 
4245  for (auto v : op->getOperands()) {
4246  if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
4247  Operation *definingOp = v.getDefiningOp();
4248  if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
4249  return op->emitOpError("shape operand is not compile time resolvable");
4250  }
4251  }
4252  }
4253  return success();
4254 }
4255 
4257  for (auto type : op->getOperandTypes()) {
4258  if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4259  return op->emitOpError("must have operands with tosa shape type");
4260  }
4261  }
4262  for (auto type : op->getResultTypes()) {
4263  if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4264  return op->emitOpError("must have result with tosa shape type");
4265  }
4266  }
4267  return success();
4268 }
4269 
4270 LogicalResult
4274  return failure();
4275 
4276  // delegate function that returns rank of shape type
4277  auto getRank = [](const Type type) {
4278  return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4279  };
4280  auto operandTypes = op->getOperandTypes();
4281  auto resultTypes = op->getResultTypes();
4282 
4283  auto rank = getRank(*op->getOperandTypes().begin());
4284  for (auto type : operandTypes) {
4285  if (getRank(type) != rank) {
4286  return op->emitOpError("operands don't have matching ranks");
4287  }
4288  }
4289  for (auto type : resultTypes) {
4290  if (getRank(type) != rank) {
4291  return op->emitOpError("result shape has different rank than operands");
4292  }
4293  }
4294  return success();
4295 }
4296 
4297 //===----------------------------------------------------------------------===//
4298 // TOSA Shape Operators verify functions.
4299 //===----------------------------------------------------------------------===//
4300 
4301 LogicalResult tosa::ConstShapeOp::verify() {
4302  // check one dimensional rank
4303  auto valuesRank = getValues().getType().getRank();
4304  if (valuesRank != 1)
4305  return emitOpError("expect elements in attribute values with rank 1");
4306  // check that number of elements in values attr equal to rank of result shape
4307  auto count = getValues().getNumElements();
4308  auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
4309  if (count != rank && (count != 1 || rank != 0)) {
4310  return emitOpError("expect number of elements in attribute values (")
4311  << count << ") to be equal to the rank (" << rank
4312  << ") for the result shape type";
4313  }
4314  return success();
4315 }
4316 
4317 //===----------------------------------------------------------------------===//
4318 // TOSA Attribute Definitions.
4319 //===----------------------------------------------------------------------===//
4320 
4321 #define GET_ATTRDEF_CLASSES
4322 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
4323 
4324 //===----------------------------------------------------------------------===//
4325 // TOSA Type Definitions.
4326 //===----------------------------------------------------------------------===//
4327 #define GET_TYPEDEF_CLASSES
4328 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
4329 
4330 //===----------------------------------------------------------------------===//
4331 // TOSA Operator Definitions.
4332 //===----------------------------------------------------------------------===//
4333 
4334 #define GET_OP_CLASSES
4335 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:51
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition: TosaOps.cpp:1245
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
Definition: TosaOps.cpp:2390
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType)
Definition: TosaOps.cpp:936
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:2963
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
Definition: TosaOps.cpp:532
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
Definition: TosaOps.cpp:900
#define REDUCE_SHAPE_INFER(OP)
Definition: TosaOps.cpp:2988
static LogicalResult verifyConvOp(T op)
Definition: TosaOps.cpp:572
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
Definition: TosaOps.cpp:909
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:3178
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition: TosaOps.cpp:1359
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
Definition: TosaOps.cpp:1373
static LogicalResult verifyReduceOp(T op)
Definition: TosaOps.cpp:3013
#define NARY_SHAPE_INFER(OP)
Definition: TosaOps.cpp:3081
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
Definition: TosaOps.cpp:2460
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition: TosaOps.cpp:1223
static LogicalResult verifyConvOpErrorIf(T op)
Definition: TosaOps.cpp:718
static LogicalResult verifyConvOpModes(T op)
Definition: TosaOps.cpp:677
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:3069
static Type getStorageElementTypeOrSelf(Type type)
Definition: TosaOps.cpp:521
#define COMPATIBLE_RETURN_TYPES(OP)
Definition: TosaOps.cpp:2979
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition: TosaOps.cpp:1404
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
Definition: TosaOps.cpp:1319
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition: TosaOps.cpp:1199
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
Definition: TosaOps.cpp:1274
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
Definition: TosaOps.cpp:858
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition: TosaOps.cpp:136
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
Definition: TosaOps.cpp:2418
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Definition: TosaOps.cpp:3858
static LogicalResult verifyPoolingOp(T op)
Definition: TosaOps.cpp:999
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition: TosaOps.cpp:515
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
Definition: TosaOps.cpp:1493
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:108
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:228
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:254
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:262
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:193
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:94
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:207
This class indicates that op operates on tosa shape types.
Definition: TosaOps.h:130
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:40
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:53
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:50
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
Definition: Operation.cpp:920
LogicalResult verifyTosaShapeOperator(Operation *op)
Definition: TosaOps.cpp:4256
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
Definition: TosaOps.cpp:4271
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
Definition: TosaOps.cpp:4244
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition: Traits.cpp:59
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Definition: QuantUtils.cpp:198
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
Definition: QuantUtils.cpp:289
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
Definition: TosaOps.cpp:225
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
Definition: QuantUtils.cpp:269
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
Definition: QuantUtils.cpp:162
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
Definition: TosaOps.cpp:250
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
Definition: QuantUtils.cpp:214
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
Definition: TosaOps.cpp:4208
bool isa_tosa_shape_type(mlir::Type t)
Definition: TosaOps.cpp:4232
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Definition: QuantUtils.cpp:243
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
Definition: TosaOps.cpp:552
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
Definition: ShapeUtils.h:33
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45