MLIR  22.0.0git
TosaOps.cpp
Go to the documentation of this file.
1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://www.mlplatform.org/tosa/tosa_spec.html
12 //
13 //===----------------------------------------------------------------------===//
14 
22 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/TypeUtilities.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 
31 #include <numeric>
32 
33 using namespace mlir;
34 using namespace mlir::tosa;
35 
36 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
38 
39 //===----------------------------------------------------------------------===//
40 // Tosa dialect interface includes.
41 //===----------------------------------------------------------------------===//
42 
43 #include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
44 #include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
45 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
46 #include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
47 
48 namespace {
49 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
50 
51 //===----------------------------------------------------------------------===//
52 // Dialect Function Inliner Interface.
53 //===----------------------------------------------------------------------===//
54 struct TosaInlinerInterface : public DialectInlinerInterface {
56 
57  //===--------------------------------------------------------------------===//
58  // Analysis Hooks.
59  //===--------------------------------------------------------------------===//
60 
61  /// All operations can be inlined by default.
62  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
63  IRMapping &map) const final {
64  return true;
65  }
66 
67  /// All regions with If and While parent operators can be inlined.
68  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
69  IRMapping &map) const final {
70  return (isa<tosa::IfOp>(dest->getParentOp()) ||
71  isa<tosa::WhileOp>(dest->getParentOp()));
72  }
73 };
74 
75 /// This class implements the bytecode interface for the Tosa dialect.
76 struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
77  TosaDialectBytecodeInterface(Dialect *dialect)
78  : BytecodeDialectInterface(dialect) {}
79 
80  //===--------------------------------------------------------------------===//
81  // Attributes
82 
83  Attribute readAttribute(DialectBytecodeReader &reader) const override {
84  return ::readAttribute(getContext(), reader);
85  }
86 
87  LogicalResult writeAttribute(Attribute attr,
88  DialectBytecodeWriter &writer) const override {
89  return ::writeAttribute(attr, writer);
90  }
91 
92  //===--------------------------------------------------------------------===//
93  // Types
94 
95  Type readType(DialectBytecodeReader &reader) const override {
96  return ::readType(getContext(), reader);
97  }
98 
99  LogicalResult writeType(Type type,
100  DialectBytecodeWriter &writer) const override {
101  return ::writeType(type, writer);
102  }
103 
104  void writeVersion(DialectBytecodeWriter &writer) const final {
105  // TODO: Populate.
106  }
107 
108  std::unique_ptr<DialectVersion>
109  readVersion(DialectBytecodeReader &reader) const final {
110  // TODO: Populate
111  reader.emitError("Dialect does not support versioning");
112  return nullptr;
113  }
114 
115  LogicalResult upgradeFromVersion(Operation *topLevelOp,
116  const DialectVersion &version) const final {
117  return success();
118  }
119 };
120 
121 } // namespace
122 
123 //===----------------------------------------------------------------------===//
124 // TOSA control flow support.
125 //===----------------------------------------------------------------------===//
126 
127 /// Returns the while loop body.
128 SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
129  return {&getBodyGraph()};
130 }
131 
132 //===----------------------------------------------------------------------===//
133 // TOSA variable operator support.
134 //===----------------------------------------------------------------------===//
135 
137  return to_vector(llvm::map_range(shape, [](int64_t dim) {
138  return dim == -1 ? ShapedType::kDynamic : dim;
139  }));
140 }
141 
142 // returns type of variable op
143 RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
144  Type elementType = variableOp.getType();
145  DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
146  auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
147  return RankedTensorType::get(shape, elementType);
148 }
149 
150 //===----------------------------------------------------------------------===//
151 // Tosa dialect initialization.
152 //===----------------------------------------------------------------------===//
153 
154 void TosaDialect::initialize() {
155  addTypes<
156 #define GET_TYPEDEF_LIST
157 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
158  >();
159  addOperations<
160 #define GET_OP_LIST
161 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
162  >();
163  addAttributes<
164 #define GET_ATTRDEF_LIST
165 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
166  >();
167  addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
168  declarePromisedInterfaces<
169  shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
170  ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
171  LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
172  LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
173  BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
174  NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
175  GreaterEqualOp, MatMulOp>();
176 }
177 
179  Type type, Location loc) {
180  // Tosa dialect constants only support ElementsAttr unlike standard dialect
181  // constant which supports all attributes.
182  if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
183  return tosa::ConstShapeOp::create(builder, loc, type,
184  llvm::cast<DenseIntElementsAttr>(value));
185  }
186  if (llvm::isa<ElementsAttr>(value))
187  return tosa::ConstOp::create(builder, loc, type,
188  llvm::cast<ElementsAttr>(value));
189  return nullptr;
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // Parsers and printers
194 //===----------------------------------------------------------------------===//
195 
196 namespace {
197 
198 ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
199  DenseElementsAttr &varShapeAttr,
200  TypeAttr &typeAttr) {
201  if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
202  if (!shapedType.hasRank())
203  return parser.emitError(parser.getCurrentLocation())
204  << "expected ranked type";
205 
206  auto elementType = shapedType.getElementType();
207  typeAttr = TypeAttr::get(elementType);
208  ArrayRef<int64_t> shape = shapedType.getShape();
209  Builder builder(parser.getContext());
210  varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
211  return success();
212  }
213  return parser.emitError(parser.getCurrentLocation())
214  << "expected shaped type";
215 }
216 
217 } // namespace
218 
219 // parses the optional initial value or type for a tosa variable
220 // with initial value:
221 // tosa.variable @name = dense<0.0> : tensor<1x8xf32>
222 //
223 // without initial value:
224 // tosa.variable @name : tensor<1x8xf32>
226  OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
227  Attribute &initialValueAttr) {
228  if (succeeded(parser.parseOptionalEqual())) {
229  if (failed(parser.parseAttribute(initialValueAttr))) {
230  return parser.emitError(parser.getCurrentLocation())
231  << "expected attribute";
232  }
233  if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
234  return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
235  typeAttr);
236  }
237  return parser.emitError(parser.getCurrentLocation())
238  << "expected Typed attr";
239  }
240 
241  initialValueAttr = nullptr;
242  Type parsedType;
243  if (failed(parser.parseColonType(parsedType))) {
244  return parser.emitError(parser.getCurrentLocation())
245  << "expected type after colon";
246  }
247  return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
248 }
249 
251  OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
252  TypeAttr typeAttr, Attribute initialValueAttr) {
253  bool needsSpace = false;
254  if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
255  auto shape =
256  convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
257  Type elementType = typeAttr.getValue();
258  RankedTensorType tensorType =
259  RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
260  auto tensorTypeAttr = TypeAttr::get(tensorType);
261  p << ": ";
262  p.printAttribute(tensorTypeAttr);
263  needsSpace = true; // subsequent attr value needs a space separator
264  }
265  if (initialValueAttr) {
266  if (needsSpace)
267  p << ' ';
268  p << "= ";
269  p.printAttribute(initialValueAttr);
270  }
271 }
272 
273 namespace {
274 
275 // parse attributes with special handling for tosa enum attributes
276 template <typename EnumType>
277 ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser,
278  NamedAttrList &outAttrs) {
279  llvm::StringRef name;
280  if (parser.parseOptionalKeyword(&name) || parser.parseEqual())
281  return failure();
282 
283  // special handling: rounding_mode accepts a *bare* RoundingMode enum
284  // keyword.
285  llvm::StringRef kw;
286  if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
287  if (name == "rounding_mode" &&
288  succeeded(parser.parseOptionalKeyword(&kw))) {
289  auto sym = symbolizeRoundingMode(kw);
290  if (!sym)
291  return parser.emitError(parser.getCurrentLocation())
292  << "invalid rounding_mode value: " << kw;
293  auto attr = RoundingModeAttr::get(parser.getContext(), sym.value());
294  outAttrs.push_back(NamedAttribute(name, attr));
295  return success();
296  }
297  }
298  // special handling: mode accepts a *bare* ResizeMode enum keyword.
299  if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
300  if (name == "mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
301  auto sym = symbolizeResizeMode(kw);
302  if (!sym)
303  return parser.emitError(parser.getCurrentLocation())
304  << "invalid resize mode value: " << kw;
305  auto attr = ResizeModeAttr::get(parser.getContext(), sym.value());
306  outAttrs.push_back(NamedAttribute(name, attr));
307  return success();
308  }
309  }
310  // special handling: nan_mode accepts a *bare* NanPropagationMode enum
311  // keyword.
312  if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
313  if (name == "nan_mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
314  auto sym = symbolizeNanPropagationMode(kw);
315  if (!sym)
316  return parser.emitError(parser.getCurrentLocation())
317  << "invalid nan_mode value: " << kw;
318  auto attr = NanPropagationModeAttr::get(parser.getContext(), sym.value());
319  outAttrs.push_back(NamedAttribute(name, attr));
320  return success();
321  }
322  }
323 
324  // Default path: parse any normal attribute literal, including fully qualified
325  // enum keyword
326  Attribute attr;
327  return parser.parseAttribute(attr, name, outAttrs);
328 }
329 
330 template <typename EnumType>
331 ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
332  // parse operands
334  if (parser.parseCommaSeparatedList(
335  [&]() { return parser.parseOperand(operands.emplace_back()); }))
336  return failure();
337 
338  // Parse { attr-dict } with special handling for enum bare token
339  NamedAttrList attrs;
340  if (succeeded(parser.parseOptionalLBrace()) &&
341  failed(parser.parseOptionalRBrace())) {
342  do {
343  if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
344  return failure();
345  } while (succeeded(parser.parseOptionalComma()));
346  if (parser.parseRBrace())
347  return failure();
348  }
349 
350  FunctionType fnTy;
351  if (parser.parseColonType(fnTy))
352  return failure();
353 
354  // Resolve operands and types
355  if (failed(parser.resolveOperands(operands, fnTy.getInputs(),
356  parser.getCurrentLocation(),
357  result.operands)))
358  return failure();
359 
360  result.addTypes(fnTy.getResult(0));
361  result.addAttributes(attrs);
362 
363  return success();
364 }
365 
366 void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) {
367  parser << namedAttr.getName().strref() << " = ";
368  auto attr = namedAttr.getValue();
369  if (auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
370  parser << roundingModeAttr.getValue();
371  } else if (auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
372  parser << resizeModeAttr.getValue();
373  } else if (auto nanPropagationModeAttr =
374  dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
375  parser << nanPropagationModeAttr.getValue();
376  } else {
377  parser.printAttribute(attr);
378  }
379 }
380 
381 // print with special handling for default valued NanPropagationMode attribute
382 void printWithNanPropagationHandling(OpAsmPrinter &parser, Operation *op) {
383  parser << " ";
384  parser.printOperands(op->getOperands());
385 
386  NamedAttrList toPrint(op->getAttrs());
387  // remove default NanPropagate attribute
388  const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
389  for (auto attr : op->getAttrs()) {
390  if (auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
391  if (nanAttr.getValue() == kDefaultNanValue) {
392  // elide from toPrint
393  toPrint.erase(attr.getName());
394  break;
395  }
396  }
397  }
398 
399  if (!toPrint.empty()) {
400  parser << " {";
401  llvm::interleaveComma(toPrint, parser, [&](const NamedAttribute namedAttr) {
402  printNamedAttr(parser, namedAttr);
403  });
404  parser << "}";
405  }
406 
407  parser << " : ";
408  parser.printFunctionalType(op);
409 }
410 
411 // print with special handling for enums: RoundingMode, ResizeMode
412 void printWithEnumHandling(OpAsmPrinter &parser, Operation *op) {
413  parser << " ";
414  parser.printOperands(op->getOperands());
415 
416  if (!op->getAttrs().empty()) {
417  parser << " {";
418  llvm::interleaveComma(op->getAttrs(), parser,
419  [&](const NamedAttribute namedAttr) {
420  printNamedAttr(parser, namedAttr);
421  });
422  parser << "}";
423  }
424 
425  parser << " : ";
426  parser.printFunctionalType(op);
427 }
428 
429 } // namespace
430 
431 ParseResult RescaleOp::parse(OpAsmParser &parser, OperationState &result) {
432  return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
433 }
434 
435 void RescaleOp::print(OpAsmPrinter &parser) {
436  printWithEnumHandling(parser, *this);
437 }
438 
439 ParseResult ApplyScaleOp::parse(OpAsmParser &parser, OperationState &result) {
440  return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
441 }
442 
443 void ApplyScaleOp::print(OpAsmPrinter &parser) {
444  printWithEnumHandling(parser, *this);
445 }
446 
447 ParseResult ResizeOp::parse(OpAsmParser &parser, OperationState &result) {
448  return parseWithEnumHandling<tosa::ResizeMode>(parser, result);
449 }
450 
451 void ResizeOp::print(OpAsmPrinter &parser) {
452  printWithEnumHandling(parser, *this);
453 }
454 
455 ParseResult ArgMaxOp::parse(OpAsmParser &parser, OperationState &result) {
456  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
457 }
458 
459 void ArgMaxOp::print(OpAsmPrinter &parser) {
460  printWithNanPropagationHandling(parser, *this);
461 }
462 
463 ParseResult MaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) {
464  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
465 }
466 
467 void MaxPool2dOp::print(OpAsmPrinter &parser) {
468  printWithNanPropagationHandling(parser, *this);
469 }
470 
471 ParseResult ClampOp::parse(OpAsmParser &parser, OperationState &result) {
472  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
473 }
474 
475 void ClampOp::print(OpAsmPrinter &parser) {
476  printWithNanPropagationHandling(parser, *this);
477 }
478 
479 ParseResult MaximumOp::parse(OpAsmParser &parser, OperationState &result) {
480  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
481 }
482 
483 void MaximumOp::print(OpAsmPrinter &parser) {
484  printWithNanPropagationHandling(parser, *this);
485 }
486 
487 ParseResult MinimumOp::parse(OpAsmParser &parser, OperationState &result) {
488  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
489 }
490 
491 void MinimumOp::print(OpAsmPrinter &parser) {
492  printWithNanPropagationHandling(parser, *this);
493 }
494 
495 ParseResult ReduceMaxOp::parse(OpAsmParser &parser, OperationState &result) {
496  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
497 }
498 
499 void ReduceMaxOp::print(OpAsmPrinter &parser) {
500  printWithNanPropagationHandling(parser, *this);
501 }
502 
503 ParseResult ReduceMinOp::parse(OpAsmParser &parser, OperationState &result) {
504  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
505 }
506 
507 void ReduceMinOp::print(OpAsmPrinter &parser) {
508  printWithNanPropagationHandling(parser, *this);
509 }
510 
511 //===----------------------------------------------------------------------===//
512 // Tosa utilities.
513 //===----------------------------------------------------------------------===//
514 
515 static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
516  if (lhs % rhs != 0)
517  return std::nullopt;
518  return lhs / rhs;
519 }
520 
522  auto srcType = getElementTypeOrSelf(type);
523  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
524  srcType = quantType.getStorageType();
525  return srcType;
526 }
527 
529  return getStorageElementTypeOrSelf(value.getType());
530 }
531 
532 static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
533  Value valZp, StringRef name) {
535  Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
536 
537  bool bothInts =
538  mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
539  bool sameBitWidth =
540  (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
541 
542  if (!bothInts || !sameBitWidth) {
543  return op->emitOpError()
544  << "expected " << name << " and " << name
545  << "_zp to both be integer of the same bitwidth, but got " << eType
546  << " vs. " << eZpType;
547  }
548  return success();
549 }
550 
551 // Create a pad-const const tensor with value of `val` of required data-type
553  Value src, int32_t val) {
554  const auto srcType = getElementTypeOrSelf(src);
555  const auto srcElemType = getStorageElementTypeOrSelf(src);
556  const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
557  const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
558  const auto padConstAttr{
559  llvm::isa<FloatType>(srcElemType)
560  ? DenseElementsAttr::get(padConstEType,
561  builder.getFloatAttr(srcElemType, val))
562  : DenseElementsAttr::get(padConstEType,
563  builder.getIntegerAttr(srcElemType, val))};
564  return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
565 }
566 
567 //===----------------------------------------------------------------------===//
568 // TOSA Operator Verifiers.
569 //===----------------------------------------------------------------------===//
570 
571 template <typename T>
572 static LogicalResult verifyConvOp(T op) {
573  const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
574  const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
575 
576  auto inputEType = inputType.getElementType();
577  auto weightEType = weightType.getElementType();
578  auto biasEType =
579  llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
580  auto resultEType =
581  llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
582  bool biasIsFloat = llvm::isa<FloatType>(biasEType);
583  bool resultIsFloat = llvm::isa<FloatType>(resultEType);
584 
585  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
586  inputEType = quantType.getStorageType();
587 
588  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
589  weightEType = quantType.getStorageType();
590 
591  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
592  biasEType = quantType.getStorageType();
593 
594  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
595  resultEType = quantType.getStorageType();
596 
597  if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
598  // for now, only enforce bias element type == result element type for
599  // float types.
600  op.emitOpError(
601  "expect both bias and result to have same element type, got ")
602  << biasEType << " and " << resultEType;
603  return failure();
604  }
605 
606  if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
607  isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
608  if (inputEType != weightEType) {
609  op.emitOpError(
610  "expect both input and weight to have same element type, got ")
611  << inputEType << " and " << weightEType;
612  return failure();
613  }
614  }
615 
616  bool inputIsFloat = llvm::isa<FloatType>(inputEType);
617  bool weightIsFloat = llvm::isa<FloatType>(weightEType);
618 
619  // Either both must be float or both non-float.
620  if (inputIsFloat != weightIsFloat) {
621  op.emitOpError(
622  "expect both input and weight to be float or not together, got ")
623  << inputEType << " and " << weightEType;
624  return failure();
625  }
626 
627  auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
628  if (inputEType != inputZpEType) {
629  return op.emitOpError("expect both input and its zero point are the same "
630  "element type, got ")
631  << inputEType << " and " << inputZpEType;
632  }
633 
634  auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
635  if (weightEType != weightZpEType) {
636  return op.emitOpError("expect both weight and its zero point are the same "
637  "element type, got ")
638  << weightEType << " and " << weightZpEType;
639  }
640 
641  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
642  if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
643  return failure();
644 
645  FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
646  if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
647  return failure();
648 
649  return success();
650 }
651 
652 LogicalResult tosa::ConstOp::verify() {
653 
654  auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().getType());
655  auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
656 
657  if (!attrType || !outputType) {
658  emitOpError("expected tensors for attr/result type");
659  return failure();
660  }
661 
662  if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
663  outputType.getElementType())) {
664  if (result.getStorageType() == attrType.getElementType())
665  return success();
666  }
667 
668  if (attrType.getElementType() != outputType.getElementType()) {
669  emitOpError("expected same attr/result element types");
670  return failure();
671  }
672 
673  return success();
674 }
675 
676 template <typename T>
677 static LogicalResult verifyConvOpModes(T op) {
678  auto inputEType =
679  llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
680 
681  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
682  inputEType = quantType.getStorageType();
683 
684  auto accType = op.getAccType();
685  if (inputEType.isInteger(8) && !accType.isInteger(32))
686  return op.emitOpError("accumulator type for i8 tensor is not i32");
687 
688  if (inputEType.isInteger(16) && !accType.isInteger(48))
689  return op.emitOpError("accumulator type for i16 tensor is not i48");
690 
691  if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
692  return op.emitOpError("accumulator type for f8 tensor is not f16");
693 
694  if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
695  return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
696 
697  if (inputEType.isBF16() && !accType.isF32())
698  return op.emitOpError("accumulator type for bf16 tensor is not f32");
699 
700  if (inputEType.isF32() && !accType.isF32())
701  return op.emitOpError("accumulator type for f32 tensor is not f32");
702 
703  auto resultEType =
704  llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
705 
706  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
707  resultEType = quantType.getStorageType();
708 
709  return success();
710 }
711 
712 //===----------------------------------------------------------------------===//
713 // ERROR_IF functions.
714 // ERROR_IF is a predicate that must set an error if the condition holds.
715 //===----------------------------------------------------------------------===//
716 
717 template <typename T>
718 static LogicalResult verifyConvOpErrorIf(T op) {
719  llvm::ArrayRef<int64_t> padding = op.getPad();
720  if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
721  return op.emitOpError("expect all padding values to be >= 0, got ")
722  << padding;
723 
724  llvm::ArrayRef<int64_t> strides = op.getStride();
725  if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
726  return op.emitOpError("expect all stride values to be >= 1, got ")
727  << strides;
728 
729  llvm::ArrayRef<int64_t> dilations = op.getDilation();
730  if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
731  return op.emitOpError("expect all dilation values to be >= 1, got ")
732  << dilations;
733 
734  const RankedTensorType outputType =
735  llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
736  if (!outputType)
737  // Skip following checks if output is not ranked
738  return success();
739 
740  const RankedTensorType inputType =
741  llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
742  const RankedTensorType weightType =
743  llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
744 
745  if (inputType && weightType) {
746  const auto verifyOutputSize =
747  [&op](const int64_t inputSize, const int64_t kernelSize,
748  const int64_t outputSize, const int64_t padBefore,
749  const int64_t padAfter, const int64_t stride,
750  const int64_t dilation, const llvm::StringRef dimName,
751  const llvm::StringRef dimAxis,
752  const llvm::StringRef padBeforeName,
753  const llvm::StringRef padAfterName) -> LogicalResult {
754  if (inputSize == ShapedType::kDynamic ||
755  kernelSize == ShapedType::kDynamic)
756  return success();
757 
758  // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
759 
760  const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
761  inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
762  stride);
763  if (!calculatedOutSizeMinusOne.has_value())
764  return op.emitOpError("expected input_")
765  << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
766  << padAfterName << " - (kernel_" << dimName
767  << " - 1) * dilation_" << dimAxis
768  << " to be wholly divisible by stride_" << dimAxis << ", got ("
769  << inputSize << " - 1 + " << padBefore << " + " << padAfter
770  << " - (" << kernelSize << " - 1) * " << dilation << ") / "
771  << stride;
772 
773  const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
774  if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
775  return op.emitOpError("calculated output ")
776  << dimName << " did not match expected: "
777  << "calculated=" << calculatedOutSize
778  << ", expected=" << outputSize;
779 
780  return success();
781  };
782 
783  // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
784  if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
785  if (failed(verifyOutputSize(
786  inputType.getDimSize(1), weightType.getDimSize(1),
787  outputType.getDimSize(1), padding[0], padding[1], strides[0],
788  dilations[0], "height", "y", "top", "bottom")))
789  return failure();
790 
791  if (failed(verifyOutputSize(
792  inputType.getDimSize(2), weightType.getDimSize(2),
793  outputType.getDimSize(2), padding[2], padding[3], strides[1],
794  dilations[1], "width", "x", "left", "right")))
795  return failure();
796  }
797 
798  // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
799  if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
800  if (failed(verifyOutputSize(
801  inputType.getDimSize(1), weightType.getDimSize(0),
802  outputType.getDimSize(1), padding[0], padding[1], strides[0],
803  dilations[0], "height", "y", "top", "bottom")))
804  return failure();
805 
806  if (failed(verifyOutputSize(
807  inputType.getDimSize(2), weightType.getDimSize(1),
808  outputType.getDimSize(2), padding[2], padding[3], strides[1],
809  dilations[1], "width", "x", "left", "right")))
810  return failure();
811  }
812 
813  // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
814  if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
815  if (failed(verifyOutputSize(
816  inputType.getDimSize(1), weightType.getDimSize(1),
817  outputType.getDimSize(1), padding[0], padding[1], strides[0],
818  dilations[0], "depth", "d", "front", "back")))
819  return failure();
820 
821  if (failed(verifyOutputSize(
822  inputType.getDimSize(2), weightType.getDimSize(2),
823  outputType.getDimSize(2), padding[2], padding[3], strides[1],
824  dilations[1], "height", "y", "top", "bottom")))
825  return failure();
826 
827  if (failed(verifyOutputSize(
828  inputType.getDimSize(3), weightType.getDimSize(3),
829  outputType.getDimSize(3), padding[4], padding[5], strides[2],
830  dilations[2], "width", "x", "left", "right")))
831  return failure();
832  }
833  }
834 
835  const RankedTensorType biasType =
836  llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
837  if (!biasType)
838  // Skip following checks if bias is not ranked
839  return success();
840 
841  const int64_t biasChannels = biasType.getDimSize(0);
842  const int64_t outputChannels =
843  outputType.getDimSize(outputType.getRank() - 1);
844  if (biasChannels == ShapedType::kDynamic ||
845  outputChannels == ShapedType::kDynamic)
846  // Skip following checks if biasChannels or outputChannels is dynamic dim
847  return success();
848 
849  if (biasChannels != outputChannels && biasChannels != 1)
850  return op.emitOpError(
851  "bias channels expected to be equal to output channels (")
852  << outputChannels << ") or 1, got " << biasChannels;
853 
854  return success();
855 }
856 
857 // Verify whether same type and shape of the given two types.
858 static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
859  StringRef name1, Type type2,
860  StringRef name2) {
861  auto shapeType1 = dyn_cast<ShapedType>(type1);
862  auto shapeType2 = dyn_cast<ShapedType>(type2);
863  if (!shapeType1 || !shapeType2)
864  return failure();
865 
866  auto elemType1 = shapeType1.getElementType();
867  auto elemType2 = shapeType2.getElementType();
868  if (elemType1 != elemType2)
869  return op->emitOpError()
870  << "require same element type for " << name1 << " (" << elemType1
871  << ") and " << name2 << " (" << elemType2 << ")";
872 
873  if (failed(verifyCompatibleShape(type1, type2)))
874  return op->emitOpError()
875  << "require same shapes for " << name1 << " (" << type1 << ") and "
876  << name2 << " (" << type2 << ")";
877 
878  return success();
879 }
880 
881 // Verify whether same length, type, and shape of the given two tensor lists.
882 static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
883  StringRef name1,
884  ValueRange list2,
885  StringRef name2) {
886  if (list1.size() != list2.size())
887  return op->emitOpError()
888  << "require same number of values in " << name1 << " ("
889  << list1.size() << ") and " << name2 << " (" << list2.size() << ")";
890 
891  for (auto [type1, type2] :
892  llvm::zip_equal(list1.getTypes(), list2.getTypes())) {
893  if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
894  return failure();
895  }
896 
897  return success();
898 }
899 
900 static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
901  ShapeAdaptor shapeAdaptor(type);
902  if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
903  return success();
904 
905  return shapeAdaptor.getNumElements() == 1 ? success() : failure();
906 }
907 
908 // Returns the first declaration point prior to this operation or failure if
909 // not found.
910 static FailureOr<tosa::VariableOp> findVariableDecl(Operation *op,
911  StringRef symName) {
912  ModuleOp module = op->getParentOfType<ModuleOp>();
913  tosa::VariableOp varOp = nullptr;
914 
915  // TODO: Adopt SymbolTable trait to Varible ops.
916  // Currently, the variable's definition point is searched via walk(),
917  // starting from the top-level ModuleOp and stopping at the point of use. Once
918  // TOSA control flow and variable extensions reach the complete state, may
919  // leverage MLIR's Symbol Table functionality to look up symbol and enhance
920  // the search to a TOSA specific graph traversal over the IR structure.
921  module.walk([&](Operation *tempOp) {
922  // Reach this op itself.
923  if (tempOp == op) {
924  return WalkResult::interrupt();
925  }
926 
927  if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
928  if (symName == tosaOp.getName()) {
929  varOp = tosaOp;
930  return WalkResult::interrupt();
931  }
932  }
933 
934  return WalkResult::advance();
935  });
936 
937  if (varOp)
938  return varOp;
939 
940  return failure();
941 }
942 
943 template <typename T>
944 static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
945  StringRef symName = op.getName();
946  FailureOr<tosa::VariableOp> varOp = findVariableDecl(op, symName);
947  if (failed(varOp))
948  return op->emitOpError("'")
949  << symName << "' has not been declared by 'tosa.variable'";
950 
951  // Verify type and shape
952  auto variableType = getVariableType(varOp.value());
953  if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
954  "the input tensor")
955  .failed())
956  return failure();
957 
958  return success();
959 }
960 
961 // verify that inType and outType have same element types
962 template <typename T>
963 static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
964  auto inputType = llvm::dyn_cast<TensorType>(inType);
965  auto outputType = llvm::dyn_cast<TensorType>(outType);
966  if (!inputType) {
967  op.emitOpError("expect shaped tensor for input, got ") << inType;
968  return failure();
969  }
970  if (!outputType) {
971  op.emitOpError("expect shaped tensor for output, got ") << outType;
972  return failure();
973  }
974  auto inputElementType = inputType.getElementType();
975  auto outputElementType = outputType.getElementType();
976  auto inputQuantType =
977  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
978  auto outputQuantType =
979  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
980  if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
981  (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
982  inputElementType != outputElementType) {
983  // only check if both element types are int/index/float/UniformQuantized
984  // eg, not sure how to check quant::QuantizedType
985  // this happens in test_conv2d_q_grouped_convolution in
986  // tfl-to-tosa-pipeline.mlir
987  op.emitOpError("expect input and output to have same element type, got ")
988  << inputElementType << " and " << outputElementType;
989  return failure();
990  }
991  return success();
992 }
993 
994 LogicalResult tosa::ArgMaxOp::verify() {
995  const ShapedType resultType = llvm::cast<ShapedType>(getType());
996 
997  // Ensure output is of 32-bit integer
998  if (const auto resultETy = resultType.getElementType();
999  !resultETy.isIntOrIndex())
1000  return emitOpError("result tensor is not of integer type");
1001 
1002  const auto inputType = llvm::cast<ShapedType>(getInput().getType());
1003  if (!inputType.hasRank())
1004  return success();
1005 
1006  // Ensure axis is within the tensor rank
1007  const int64_t axis = getAxisAttr().getInt();
1008  if (((axis < 0) || axis >= inputType.getRank()))
1009  return emitOpError("specified axis is outside the rank of the tensor");
1010 
1011  if (!resultType.hasRank())
1012  return success();
1013 
1014  const ArrayRef<int64_t> inputShape = inputType.getShape();
1015  const ArrayRef<int64_t> outputShape = resultType.getShape();
1016  llvm::SmallVector<int64_t> expectedOutputShape(inputShape);
1017  expectedOutputShape.erase(expectedOutputShape.begin() + axis);
1018  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
1019  return emitOpError("expected output shape '")
1020  << expectedOutputShape << "', got '" << outputShape << "'";
1021 
1022  return success();
1023 }
1024 
1025 template <typename T>
1026 static LogicalResult verifyPoolingOp(T op) {
1027  const llvm::ArrayRef<int64_t> kernel = op.getKernel();
1028  if (llvm::any_of(kernel, [](int64_t s) { return s < 1; }))
1029  return op.emitOpError("expect all kernel values to be >= 1, got ")
1030  << kernel;
1031 
1032  const llvm::ArrayRef<int64_t> strides = op.getStride();
1033  if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
1034  return op.emitOpError("expect all stride values to be >= 1, got ")
1035  << strides;
1036 
1037  const llvm::ArrayRef<int64_t> padding = op.getPad();
1038  if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
1039  return op.emitOpError("expect all padding values to be >= 0, got ")
1040  << padding;
1041 
1042  // Padding must be less than kernel size to avoid a divide-by-zero
1043  const int64_t kernelX = kernel[1];
1044  const int64_t padLeft = padding[2];
1045  const int64_t padRight = padding[3];
1046  if (padRight >= kernelX || padLeft >= kernelX)
1047  return op.emitOpError("expected left/right padding to be less than the "
1048  "width of the kernel, got pad_left=")
1049  << padLeft << ", pad_right=" << padRight << ", kernel_x=" << kernelX;
1050 
1051  const int64_t kernelY = kernel[0];
1052  const int64_t padTop = padding[0];
1053  const int64_t padBottom = padding[1];
1054  if (padTop >= kernelY || padBottom >= kernelY)
1055  return op.emitOpError("expected top/bottom padding to be less than the "
1056  "height of the kernel, got pad_top=")
1057  << padTop << ", pad_bottom=" << padBottom
1058  << ", kernel_y=" << kernelY;
1059 
1060  const auto inputType =
1061  llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
1062  const auto outputType =
1063  llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
1064  if (!inputType || !outputType)
1065  return success();
1066 
1067  const auto verifyOutputSize =
1068  [&op](const int64_t inputSize, const int64_t outputSize,
1069  const int64_t kernelSize, const int64_t strideSize,
1070  const int64_t padBefore, const int64_t padAfter,
1071  const llvm::StringRef dimName, const llvm::StringRef dimAxis,
1072  const llvm::StringRef padBeforeName,
1073  const llvm::StringRef padAfterName) -> LogicalResult {
1074  if (ShapedType::isDynamic(inputSize))
1075  return success();
1076 
1077  const std::optional<int64_t> calculatedOutSizeMinusOne =
1078  idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1079  if (!calculatedOutSizeMinusOne.has_value())
1080  return op.emitOpError("expected input_")
1081  << dimName << " + pad_" << padBeforeName << " + pad_"
1082  << padAfterName << " - kernel_" << dimAxis
1083  << " to be wholly divisible by stride_" << dimAxis << ", got ("
1084  << inputSize << " + " << padBefore << " + " << padAfter << " - "
1085  << kernelSize << ") / " << strideSize;
1086 
1087  const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1088  if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1089  return op.emitOpError("calculated output ")
1090  << dimName << " did not match expected: "
1091  << "calculated=" << calculatedOutSize
1092  << ", expected=" << outputSize;
1093 
1094  return success();
1095  };
1096 
1097  if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
1098  kernel[0], strides[0], padding[0], padding[1],
1099  "height", "y", "top", "bottom")))
1100  return failure();
1101 
1102  if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
1103  kernel[1], strides[1], padding[2], padding[3],
1104  "width", "x", "left", "right")))
1105  return failure();
1106 
1107  return success();
1108 }
1109 
1110 LogicalResult tosa::AvgPool2dOp::verify() {
1111  if (failed(verifyPoolingOp(*this)))
1112  return failure();
1113 
1114  const Type inputETy = getStorageElementTypeOrSelf(getInput().getType());
1115  const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType());
1116  const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType());
1117  const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType());
1118 
1119  auto accType = getAccType();
1120  if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1121  return emitOpError("accumulator type for integer tensor is not i32");
1122 
1123  if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
1124  return emitOpError("accumulator type for f16 tensor is not f16/f32");
1125 
1126  if (inputETy.isBF16() && !accType.isF32())
1127  return emitOpError("accumulator type for bf16 tensor is not f32");
1128 
1129  if (inputETy.isF32() && !accType.isF32())
1130  return emitOpError("accumulator type for f32 tensor is not f32");
1131 
1132  if (inputETy != inputZpETy)
1133  return emitOpError("expect both input and its zero point are the same "
1134  "element type, got ")
1135  << inputETy << " and " << inputZpETy;
1136 
1137  if (resultETy != outputZpETy)
1138  return emitOpError("expect both output and its zero point are the same "
1139  "element type, got ")
1140  << resultETy << " and " << outputZpETy;
1141 
1142  FailureOr<int64_t> maybeIZp = getInputZeroPoint();
1143  if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
1144  return failure();
1145 
1146  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1147  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
1148  return failure();
1149 
1150  return success();
1151 }
1152 
1153 LogicalResult tosa::ClampOp::verify() {
1154  mlir::Type inputETy =
1155  llvm::cast<ShapedType>(getInput().getType()).getElementType();
1156  if (auto quantType =
1157  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1158  inputETy = quantType.getStorageType();
1159  }
1160  mlir::Type outputETy =
1161  llvm::cast<ShapedType>(getOutput().getType()).getElementType();
1162  if (auto quantType =
1163  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1164  outputETy = quantType.getStorageType();
1165  }
1166  if (inputETy != outputETy)
1167  return emitOpError("input/output element types are incompatible.");
1168 
1169  auto maxValAttr = getMaxValAttr();
1170  auto minValAttr = getMinValAttr();
1171 
1172  unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
1173 
1174  if (inputETy.isInteger(dataTypeBitWidth)) {
1175  // if input datatype is integer, check that the min_val/max_val attributes
1176  // are integer attributes, and that their type is the same as the input's
1177  // datatype
1178  auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1179  auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1180  if (!intMaxValAttr || !intMinValAttr ||
1181  (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1182  (intMaxValAttr.getType() != inputETy))
1183  return emitOpError("min/max attributes types are incompatible with "
1184  "input/output element types.");
1185 
1186  const bool isUnsigned = inputETy.isUnsignedInteger();
1187  const bool isBoolean = inputETy.isInteger(1);
1188  const APInt minVal = intMinValAttr.getValue();
1189  const APInt maxVal = intMaxValAttr.getValue();
1190  if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1191  return emitOpError("expected min_val <= max_val, got min_val=")
1192  << minValAttr << ", max_val=" << maxValAttr;
1193  } else {
1194  // otherwise, input datatype is float, check that the min_val/max_val
1195  // attributes share the same type and that their type is the same as the
1196  // input's datatype
1197  auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1198  auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1199  if (!floatMaxValAttr || !floatMinValAttr ||
1200  (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1201  (floatMaxValAttr.getType() != inputETy))
1202  return emitOpError("min/max attributes types are incompatible with "
1203  "input/output element types.");
1204 
1205  const APFloat minVal = floatMinValAttr.getValue();
1206  const APFloat maxVal = floatMaxValAttr.getValue();
1207  if (minVal.isNaN() || maxVal.isNaN())
1208  return emitOpError("min/max attributes should not be 'NaN', got min_val=")
1209  << minValAttr << ", max_val=" << maxValAttr;
1210 
1211  if (maxVal < minVal)
1212  return emitOpError("expected min_val <= max_val, got min_val=")
1213  << minValAttr << ", max_val=" << maxValAttr;
1214  }
1215 
1216  return success();
1217 }
1218 
1219 //===----------------------------------------------------------------------===//
1220 // TOSA Operator Quantization Builders.
1221 //===----------------------------------------------------------------------===//
1222 
1223 /// This builder is called on all convolution operators except TransposeConv,
1224 /// which has specialized output shape semantics. The builder also defines the
1225 /// bitwidth of the output given the bit width of the input & weight content.
1226 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
1227  Type outputType, Value input, Value weight,
1228  Value bias, DenseI64ArrayAttr pad,
1229  DenseI64ArrayAttr stride,
1230  DenseI64ArrayAttr dilation,
1231  TypeAttr accType) {
1232  auto zps = createZPsAsConst(builder, input, weight);
1233  result.addOperands({input, weight, bias, zps.first, zps.second});
1234  result.addAttribute("pad", pad);
1235  result.addAttribute("stride", stride);
1236  result.addAttribute("dilation", dilation);
1237  result.addAttribute("acc_type", accType);
1238  Type finalOutputType = outputType;
1239  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1240  if (quantAttr) {
1241  finalOutputType =
1242  buildConvOpResultTypeInfo(builder, outputType, input, weight);
1243  }
1244  result.addTypes(finalOutputType);
1245 }
1246 
1247 /// Handles tosa.transpose_conv2d which has outpad and output shape
1248 /// attributes.
1249 static void
1251  Type outputType, Value input, Value weight,
1252  Value bias, DenseI64ArrayAttr outpad,
1253  DenseI64ArrayAttr stride, TypeAttr accType) {
1254  auto zps = createZPsAsConst(builder, input, weight);
1255  result.addOperands({input, weight, bias, zps.first, zps.second});
1256  result.addAttribute("out_pad", outpad);
1257  result.addAttribute("stride", stride);
1258  result.addAttribute("acc_type", accType);
1259  Type finalOutputType = outputType;
1260  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1261  if (quantAttr) {
1262  finalOutputType =
1263  buildConvOpResultTypeInfo(builder, outputType, input, weight);
1264  }
1265  result.addTypes(finalOutputType);
1266 }
1267 
1268 /// The tosa.matmul op is also intended to be generated where a fully_connected
1269 /// op must be constructed where the weight is not a constant. In this case,
1270 /// the fully_connected op must be expressed using matmul.
1271 /// TODO: Add link to the leglization document explaining this.
1273  OperationState &result, Type outputType,
1274  Value a, Value b) {
1275  auto zps = createZPsAsConst(builder, a, b);
1276  result.addOperands({a, b, zps.first, zps.second});
1277 
1278  Type finalOutputType{outputType};
1279  if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
1280  auto eType = getStorageElementTypeOrSelf(a.getType());
1281  auto inputBits = eType.getIntOrFloatBitWidth();
1282 
1283  auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1284  assert(outputShapedType && "Output must be a shaped type");
1285 
1286  IntegerType accElementType;
1287  if (inputBits == 16)
1288  accElementType = builder.getIntegerType(48);
1289  else
1290  accElementType = builder.getI32Type();
1291 
1292  finalOutputType = outputShapedType.clone(accElementType);
1293  }
1294  result.addTypes(finalOutputType);
1295 }
1296 
1297 /// Both the tosa.avg_pool2d and unary ops use the same
1298 /// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
1299 /// has additional parameters not part of the unary ops.
1300 static void
1302  Type outputType, Value input,
1303  DenseArrayAttr kernel, DenseArrayAttr stride,
1304  DenseArrayAttr pad, TypeAttr accType) {
1305  const Location loc{result.location};
1306  int64_t inputZp{0};
1307  int64_t outputZp{0};
1308 
1309  if (auto quantAttr =
1310  buildUnaryOpQuantizationAttr(builder, input, outputType)) {
1311  inputZp = quantAttr.getInputZp();
1312  outputZp = quantAttr.getOutputZp();
1313  }
1314  const std::optional<Value> inputZpOp =
1315  createZeroPointTensor(builder, loc, input.getType(), inputZp);
1316  if (!inputZpOp) {
1317  (void)emitError(
1318  loc,
1319  "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1320  }
1321  const std::optional<Value> outputZpOp =
1322  createZeroPointTensor(builder, loc, outputType, outputZp);
1323  if (!outputZpOp) {
1324  (void)emitError(loc, "Failed to create output zero point tensor for "
1325  "quantized AVG_POOL2D op");
1326  }
1327 
1328  if (inputZpOp && outputZpOp) {
1329  result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1330  } else {
1331  // failed to create one or more zero points above: just add input as
1332  // operands this will trigger error in building the op because of missing
1333  // zero points
1334  result.addOperands({input});
1335  }
1336  result.addAttribute("kernel", kernel);
1337  result.addAttribute("stride", stride);
1338  result.addAttribute("pad", pad);
1339  result.addAttribute("acc_type", accType);
1340  result.types.push_back(outputType);
1341 }
1342 
1343 /// This builder is called on single-parameter negate operator
1344 /// to construct input and output zero points based on their
1345 /// types.
1347  OperationState &result, Type outputType,
1348  Value input) {
1349  const Location loc{result.location};
1350  int64_t input1Zp{0};
1351  int64_t outputZp{0};
1352  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
1353  if (quantAttr) {
1354  input1Zp = quantAttr.getInputZp();
1355  outputZp = quantAttr.getOutputZp();
1356  }
1357  const std::optional<Value> input1ZpOp =
1358  createZeroPointTensor(builder, loc, input.getType(), input1Zp);
1359  if (!input1ZpOp) {
1360  (void)emitError(
1361  loc, "Failed to create input1 zero point for quantized NEGATE op");
1362  }
1363 
1364  const std::optional<Value> outputZpOp =
1365  createZeroPointTensor(builder, loc, input.getType(), outputZp);
1366  if (!outputZpOp) {
1367  (void)emitError(
1368  loc, "Failed to create output zero point for quantized NEGATE op");
1369  }
1370 
1371  if (input1ZpOp && outputZpOp) {
1372  result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1373  } else {
1374  // failed to create one or more zero points above: just add input as
1375  // operands. This will trigger error in building the op because of
1376  // missing zero points
1377  result.addOperands({input});
1378  }
1379 
1380  result.types.push_back(outputType);
1381 }
1382 
1383 /// This builder is called on TOSA pad operator that needs to create its own
1384 /// OptionalAttr quantization_attr parameter to scale the padding values
1385 /// correctly. No pad_const is interpreted as zero-padding.
1386 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
1387  Type outputType, Value input,
1388  Value paddings) {
1389  const Location loc{result.location};
1390  int32_t zp{0};
1391  const auto quantAttr = buildPadOpQuantizationAttr(builder, input);
1392  if (quantAttr) {
1393  zp = static_cast<int32_t>(quantAttr.getInputZp());
1394  }
1395  const auto padConstOp{createPadConstTensor(builder, loc, input, zp)};
1396  result.addOperands({input, paddings, padConstOp});
1397  result.types.push_back(outputType);
1398 }
1399 
1400 static void buildVariableOp(OpBuilder &builder, OperationState &result,
1401  StringRef name, Type variableType,
1402  Attribute initialValue) {
1403  const Location loc{result.location};
1404  auto nameAttr = builder.getStringAttr(name);
1405 
1406  auto shapedType = dyn_cast<ShapedType>(variableType);
1407  if (!shapedType) {
1408  (void)emitError(loc, "variable type must be a shaped type");
1409  return;
1410  }
1411  if (!shapedType.hasRank()) {
1412  (void)emitError(loc, "variable type must be a ranked type");
1413  return;
1414  }
1415 
1416  auto elementType = shapedType.getElementType();
1417  auto elementTypeAttr = TypeAttr::get(elementType);
1418  ArrayRef<int64_t> shape = shapedType.getShape();
1419  auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
1420 
1421  result.addAttribute("name", nameAttr);
1422  result.addAttribute("var_shape", varShapeAttr);
1423  result.addAttribute("type", elementTypeAttr);
1424  result.addAttribute("initial_value", initialValue);
1425 }
1426 
1427 //===----------------------------------------------------------------------===//
1428 // TOSA Operator Return Type Inference.
1429 //===----------------------------------------------------------------------===//
1430 
1431 static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
1432  SmallVector<int64_t> &outShape) {
1433  int64_t outRank = 0;
1434  for (int i = 0, e = operands.size(); i != e; ++i) {
1435  auto shape = operands.getShape(i);
1436  if (!shape.hasRank()) {
1437  // TODO(jennik): Update function to have better case handling for
1438  // invalid operands and for ranked tensors.
1439  return failure();
1440  }
1441  outRank = std::max<int64_t>(outRank, shape.getRank());
1442  }
1443 
1444  outShape.resize(outRank, 1);
1445 
1446  for (int i = 0, e = operands.size(); i != e; ++i) {
1447  auto shape = operands.getShape(i);
1448  auto rankDiff = outShape.size() - shape.getRank();
1449 
1450  for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
1451  auto dim1 = outShape[i + rankDiff];
1452  auto dim2 = shape.getDimSize(i);
1453  auto resolvedDim = dim1;
1454 
1455  if (dim1 == 1) {
1456  resolvedDim = dim2;
1457  } else if (dim2 == 1) {
1458  resolvedDim = dim1;
1459  } else if (dim1 != dim2) {
1460  return failure();
1461  }
1462  outShape[i + rankDiff] = resolvedDim;
1463  }
1464  }
1465 
1466  return success();
1467 }
1468 
1469 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1470  MLIRContext *context, ::std::optional<Location> location,
1471  ArgMaxOp::Adaptor adaptor,
1472  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1473  ShapeAdaptor inputShape(adaptor.getInput().getType());
1474  IntegerAttr axis = adaptor.getProperties().axis;
1475  int32_t axisVal = axis.getValue().getSExtValue();
1476 
1477  if (!inputShape.hasRank()) {
1478  inferredReturnShapes.push_back(ShapedTypeComponents());
1479  return success();
1480  }
1481 
1482  SmallVector<int64_t> outShape;
1483  outShape.reserve(inputShape.getRank() - 1);
1484  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1485  if (i == axisVal)
1486  continue;
1487  outShape.push_back(inputShape.getDimSize(i));
1488  }
1489 
1490  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1491  return success();
1492 }
1493 
1494 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1495  MLIRContext *context, ::std::optional<Location> location,
1496  RFFT2dOp::Adaptor adaptor,
1497  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1498  ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1499 
1500  if (!inputShape.hasRank())
1501  return failure();
1502 
1503  llvm::SmallVector<int64_t> outputShape;
1504  outputShape.resize(3, ShapedType::kDynamic);
1505  outputShape[0] = inputShape.getDimSize(0);
1506  outputShape[1] = inputShape.getDimSize(1);
1507  int64_t inWidth = inputShape.getDimSize(2);
1508 
1509  // Note that we can support this calculation symbolically
1510  // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
1511  if (inWidth != ShapedType::kDynamic)
1512  outputShape[2] = inWidth / 2 + 1;
1513 
1514  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1515  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1516 
1517  return success();
1518 }
1519 
1520 static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
1521  const llvm::StringRef dimName) {
1522  const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1523  if (!isPowerOfTwo)
1524  return op->emitOpError("expected ")
1525  << dimName << " to be a power of two, got " << dimSize;
1526 
1527  return success();
1528 }
1529 
1530 LogicalResult tosa::RFFT2dOp::verify() {
1531  const auto outputTypes = getResultTypes();
1532  if (failed(verifyCompatibleShapes(outputTypes)))
1533  return emitOpError("expected output shapes to match, got ") << outputTypes;
1534 
1535  const auto inputType =
1536  llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1537  if (!inputType)
1538  return success();
1539 
1540  const int64_t height = inputType.getDimSize(1);
1541  if (ShapedType::isStatic(height) &&
1542  failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1543  return failure();
1544 
1545  const int64_t width = inputType.getDimSize(2);
1546  if (ShapedType::isStatic(width) &&
1547  failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1548  return failure();
1549 
1550  const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1551  if (!outputType)
1552  return success();
1553 
1554  // Batch and height input/output dimensions should match
1555  if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
1556  outputType.getShape().drop_back())))
1557  return emitOpError("expected batch and height dimensions of input/output "
1558  "to match, got input=")
1559  << inputType << " output=" << outputType;
1560 
1561  // Output width dimension expected to be input_width / 2 + 1
1562  const int64_t outputWidth = outputType.getDimSize(2);
1563  if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1564  (outputWidth != (width / 2) + 1))
1565  return emitOpError(
1566  "expected output width to be equal to input_width / 2 + 1, got ")
1567  << outputWidth;
1568 
1569  return success();
1570 }
1571 
1572 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1573  MLIRContext *context, ::std::optional<Location> location,
1574  FFT2dOp::Adaptor adaptor,
1575  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1576  inferredReturnShapes.push_back(
1577  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
1578  inferredReturnShapes.push_back(
1579  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
1580  return success();
1581 }
1582 
1583 LogicalResult tosa::FFT2dOp::verify() {
1584  const auto inputRealType =
1585  llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1586  const auto inputImagType =
1587  llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
1588  if (!inputRealType || !inputImagType)
1589  return success();
1590 
1591  const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
1592  return ShapedType::isDynamic(a) ? a : b;
1593  };
1594 
1595  const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1596  inputImagType.getDimSize(1));
1597  if (ShapedType::isStatic(height) &&
1598  failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1599  return failure();
1600 
1601  const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1602  inputImagType.getDimSize(2));
1603  if (ShapedType::isStatic(width) &&
1604  failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1605  return failure();
1606 
1607  return success();
1608 }
1609 
1610 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1611  MLIRContext *context, ::std::optional<Location> location,
1612  ConcatOp::Adaptor adaptor,
1613  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1614  // Infer all dimension sizes by reducing based on inputs.
1615  const Properties &prop = adaptor.getProperties();
1616  int32_t axis = prop.axis.getValue().getSExtValue();
1617  llvm::SmallVector<int64_t> outputShape;
1618  bool hasRankedInput = false;
1619  for (auto operand : adaptor.getOperands()) {
1620  ShapeAdaptor operandShape(operand.getType());
1621  if (!operandShape.hasRank())
1622  continue;
1623 
1624  // Copy the Operand's rank.
1625  if (!hasRankedInput)
1626  outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1627 
1628  // Copy shapes until the dim is non-dynamic.
1629  for (int i = 0, s = operandShape.getRank(); i < s; i++) {
1630  if (i == axis || operandShape.isDynamicDim(i))
1631  continue;
1632  if (outputShape[i] == ShapedType::kDynamic)
1633  outputShape[i] = operandShape.getDimSize(i);
1634  if (outputShape[i] != operandShape.getDimSize(i))
1635  return emitOptionalError(location,
1636  "Cannot concat tensors with different sizes"
1637  " on the non-axis dimension ",
1638  i);
1639  }
1640 
1641  hasRankedInput = true;
1642  }
1643 
1644  if (adaptor.getInput1().empty())
1645  return failure();
1646 
1647  Type inputType =
1648  llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1649  if (!hasRankedInput) {
1650  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1651  return success();
1652  }
1653 
1654  // Determine the dimension size along the concatenation axis.
1655  int64_t concatDimSize = 0;
1656  for (auto operand : adaptor.getOperands()) {
1657  ShapeAdaptor operandShape(operand.getType());
1658 
1659  // We need to know the length of the concatenation axis of all inputs to
1660  // determine the dimension size of the output shape.
1661  if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1662  concatDimSize = ShapedType::kDynamic;
1663  break;
1664  }
1665 
1666  concatDimSize += operandShape.getDimSize(axis);
1667  }
1668 
1669  outputShape[axis] = concatDimSize;
1670 
1671  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1672  return success();
1673 }
1674 
1675 LogicalResult tosa::ConcatOp::verify() {
1676  // check that each input has same element type as output
1677  auto outType = getOutput().getType();
1678  const Operation::operand_range inputList = getInput1();
1679 
1680  // Check there is at least one input
1681  if (inputList.empty())
1682  return emitOpError("expect at least one input");
1683 
1684  if (!llvm::all_of(inputList, [&](auto input) {
1685  return succeeded(verifySameElementTypes(
1686  *this, /* inType = */ input.getType(), outType));
1687  })) {
1688  return failure();
1689  }
1690 
1691  const int32_t axis = getAxis();
1692  ShapeAdaptor firstRankedInputShape = nullptr;
1693  for (const auto &input : inputList) {
1694  const Type inputType = input.getType();
1695  ShapeAdaptor currShape(inputType);
1696  if (currShape.hasRank()) {
1697  firstRankedInputShape = currShape;
1698  // Check axis is in expected range
1699  if (axis < 0 || axis >= firstRankedInputShape.getRank())
1700  return emitOpError("expect axis to be within range 0 < axis < "
1701  "rank(input1[firstRankedTensorIdx]), got ")
1702  << axis;
1703  break;
1704  }
1705  }
1706 
1707  const auto allOperandsHasRank = [](const Value input) {
1708  return ShapeAdaptor(input.getType()).hasRank();
1709  };
1710  if (llvm::all_of(inputList, allOperandsHasRank)) {
1711  const int64_t firstInputRank = firstRankedInputShape.getRank();
1712 
1713  for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {
1714  const ShapeAdaptor inputShape(input.getType());
1715  const int64_t inputRank = inputShape.getRank();
1716  const size_t operandNum = index + 1;
1717 
1718  // Check that each operand has the same rank
1719  if (inputRank != firstInputRank)
1720  return emitOpError(
1721  "expect all operands to have the same rank, but got ")
1722  << firstInputRank << " vs " << inputRank << " on operands 0 and "
1723  << operandNum;
1724 
1725  // Check non-axis dims match
1726  for (int i = 0; i < inputRank; i++) {
1727  const int64_t inputDim = inputShape.getDimSize(i);
1728  const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
1729  if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
1730  inputShape.isDynamicDim(i))
1731  continue;
1732  if (inputDim != firstInputDim)
1733  return emitOpError("expect all operand shapes to have the same sizes "
1734  "on non-axis dimensions, but got ")
1735  << inputDim << " vs " << firstInputDim << " at index " << i
1736  << " on operands 0 and " << operandNum;
1737  }
1738  }
1739 
1740  // ERROR_IF(axis_sum != shape[axis]);
1741  int64_t axisSum = 0;
1742  for (const auto &input : inputList) {
1743  const ShapeAdaptor inputShape(input.getType());
1744  if (inputShape.isDynamicDim(axis)) {
1745  // make axisSum negative to indicate invalid value
1746  axisSum = -1;
1747  break;
1748  }
1749  axisSum += inputShape.getDimSize(axis);
1750  }
1751  const ShapeAdaptor outputShape(outType);
1752  if (axisSum >= 0 && outputShape.hasRank() &&
1753  !outputShape.isDynamicDim(axis) &&
1754  axisSum != outputShape.getDimSize(axis))
1755  return emitOpError("requires sum of axis dimensions of input1 "
1756  "equal to output axis dimension, got ")
1757  << axisSum << " and " << outputShape.getDimSize(axis);
1758  }
1759 
1760  return success();
1761 }
1762 
1763 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1764  MLIRContext *context, ::std::optional<Location> location,
1765  ValueShapeRange operands, DictionaryAttr attributes,
1766  OpaqueProperties properties, RegionRange regions,
1767  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1768  auto elementType = IntegerType::get(context, /*width=*/1);
1769 
1770  llvm::SmallVector<int64_t> outShape;
1771  if (resolveBroadcastShape(operands, outShape).failed()) {
1772  inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
1773  return success();
1774  }
1775 
1776  inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
1777  return success();
1778 }
1779 
1780 bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1781  if (l.size() != r.size() || l.size() != 1)
1782  return false;
1783  return succeeded(verifyCompatibleShape(l[0], r[0]));
1784 }
1785 
1786 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1787  MLIRContext *context, ::std::optional<Location> location,
1788  MatMulOp::Adaptor adaptor,
1789  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1790  ShapeAdaptor lhsShape(adaptor.getA().getType());
1791  ShapeAdaptor rhsShape(adaptor.getB().getType());
1792 
1793  // All shapes are dynamic.
1794  SmallVector<int64_t> outShape;
1795  outShape.resize(3, ShapedType::kDynamic);
1796 
1797  if (lhsShape.hasRank()) {
1798  outShape[0] = lhsShape.getDimSize(0);
1799  outShape[1] = lhsShape.getDimSize(1);
1800  }
1801 
1802  if (rhsShape.hasRank()) {
1803  outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1804  : outShape[0];
1805  outShape[2] = rhsShape.getDimSize(2);
1806  }
1807 
1808  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1809  return success();
1810 }
1811 
1812 LogicalResult MatMulOp::verify() {
1813  auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
1814  auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
1815 
1816  // Must be shaped tensor types
1817  if (!aType)
1818  return emitOpError("expect a shaped tensor for input a, got ")
1819  << getA().getType();
1820 
1821  if (!bType)
1822  return emitOpError("expect a shaped tensor for input b, got ")
1823  << getB().getType();
1824 
1825  auto aElementType = aType.getElementType();
1826  auto bElementType = bType.getElementType();
1827 
1828  auto aQuantizedEType =
1829  llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1830  auto bQuantizedEType =
1831  llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1832 
1833  if (aQuantizedEType || bQuantizedEType) {
1834  if (!aQuantizedEType || !bQuantizedEType) {
1835  return emitOpError("expect operands to be both quantized or both not "
1836  "quantized, got ")
1837  << aElementType << " and " << bElementType;
1838  }
1839  // both a and b have quantized element types
1840  auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1841  auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1842  if (aQuantWidth != bQuantWidth) {
1843  return emitOpError("expect quantized operands to have same widths, got ")
1844  << aQuantWidth << " and " << bQuantWidth;
1845  }
1846  } else {
1847  // non-quantized element types
1848  if (aElementType != bElementType) {
1849  return emitOpError("expect same element type for inputs a and b, got ")
1850  << aElementType << " and " << bElementType;
1851  }
1852  }
1853 
1854  // check a_zp and b_zp
1855  auto aEType = getStorageElementTypeOrSelf(aType);
1856  auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
1857  if (aEType != aZpEType) {
1858  return emitOpError("expect input a and a_zp have the same "
1859  "element type, got ")
1860  << aEType << " and " << aZpEType;
1861  }
1862 
1863  auto bEType = getStorageElementTypeOrSelf(bType);
1864  auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
1865  if (bEType != bZpEType) {
1866  return emitOpError("expect input b and b_zp have the same "
1867  "element type, got ")
1868  << bEType << " and " << bZpEType;
1869  }
1870 
1871  FailureOr<int64_t> maybeAZp = getAZeroPoint();
1872  if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1873  return failure();
1874 
1875  FailureOr<int64_t> maybeBZp = getBZeroPoint();
1876  if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1877  return failure();
1878 
1879  return success();
1880 }
1881 
1882 LogicalResult tosa::PadOp::inferReturnTypeComponents(
1883  MLIRContext *context, ::std::optional<Location> location,
1884  PadOp::Adaptor adaptor,
1885  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1886  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1887  auto paddingRank =
1888  cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
1889  SmallVector<int64_t> outputShape;
1890 
1891  // If the input rank is unknown, we can infer the output rank using the
1892  // padding shape's rank divided by 2.
1893  if (!inputShape.hasRank()) {
1894  outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
1895  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1896  return success();
1897  }
1898 
1899  SmallVector<int64_t> paddingValues;
1900  // If the paddings value is not a constant, all dimensions must be dynamic.
1901  if (!tosa::getConstShapeValues(adaptor.getPadding().getDefiningOp(),
1902  paddingValues)) {
1903  outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1904  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1905  return success();
1906  }
1907 
1908  outputShape.reserve(inputShape.getRank());
1909  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1910  if (inputShape.isDynamicDim(i)) {
1911  outputShape.push_back(ShapedType::kDynamic);
1912  continue;
1913  }
1914  auto padFront = paddingValues[i * 2];
1915  auto padBack = paddingValues[i * 2 + 1];
1916  if (padFront < 0 || padBack < 0) {
1917  // if either padding for dim i is -1, output dim is unknown
1918  outputShape.push_back(ShapedType::kDynamic);
1919  continue;
1920  }
1921 
1922  outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
1923  }
1924 
1925  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1926  return success();
1927 }
1928 
1929 LogicalResult tosa::PadOp::verify() {
1930  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1931  /* outType = */ getOutput().getType())
1932  .failed()) {
1933  return failure();
1934  }
1935 
1936  if (auto padConst = getPadConst()) {
1937  if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
1938  /* outType = */ getOutput().getType())
1939  .failed()) {
1940  return failure();
1941  }
1942  }
1943 
1944  RankedTensorType inputType =
1945  llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1946  RankedTensorType outputType =
1947  llvm::dyn_cast<RankedTensorType>(getOutput().getType());
1948  if (!inputType || !outputType)
1949  return success();
1950 
1951  auto inputRank = inputType.getRank();
1952  auto outputRank = outputType.getRank();
1953  if (inputRank != outputRank)
1954  return emitOpError() << "expect same input and output tensor rank, but got "
1955  << "inputRank: " << inputRank
1956  << ", outputRank: " << outputRank;
1957 
1958  DenseIntElementsAttr paddingAttr;
1959  if (!matchPattern(getPadding(), m_Constant(&paddingAttr))) {
1960  return failure();
1961  }
1962 
1963  auto paddingValues = paddingAttr.getValues<APInt>();
1964  if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
1965  return emitOpError() << "padding tensor must have " << inputRank
1966  << " * 2 = " << inputRank * 2 << " elements, but got "
1967  << paddingValues.size();
1968 
1969  auto inputShape = inputType.getShape();
1970  auto outputShape = outputType.getShape();
1971 
1972  for (int64_t i = 0; i < inputRank; ++i) {
1973  int64_t padStart = paddingValues[i * 2].getSExtValue();
1974  int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
1975 
1976  if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
1977  return emitOpError()
1978  << "invalid padding values at dimension " << i
1979  << ": values must be non-negative or -1 for dynamic padding, got ["
1980  << padStart << ", " << padEnd << "]";
1981  }
1982 
1983  // Skip shape verification for dynamic input/output
1984  if (inputShape[i] == ShapedType::kDynamic ||
1985  outputShape[i] == ShapedType::kDynamic)
1986  continue;
1987 
1988  if (outputShape[i] != inputShape[i] + padStart + padEnd) {
1989  return emitOpError() << "mismatch in output shape at dimension " << i
1990  << ": expected " << inputShape[i] << " + "
1991  << padStart << " + " << padEnd << " = "
1992  << (inputShape[i] + padStart + padEnd)
1993  << ", but got " << outputShape[i];
1994  }
1995  }
1996 
1997  return success();
1998 }
1999 
2000 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
2001  MLIRContext *context, ::std::optional<Location> location,
2002  SliceOp::Adaptor adaptor,
2003  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2004 
2005  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2006  SmallVector<int64_t> start;
2007  SmallVector<int64_t> size;
2008 
2009  if (!tosa::getConstShapeValues(adaptor.getStart().getDefiningOp(), start) ||
2010  !tosa::getConstShapeValues(adaptor.getSize().getDefiningOp(), size)) {
2011  auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
2012  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2013  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2014  return success();
2015  }
2016 
2017  // if size[i] is -1, all remaining elements in dimension i are included
2018  // in the slice, similar to TF.
2019  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2020  // initialize outputShape to all unknown
2021  SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
2022  if (inputShape.hasRank()) {
2023  for (size_t i = 0; i < size.size(); i++) {
2024  if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
2025  (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
2026  start[i] < inputShape.getDimSize(i))) {
2027  // size[i] is not 0 and not < -1, and start[i] is in valid range
2028  if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
2029  // input shape has unknown dim[i] - only valid if size[i] > 0
2030  if (size[i] > 0) {
2031  outputShape[i] = size[i];
2032  }
2033  } else {
2034  // input shape has known dim[i]
2035  if (size[i] == -1) {
2036  outputShape[i] = inputShape.getDimSize(i) - start[i];
2037  } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2038  // start[i] + size[i] is within bound of input shape's dim[i]
2039  outputShape[i] = size[i];
2040  }
2041  }
2042  }
2043  }
2044  } else {
2045  outputShape = convertToMlirShape(size);
2046  }
2047  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2048  return success();
2049 }
2050 
2051 LogicalResult tosa::SliceOp::verify() {
2052  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2053  /* outType = */ getOutput().getType())
2054  .failed())
2055  return failure();
2056 
2057  const ShapeAdaptor inputShape(getInput1().getType());
2058  if (inputShape.hasRank()) {
2059  const auto inputRank = inputShape.getRank();
2060  const ShapeAdaptor outputShape(getOutput().getType());
2061  if (outputShape.hasRank() && inputRank != outputShape.getRank())
2062  return emitOpError(
2063  "expect input1 and output to have the same ranks, got ")
2064  << inputRank << " and " << outputShape.getRank();
2065 
2066  const auto startShapeRank =
2067  llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
2068  if (inputRank != startShapeRank)
2069  return emitOpError("length of start is not equal to rank of input shape");
2070 
2071  const auto sizeShapeRank =
2072  llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
2073  if (inputRank != sizeShapeRank)
2074  return emitOpError("length of size is not equal to rank of input shape");
2075  }
2076 
2077  return success();
2078 }
2079 
2080 LogicalResult tosa::MulOp::inferReturnTypeComponents(
2081  MLIRContext *context, ::std::optional<Location> location,
2082  ValueShapeRange operands, DictionaryAttr attributes,
2083  OpaqueProperties properties, RegionRange regions,
2084  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2085  // mul op's output shape only depend on input1 and input2, not on shift
2086  ValueShapeRange twoInputs = operands.drop_back();
2087  llvm::SmallVector<int64_t> outShape;
2088  if (resolveBroadcastShape(twoInputs, outShape).failed()) {
2089  inferredReturnShapes.push_back(ShapedTypeComponents());
2090  } else {
2091  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2092  }
2093  return success();
2094 }
2095 
2096 LogicalResult tosa::MulOp::verify() {
2097  const Value output = getOutput();
2098  auto resElemType = getElementTypeOrSelf(output);
2099 
2100  // Verify if the element type among operands and result match tosa
2101  // specification.
2102  if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2103  IntegerType lhsIntType =
2104  dyn_cast<IntegerType>(getElementTypeOrSelf(getInput1()));
2105  IntegerType rhsIntType =
2106  dyn_cast<IntegerType>(getElementTypeOrSelf(getInput2()));
2107  if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2108  return emitOpError("requires the same element type for all operands");
2109 
2110  // Though the spec requires the element type of result to be i32, a more
2111  // relaxed way is provided at dialect level for easier cooperating with
2112  // other dialects.
2113  if (lhsIntType.getWidth() > resIntType.getWidth())
2114  return emitOpError("invalid data type size for operands or result");
2115 
2116  } else {
2117  // For other supported type, the spec requires requires the same element
2118  // type for all operands (excludes `shift` operand) and results.
2119  for (int i = 0; i < 2; ++i) {
2120  if (getElementTypeOrSelf(getOperand(i)) != resElemType)
2121  return emitOpError(
2122  "requires the same element type for all operands and results");
2123  }
2124 
2125  // verify shift has value 0 for non-integer types
2126  ElementsAttr shift_elem;
2127  if (matchPattern(getShift(), m_Constant(&shift_elem))) {
2128  int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
2129  if (shift != 0) {
2130  return emitOpError() << "require shift to be 0 for float type";
2131  }
2132  }
2133  }
2134 
2135  // Verify the op has same ranks for all main operands (excludes extra operands
2136  // such as shift of mul op, so this is the only difference with the built-in
2137  // `SameOperandsAndResultRank` trait) and results types, if known.
2138  TypeRange operandTypes = getOperandTypes();
2139  ShapedType aType = cast<ShapedType>(operandTypes[0]);
2140  ShapedType bType = cast<ShapedType>(operandTypes[1]);
2141 
2142  const bool aHasRank = aType.hasRank();
2143  const bool bHasRank = bType.hasRank();
2144  if (aHasRank && bHasRank) {
2145  const int64_t aRank = aType.getRank();
2146  const int64_t bRank = bType.getRank();
2147  if (aRank != bRank)
2148  return emitOpError("a and b operands don't have matching ranks, got ")
2149  << aRank << " and " << bRank;
2150 
2151  // check for broadcast compatible shapes
2152  SmallVector<int64_t> resultShape;
2154  aType.getShape(), bType.getShape(), resultShape))
2155  return emitOpError("a and b operands don't have broadcast-compatible "
2156  "shapes, got ")
2157  << aType << " and " << bType;
2158  }
2159 
2160  ShapedType resultType = cast<ShapedType>(output.getType());
2161  if (!resultType.hasRank())
2162  return success();
2163 
2164  const int64_t resultRank = resultType.getRank();
2165  if (aHasRank && resultRank != aType.getRank())
2166  return emitOpError("result type has different rank than a, got ")
2167  << resultRank << " vs " << aType.getRank();
2168  if (bHasRank && resultRank != bType.getRank())
2169  return emitOpError("result type has different rank than b, got ")
2170  << resultRank << " vs " << bType.getRank();
2171 
2172  return success();
2173 }
2174 
2175 LogicalResult tosa::TableOp::inferReturnTypeComponents(
2176  MLIRContext *context, ::std::optional<Location> location,
2177  TableOp::Adaptor adaptor,
2178  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2179  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2180 
2181  if (!inputShape.hasRank()) {
2182  inferredReturnShapes.push_back(ShapedTypeComponents());
2183  return success();
2184  }
2185 
2186  inferredReturnShapes.resize(1);
2187  inputShape.getDims(inferredReturnShapes[0]);
2188  return success();
2189 }
2190 
2191 LogicalResult tosa::TableOp::verify() {
2192  const TensorType inputType = getInput1().getType();
2193  const TensorType outputType = getOutput().getType();
2194 
2195  if (!inputType.hasRank() || !outputType.hasRank())
2196  return success();
2197 
2198  if (inputType.getRank() != outputType.getRank())
2199  return emitOpError()
2200  << "expected input tensor rank to equal result tensor rank";
2201 
2202  auto inputDims = inputType.getShape();
2203  auto outputDims = outputType.getShape();
2204  for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
2205  int64_t dim = it.index();
2206  auto [inputDim, outputDim] = it.value();
2207  if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2208  return emitOpError() << "dim(result, " << dim << ") = " << outputDim
2209  << " doesn't match dim(input, " << dim
2210  << ") = " << inputDim;
2211  }
2212  }
2213  return success();
2214 }
2215 
2216 LogicalResult
2217 tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
2218  // Multiples must be constants.
2219  DenseIntElementsAttr multiplesAttr;
2220  if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
2221  return failure();
2222  multiples = llvm::to_vector(
2223  llvm::map_range(multiplesAttr.getValues<APInt>(),
2224  [](const APInt &val) { return val.getSExtValue(); }));
2225  return success();
2226 }
2227 
2228 LogicalResult tosa::TileOp::inferReturnTypeComponents(
2229  MLIRContext *context, ::std::optional<Location> location,
2230  TileOp::Adaptor adaptor,
2231  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2232  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2233  SmallVector<int64_t> multiples;
2234  if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
2235  multiples)) {
2236  auto rank =
2237  cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2238  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2239  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2240  return success();
2241  } else {
2242  multiples = convertToMlirShape(multiples);
2243  }
2244 
2245  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2246  SmallVector<int64_t> outputShape;
2247  if (!inputShape.hasRank()) {
2248  outputShape.resize(multiples.size(), ShapedType::kDynamic);
2249  inferredReturnShapes.push_back(
2250  ShapedTypeComponents(outputShape, inputType));
2251  return success();
2252  } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
2253  return failure();
2254 
2255  // Any non dynamic dimension can be multiplied to a known size.
2256  outputShape.reserve(multiples.size());
2257  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2258  if (multiples[i] == ShapedType::kDynamic) {
2259  outputShape.push_back(ShapedType::kDynamic);
2260  } else {
2261  int64_t dim = inputShape.getDimSize(i);
2262  if (dim != ShapedType::kDynamic)
2263  dim *= multiples[i];
2264  outputShape.push_back(dim);
2265  }
2266  }
2267 
2268  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2269  return success();
2270 }
2271 
2272 LogicalResult tosa::TileOp::verify() {
2273  if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
2274  /* outType = */ getOutput().getType())
2275  .failed()) {
2276  return failure();
2277  }
2278  ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
2279  ShapedType outputType = llvm::cast<ShapedType>(getType());
2280 
2281  shapeType multiplesType =
2282  llvm::cast<tosa::shapeType>(getMultiples().getType());
2283 
2284  auto multiplesRank = multiplesType.getRank();
2285 
2286  if (inputType.hasRank()) {
2287  if (inputType.getRank() != multiplesRank)
2288  return emitOpError("expect 'multiples' to have rank ")
2289  << inputType.getRank() << " but got " << multiplesRank << ".";
2290  if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
2291  return emitOpError("expect same input and output tensor rank.");
2292  } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2293  return emitOpError("expect 'multiples' array to have length ")
2294  << outputType.getRank() << " but got " << multiplesRank << ".";
2295 
2296  SmallVector<int64_t> multiples;
2297  if (getConstantMultiples(multiples).succeeded() &&
2298  llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
2299  return emitOpError(
2300  "expect element of 'multiples' to be positive integer or -1.");
2301 
2302  return success();
2303 }
2304 
2305 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2306  if (l.size() != r.size() || l.size() != 1)
2307  return false;
2308  return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
2309 }
2310 
2311 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2312  MLIRContext *context, ::std::optional<Location> location,
2313  ReshapeOp::Adaptor adaptor,
2314  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2315  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2316  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2317  llvm::SmallVector<int64_t> newShapeValue;
2318  if (!tosa::getConstShapeValues(adaptor.getShape().getDefiningOp(),
2319  newShapeValue)) {
2320  auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2321  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2322  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2323  return success();
2324  } else {
2325  newShapeValue = convertToMlirShape(newShapeValue);
2326  }
2327 
2328  // We cannot infer from the total number of elements so we must take the
2329  // shape attribute as exact.
2330  if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2331  inferredReturnShapes.push_back(
2332  ShapedTypeComponents(newShapeValue, inputType));
2333  return success();
2334  }
2335 
2336  // Determine the number of elements covered by the slice of all static
2337  // dimensions. This allows us to infer the length of the remaining dynamic
2338  // dimension.
2339  int64_t numElements = inputShape.getNumElements();
2340  int64_t staticMul = 1;
2341  for (auto val : newShapeValue) {
2342  if (ShapedType::isStatic(val)) {
2343  staticMul *= val;
2344  }
2345  }
2346 
2347  // Determine the length of the dynamic dimension.
2348  for (auto &val : newShapeValue) {
2349  if (ShapedType::isDynamic(val))
2350  val = numElements / staticMul;
2351  }
2352 
2353  inferredReturnShapes.push_back(
2354  ShapedTypeComponents(newShapeValue, inputType));
2355  return success();
2356 }
2357 
2358 llvm::LogicalResult tosa::ReshapeOp::verify() {
2359  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2360  /* outType = */ getOutput().getType())
2361  .failed()) {
2362  return failure();
2363  }
2364  TensorType inputType = getInput1().getType();
2365 
2366  SmallVector<int64_t> shapeValues;
2367  if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
2368  // skip following checks if shape is not constant
2369  return mlir::success();
2370  }
2371 
2372  int missingDims = llvm::count(shapeValues, -1);
2373  if (missingDims > 1)
2374  return emitOpError() << "expected at most one target dimension to be -1";
2375 
2376  const auto outputType = dyn_cast<RankedTensorType>(getType());
2377  if (!outputType)
2378  return success();
2379 
2380  if ((int64_t)shapeValues.size() != outputType.getRank())
2381  return emitOpError() << "new shape does not match result rank";
2382 
2383  for (auto [newShapeDim, outputShapeDim] :
2384  zip(shapeValues, outputType.getShape())) {
2385  if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2386  outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2387  return emitOpError() << "new shape is inconsistent with result shape";
2388 
2389  if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2390  return emitOpError() << "new shape has invalid tensor dimension size "
2391  << newShapeDim;
2392  }
2393 
2394  if (inputType.hasStaticShape()) {
2395  int64_t inputElementsNum = inputType.getNumElements();
2396  if (outputType.hasStaticShape()) {
2397  int64_t outputElementsNum = outputType.getNumElements();
2398  if (inputElementsNum != outputElementsNum) {
2399  return emitOpError() << "cannot reshape " << inputElementsNum
2400  << " elements into " << outputElementsNum;
2401  }
2402  }
2403 
2404  int64_t newShapeElementsNum = std::accumulate(
2405  shapeValues.begin(), shapeValues.end(), 1LL,
2406  [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
2407  bool isStaticNewShape =
2408  llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
2409  if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2410  (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2411  return emitOpError() << "cannot reshape " << inputElementsNum
2412  << " elements into " << newShapeElementsNum;
2413  }
2414  }
2415 
2416  return mlir::success();
2417 }
2418 
2419 // return failure if val is not a constant
2420 // set zp to -1 if val is non-zero float or val is not integer nor float
2421 // otherwise set zp to val's constant value
2422 static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
2423  ElementsAttr zpAttr;
2424  if (!matchPattern(val, m_Constant(&zpAttr))) {
2425  return failure();
2426  }
2427 
2428  Type zpElemType = zpAttr.getElementType();
2429 
2430  if (llvm::isa<FloatType>(zpElemType)) {
2431  if (zpAttr.getValues<APFloat>()[0].isZero()) {
2432  return 0;
2433  }
2434  // return non-zero value to trigger error check
2435  return -1;
2436  }
2437 
2438  if (llvm::isa<IntegerType>(zpElemType)) {
2439  if (signExtend)
2440  return zpAttr.getValues<APInt>()[0].getSExtValue();
2441  else
2442  return zpAttr.getValues<APInt>()[0].getZExtValue();
2443  }
2444 
2445  // return non-zero value to trigger error check
2446  return -1;
2447 }
2448 
2449 template <typename T>
2450 static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
2451  const std::string &operand) {
2452  Type zpElemType = getElementTypeOrSelf(val);
2453 
2454  if (!zpElemType.isInteger(8) && zp != 0) {
2455  // convert operand to lower case for error message
2456  std::string lower = operand;
2457  std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
2458  return op.emitOpError()
2459  << lower << " zero point must be zero for non-int8 integer types";
2460  }
2461 
2462  return success();
2463 }
2464 
2465 static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
2466  const int64_t &zp,
2467  const std::string &operand) {
2468  bool isInputZp = (operand == "Input");
2469 
2470  bool tensorUnsigned =
2471  isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2472  StringRef tensorName = isInputZp ? "input" : "output";
2473 
2474  Type zpElemType = getElementTypeOrSelf(zpVal);
2475 
2476  if (zp != 0) {
2477  if (!zpElemType.isInteger(8) &&
2478  !(zpElemType.isInteger(16) && tensorUnsigned)) {
2479  return op.emitOpError()
2480  << "expect " << tensorName << "_zp of 0, got " << zp;
2481  }
2482  if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
2483  return op.emitOpError() << "expect " << tensorName
2484  << "_zp of 0 or 32768 for unsigned int16 "
2485  << tensorName << ", got " << zp;
2486  }
2487  }
2488 
2489  return success();
2490 }
2491 
2492 #define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2493  FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2494  return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2495  } \
2496  LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2497  return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2498  }
2499 
2500 ZERO_POINT_HELPER(Conv2DOp, Input, true)
2501 ZERO_POINT_HELPER(Conv2DOp, Weight, true)
2502 ZERO_POINT_HELPER(Conv3DOp, Input, true)
2503 ZERO_POINT_HELPER(Conv3DOp, Weight, true)
2504 ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true)
2505 ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true)
2506 ZERO_POINT_HELPER(TransposeConv2DOp, Input, true)
2507 ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true)
2508 ZERO_POINT_HELPER(AvgPool2dOp, Input, true)
2509 ZERO_POINT_HELPER(AvgPool2dOp, Output, true)
2510 ZERO_POINT_HELPER(MatMulOp, A, true)
2511 ZERO_POINT_HELPER(MatMulOp, B, true)
2512 ZERO_POINT_HELPER(NegateOp, Input1, true)
2513 ZERO_POINT_HELPER(NegateOp, Output, true)
2514 ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned())
2515 ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned())
2516 #undef ZERO_POINT_HELPER
2517 
2518 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2519  MLIRContext *context, ::std::optional<Location> location,
2520  TransposeOp::Adaptor adaptor,
2521  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2522  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2523 
2524  // If input rank and permutation length is unknown, the output rank is
2525  // unknown.
2526  if (!inputShape.hasRank()) {
2527  inferredReturnShapes.push_back(ShapedTypeComponents());
2528  return success();
2529  }
2530 
2531  const auto inputRank = inputShape.getRank();
2532 
2533  // This would imply the number of permutations does not match the rank of
2534  // the input which is illegal.
2535  if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
2536  return failure();
2537  }
2538 
2539  SmallVector<int64_t> outputShape;
2540  // Rank-0 means no permutations matter.
2541  if (inputRank == 0) {
2542  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2543  return success();
2544  }
2545 
2546  // Check whether the input dimensions are all the same.
2547  bool allTheSame = true;
2548  for (int i = 1, s = inputRank; i < s; i++) {
2549  if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
2550  allTheSame = false;
2551  break;
2552  }
2553  }
2554 
2555  // If all of the input dimensions are the same we don't care about the
2556  // permutation.
2557  if (allTheSame) {
2558  outputShape.resize(inputRank, inputShape.getDimSize(0));
2559  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2560  return success();
2561  }
2562 
2563  outputShape.resize(inputRank, ShapedType::kDynamic);
2564 
2565  // Constant permutation values must be within the input rank.
2566  if (llvm::any_of(adaptor.getPerms(),
2567  [inputRank](const auto i) { return i >= inputRank; }))
2568  return failure();
2569 
2570  outputShape.reserve(inputRank);
2571  for (int i = 0, s = inputRank; i < s; i++) {
2572  outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
2573  }
2574 
2575  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2576  return success();
2577 }
2578 
2579 LogicalResult tosa::TransposeOp::verify() {
2580  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2581  /* outType = */ getOutput().getType())
2582  .failed()) {
2583  return failure();
2584  }
2585 
2586  const ShapeAdaptor inputShape(getInput1().getType());
2587  const ShapeAdaptor outputShape(getOutput().getType());
2588 
2589  const llvm::ArrayRef<int32_t> constantPerms = getPerms();
2590 
2591  if (inputShape.hasRank() &&
2592  constantPerms.size() != static_cast<size_t>(inputShape.getRank()))
2593  return emitOpError() << "expected perms attribute to have size "
2594  << inputShape.getRank()
2595  << " (input rank) but got size "
2596  << constantPerms.size();
2597 
2598  if (inputShape.hasRank() && outputShape.hasRank() &&
2599  inputShape.getRank() != outputShape.getRank())
2600  return emitOpError()
2601  << "expected input tensor rank to equal result tensor rank";
2602 
2603  if (outputShape.hasRank() &&
2604  constantPerms.size() != static_cast<size_t>(outputShape.getRank()))
2605  return emitOpError() << "expected perms attribute to have size "
2606  << outputShape.getRank()
2607  << " (output rank) but got size "
2608  << constantPerms.size();
2609 
2610  if (!llvm::all_of(constantPerms,
2611  [&constantPerms](int32_t s) {
2612  return s >= 0 &&
2613  static_cast<size_t>(s) < constantPerms.size();
2614  }) ||
2615  !isPermutationVector(llvm::to_vector(llvm::map_range(
2616  constantPerms, [](int32_t v) -> int64_t { return v; }))))
2617  return emitOpError() << "expected valid permutation indices";
2618 
2619  // ERROR_IF(tensor_size(shape1) != tensor_size(shape))
2620  if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2621  inputShape.getNumElements() != outputShape.getNumElements())
2622  return emitOpError() << "expected input1 and output to have same numbers "
2623  "of elements, got "
2624  << inputShape.getNumElements() << " and "
2625  << outputShape.getNumElements();
2626 
2627  // Verify that the types of the input and output tensors are properly
2628  // permuted.
2629  if (inputShape.hasRank() && outputShape.hasRank()) {
2630  for (auto i = 0; i < outputShape.getRank(); i++) {
2631  if (inputShape.isDynamicDim(constantPerms[i]) ||
2632  outputShape.isDynamicDim(i))
2633  continue;
2634 
2635  if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2636  return emitOpError()
2637  << "expected output tensor dim " << i << " to match "
2638  << "input dim " << constantPerms[i] << " with value of "
2639  << inputShape.getDimSize(constantPerms[i]);
2640  }
2641  }
2642 
2643  return success();
2644 }
2645 
2646 LogicalResult TransposeOp::reifyResultShapes(
2647  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2648 
2649  const llvm::ArrayRef<int32_t> transposePerms = getPerms();
2650 
2651  Value input = getInput1();
2652  auto inputType = cast<TensorType>(input.getType());
2653 
2654  SmallVector<OpFoldResult> returnedDims(inputType.getRank());
2655  for (auto dim : transposePerms) {
2656  int32_t dimInInput = transposePerms[dim];
2657  if (inputType.isDynamicDim(dimInInput))
2658  returnedDims[dim] =
2659  tensor::DimOp::create(builder, getLoc(), input, dimInInput)
2660  .getResult();
2661  else
2662  returnedDims[dim] =
2663  builder.getIndexAttr(inputType.getDimSize(dimInInput));
2664  }
2665 
2666  reifiedReturnShapes.emplace_back(std::move(returnedDims));
2667  return success();
2668 }
2669 
2670 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2671  MLIRContext *context, ::std::optional<Location> location,
2672  GatherOp::Adaptor adaptor,
2673  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2674  llvm::SmallVector<int64_t> outputShape;
2675  outputShape.resize(3, ShapedType::kDynamic);
2676 
2677  ShapeAdaptor valuesShape(adaptor.getValues().getType());
2678  if (valuesShape.hasRank()) {
2679  outputShape[0] = valuesShape.getDimSize(0);
2680  outputShape[2] = valuesShape.getDimSize(2);
2681  }
2682 
2683  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2684  if (indicesShape.hasRank()) {
2685  if (outputShape[0] == ShapedType::kDynamic)
2686  outputShape[0] = indicesShape.getDimSize(0);
2687  if (outputShape[1] == ShapedType::kDynamic)
2688  outputShape[1] = indicesShape.getDimSize(1);
2689  }
2690 
2691  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2692  return success();
2693 }
2694 
2695 LogicalResult tosa::GatherOp::verify() {
2696  if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
2697  /* outType = */ getOutput().getType())
2698  .failed()) {
2699  return failure();
2700  }
2701 
2702  const ShapeAdaptor valuesShape(getValues().getType());
2703  const ShapeAdaptor indicesShape(getIndices().getType());
2704  const ShapeAdaptor outputShape(getOutput().getType());
2705 
2706  int64_t N = ShapedType::kDynamic;
2707  int64_t W = ShapedType::kDynamic;
2708  int64_t C = ShapedType::kDynamic;
2709 
2710  if (valuesShape.hasRank()) {
2711  N = valuesShape.getDimSize(0);
2712  C = valuesShape.getDimSize(2);
2713  }
2714  if (indicesShape.hasRank()) {
2715  const int64_t indicesN = indicesShape.getDimSize(0);
2716  W = indicesShape.getDimSize(1);
2717  if (N == ShapedType::kDynamic)
2718  N = indicesN;
2719  else if (indicesN != ShapedType::kDynamic && N != indicesN)
2720  return emitOpError() << "requires indices dimension 0 to have size " << N
2721  << ", got " << indicesN;
2722  }
2723  if (outputShape.hasRank()) {
2724  const int64_t outputN = outputShape.getDimSize(0);
2725  const int64_t outputW = outputShape.getDimSize(1);
2726  const int64_t outputC = outputShape.getDimSize(2);
2727  if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2728  N != outputN)
2729  return emitOpError() << "requires output dimension 0 to have size " << N
2730  << ", got " << outputN;
2731 
2732  if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2733  W != outputW)
2734  return emitOpError() << "requires output dimension 1 to have size " << W
2735  << ", got " << outputW;
2736  if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2737  C != outputC)
2738  return emitOpError() << "requires output dimension 2 to have size " << C
2739  << ", got " << outputC;
2740  }
2741  return success();
2742 }
2743 
2744 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2745  MLIRContext *context, ::std::optional<Location> location,
2746  ResizeOp::Adaptor adaptor,
2747  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2748  llvm::SmallVector<int64_t, 4> outputShape;
2749  outputShape.resize(4, ShapedType::kDynamic);
2750 
2751  ShapeAdaptor inputShape(adaptor.getInput().getType());
2752  if (!inputShape.hasRank())
2753  return failure();
2754 
2755  outputShape[0] = inputShape.getDimSize(0);
2756  outputShape[3] = inputShape.getDimSize(3);
2757  int64_t inputHeight = inputShape.getDimSize(1);
2758  int64_t inputWidth = inputShape.getDimSize(2);
2759 
2760  if ((inputHeight == ShapedType::kDynamic) ||
2761  (inputWidth == ShapedType::kDynamic))
2762  return failure();
2763 
2764  SmallVector<int64_t> scaleInt, offsetInt, borderInt;
2765  if (!tosa::getConstShapeValues(adaptor.getScale().getDefiningOp(),
2766  scaleInt) ||
2767  !tosa::getConstShapeValues(adaptor.getOffset().getDefiningOp(),
2768  offsetInt) ||
2769  !tosa::getConstShapeValues(adaptor.getBorder().getDefiningOp(),
2770  borderInt)) {
2771  return failure();
2772  }
2773 
2774  // Compute the output shape based on attributes: scale, offset, and border.
2775  const int64_t outputHeight =
2776  (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2777  scaleInt[1]) +
2778  1;
2779 
2780  const int64_t outputWidth =
2781  (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2782  scaleInt[3]) +
2783  1;
2784 
2785  if (outputHeight < 0 || outputWidth < 0) {
2786  return emitOptionalError(
2787  location,
2788  "calculated output height and width must be non-negative, "
2789  "got height = ",
2790  outputHeight, ", width = ", outputWidth);
2791  }
2792 
2793  outputShape[1] = outputHeight;
2794  outputShape[2] = outputWidth;
2795  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2796  return success();
2797 }
2798 
2799 LogicalResult tosa::ResizeOp::verify() {
2800  const Value input = getInput();
2801  const Value output = getOutput();
2802  const RankedTensorType inputType =
2803  llvm::dyn_cast<RankedTensorType>(input.getType());
2804  const RankedTensorType outputType =
2805  llvm::dyn_cast<RankedTensorType>(output.getType());
2806 
2807  SmallVector<int64_t> scaleValues;
2808  SmallVector<int64_t> offsetValues;
2809  SmallVector<int64_t> borderValues;
2810  if (!tosa::getConstShapeValues(getScale().getDefiningOp(), scaleValues) ||
2811  !tosa::getConstShapeValues(getOffset().getDefiningOp(), offsetValues) ||
2812  !tosa::getConstShapeValues(getBorder().getDefiningOp(), borderValues)) {
2813  // Skip following checks if shape is not constant
2814  return success();
2815  }
2816 
2817  if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
2818  return emitOpError("expect all scale values to be > 0, got ")
2819  << scaleValues;
2820 
2821  const int64_t scaleYN = scaleValues[0];
2822  const int64_t scaleYD = scaleValues[1];
2823  const int64_t scaleXN = scaleValues[2];
2824  const int64_t scaleXD = scaleValues[3];
2825 
2826  const int64_t offsetY = offsetValues[0];
2827  const int64_t offsetX = offsetValues[1];
2828 
2829  const int64_t borderY = borderValues[0];
2830  const int64_t borderX = borderValues[1];
2831 
2832  if (!inputType)
2833  return success();
2834  if (!outputType)
2835  return success();
2836 
2837  const int64_t oh = outputType.getDimSize(1);
2838  const int64_t ow = outputType.getDimSize(2);
2839  const int64_t ih = inputType.getDimSize(1);
2840  const int64_t iw = inputType.getDimSize(2);
2841 
2842  // Don't check with input height that could be broadcast (ih != 1)
2843  // since Linalg, a consumer of TOSA, expects broadcasting support
2844  // in resize to be available. Taking the cautious approach for now,
2845  // we can consider removing support for broadcasting later.
2846  if (ih != ShapedType::kDynamic && ih != 1) {
2847  const std::optional<int64_t> calculatedOutHeightMinusOne =
2848  idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
2849  if (!calculatedOutHeightMinusOne.has_value())
2850  return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
2851  "border_y ")
2852  << "to be wholly divisible by scale_y_d, got ((" << ih
2853  << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
2854  << ") / " << scaleYD;
2855  const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
2856  if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
2857  return emitOpError("calculated output height did not match expected: ")
2858  << "calculated=" << calculatedOutHeight << ", expected=" << oh;
2859  }
2860 
2861  // Don't check with input width that could be broadcast (iw != 1)
2862  // since Linalg, a consumer of TOSA, expects broadcasting support
2863  // in resize to be available. Taking the cautious approach for now,
2864  // we can consider removing support for broadcasting later.
2865  if (iw != ShapedType::kDynamic && iw != 1) {
2866  const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
2867  const std::optional<int64_t> calculatedOutWidthMinusOne =
2868  idivCheck(scaledInWidth, scaleXD);
2869  if (!calculatedOutWidthMinusOne.has_value())
2870  return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
2871  "border_x ")
2872  << "to be wholly divisible by scale_x_d, got ((" << iw
2873  << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
2874  << ") / " << scaleXD;
2875  const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
2876  if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
2877  return emitOpError("calculated output width did not match expected: ")
2878  << "calculated=" << calculatedOutWidth << ", expected=" << ow;
2879  }
2880 
2881  return success();
2882 }
2883 
2884 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
2885  MLIRContext *context, ::std::optional<Location> location,
2886  ScatterOp::Adaptor adaptor,
2887  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2888  llvm::SmallVector<int64_t> outputShape;
2889  outputShape.resize(3, ShapedType::kDynamic);
2890 
2891  ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
2892  if (valuesInShape.hasRank()) {
2893  outputShape[0] = valuesInShape.getDimSize(0);
2894  outputShape[1] = valuesInShape.getDimSize(1);
2895  outputShape[2] = valuesInShape.getDimSize(2);
2896  }
2897 
2898  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2899  if (indicesShape.hasRank()) {
2900  if (outputShape[0] == ShapedType::kDynamic)
2901  outputShape[0] = indicesShape.getDimSize(0);
2902  }
2903 
2904  ShapeAdaptor inputShape(adaptor.getInput().getType());
2905  if (inputShape.hasRank()) {
2906  if (outputShape[0] == ShapedType::kDynamic)
2907  outputShape[0] = inputShape.getDimSize(0);
2908  if (outputShape[2] == ShapedType::kDynamic)
2909  outputShape[2] = inputShape.getDimSize(2);
2910  }
2911 
2912  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2913  return success();
2914 }
2915 
2916 LogicalResult tosa::ScatterOp::verify() {
2917  if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
2918  /* outType = */ getValuesOut().getType())
2919  .failed() ||
2920  verifySameElementTypes(*this, /* inType = */ getInput().getType(),
2921  /* outType = */ getValuesOut().getType())
2922  .failed()) {
2923  return failure();
2924  }
2925 
2926  const ShapeAdaptor valuesInShape(getValuesIn().getType());
2927  const ShapeAdaptor indicesShape(getIndices().getType());
2928  const ShapeAdaptor inputShape(getInput().getType());
2929  const ShapeAdaptor outputShape(getValuesOut().getType());
2930 
2931  int64_t N = ShapedType::kDynamic;
2932  int64_t K = ShapedType::kDynamic;
2933  int64_t W = ShapedType::kDynamic;
2934  int64_t C = ShapedType::kDynamic;
2935  if (valuesInShape.hasRank()) {
2936  N = valuesInShape.getDimSize(0);
2937  K = valuesInShape.getDimSize(1);
2938  C = valuesInShape.getDimSize(2);
2939  }
2940  if (indicesShape.hasRank()) {
2941  const int64_t indicesN = indicesShape.getDimSize(0);
2942  W = indicesShape.getDimSize(1);
2943  if (N == ShapedType::kDynamic)
2944  N = indicesN;
2945  else if (indicesN != ShapedType::kDynamic && N != indicesN)
2946  return emitOpError() << "requires indices dimension 0 to have size " << N
2947  << ", got " << indicesN;
2948  }
2949  if (inputShape.hasRank()) {
2950  const int64_t inputN = inputShape.getDimSize(0);
2951  const int64_t inputW = inputShape.getDimSize(1);
2952  const int64_t inputC = inputShape.getDimSize(2);
2953  if (N == ShapedType::kDynamic)
2954  N = inputN;
2955  else if (inputN != ShapedType::kDynamic && N != inputN)
2956  return emitOpError() << "requires input dimension 0 to have size " << N
2957  << ", got " << inputN;
2958  if (W == ShapedType::kDynamic)
2959  W = inputW;
2960  else if (inputW != ShapedType::kDynamic && W != inputW)
2961  return emitOpError() << "requires input dimension 1 to have size " << W
2962  << ", got " << inputW;
2963 
2964  if (C == ShapedType::kDynamic)
2965  C = inputC;
2966  else if (inputC != ShapedType::kDynamic && C != inputC)
2967  return emitOpError() << "requires input dimension 2 to have size " << C
2968  << ", got " << inputC;
2969  }
2970  if (outputShape.hasRank()) {
2971  const int64_t outputN = outputShape.getDimSize(0);
2972  const int64_t outputK = outputShape.getDimSize(1);
2973  const int64_t outputC = outputShape.getDimSize(2);
2974  if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2975  N != outputN)
2976  return emitOpError() << "requires values_out dimension 0 to have size "
2977  << N << ", got " << outputN;
2978  if (K == ShapedType::kDynamic)
2979  K = outputK;
2980  else if (outputK != ShapedType::kDynamic && K != outputK)
2981  return emitOpError() << "requires values_out dimension 1 to have size "
2982  << K << ", got " << outputK;
2983  if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2984  C != outputC)
2985  return emitOpError() << "requires values_out dimension 2 to have size "
2986  << C << ", got " << outputC;
2987  }
2988  if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
2989  return emitOpError() << "requires dimensions K >= W, got K=" << K
2990  << " and W=" << W;
2991 
2992  return success();
2993 }
2994 
2995 static LogicalResult ReduceInferReturnTypes(
2996  ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
2997  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2998  int64_t axisVal = axis.getValue().getSExtValue();
2999  if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
3000  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
3001  return success();
3002  }
3003 
3004  SmallVector<int64_t> outputShape;
3005  operandShape.getDims(outputShape);
3006  outputShape[axisVal] = 1;
3007  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
3008  return success();
3009 }
3010 
3011 #define COMPATIBLE_RETURN_TYPES(OP) \
3012  bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
3013  if (l.size() != r.size() || l.size() != 1) \
3014  return false; \
3015  if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
3016  return false; \
3017  return succeeded(verifyCompatibleShape(l[0], r[0])); \
3018  }
3019 
3020 #define REDUCE_SHAPE_INFER(OP) \
3021  LogicalResult OP::inferReturnTypeComponents( \
3022  MLIRContext *context, ::std::optional<Location> location, \
3023  OP::Adaptor adaptor, \
3024  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3025  Type inputType = \
3026  llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
3027  ShapeAdaptor inputShape(adaptor.getInput().getType()); \
3028  const Properties &prop = adaptor.getProperties(); \
3029  return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
3030  inferredReturnShapes); \
3031  } \
3032  COMPATIBLE_RETURN_TYPES(OP)
3033 
3034 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
3035 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
3036 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
3037 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
3038 REDUCE_SHAPE_INFER(tosa::ReduceProductOp)
3039 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
3040 #undef REDUCE_SHAPE_INFER
3041 COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
3042 #undef COMPATIBLE_RETURN_TYPES
3043 
3044 template <typename T>
3045 static LogicalResult verifyReduceOp(T op) {
3046  // All TOSA reduce Ops have input, output and axis.
3047  TensorType inputType = op.getInput().getType();
3048  TensorType outputType = op.getOutput().getType();
3049  int32_t reduceAxis = op.getAxis();
3050 
3051  if (reduceAxis < 0) {
3052  op.emitOpError("reduce axis must not be negative");
3053  return failure();
3054  }
3055  if (inputType.hasRank()) {
3056  int64_t inputRank = inputType.getRank();
3057  // We allow for a special case where the input/output shape has rank 0 and
3058  // axis is also 0.
3059  if (reduceAxis >= inputRank && (reduceAxis != 0 || inputRank != 0)) {
3060  op.emitOpError("expect input tensor rank (")
3061  << inputRank << ") to be larger than reduce axis (" << reduceAxis
3062  << ")";
3063  return failure();
3064  }
3065  }
3066  if (outputType.hasRank()) {
3067  int64_t outputRank = outputType.getRank();
3068  if (inputType.hasRank() && outputRank != inputType.getRank()) {
3069  op.emitOpError(
3070  "expect output tensor rank to be equal to input tensor rank");
3071  return failure();
3072  }
3073  if (reduceAxis >= outputRank && (reduceAxis != 0 || outputRank != 0)) {
3074  op.emitOpError("expect output tensor rank (")
3075  << outputRank << ") to be larger than reduce axis (" << reduceAxis
3076  << ")";
3077  return failure();
3078  }
3079  // We can only verify the reduced dimension size to be 1 if this is not
3080  // the special case of output rank == 0.
3081  if (outputRank != 0) {
3082  auto outputShape = outputType.getShape();
3083  if (!outputType.isDynamicDim(reduceAxis) &&
3084  outputShape[reduceAxis] != 1) {
3085  op.emitOpError("expect reduced dimension size to be 1, got ")
3086  << outputShape[reduceAxis];
3087  return failure();
3088  }
3089  }
3090  }
3091  return success();
3092 }
3093 
3094 LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
3095 LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
3096 LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
3097 LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
3098 LogicalResult tosa::ReduceProductOp::verify() { return verifyReduceOp(*this); }
3099 LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
3100 
3101 static LogicalResult NAryInferReturnTypes(
3102  const ValueShapeRange &operands,
3103  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3104  llvm::SmallVector<int64_t> outShape;
3105  if (resolveBroadcastShape(operands, outShape).failed()) {
3106  inferredReturnShapes.push_back(ShapedTypeComponents());
3107  } else {
3108  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
3109  }
3110  return success();
3111 }
3112 
3113 #define NARY_SHAPE_INFER(OP) \
3114  LogicalResult OP::inferReturnTypeComponents( \
3115  MLIRContext *context, ::std::optional<Location> location, \
3116  ValueShapeRange operands, DictionaryAttr attributes, \
3117  OpaqueProperties properties, RegionRange regions, \
3118  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3119  return NAryInferReturnTypes(operands, inferredReturnShapes); \
3120  }
3121 
3122 NARY_SHAPE_INFER(tosa::AbsOp)
3123 NARY_SHAPE_INFER(tosa::AddOp)
3124 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
3125 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
3126 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
3127 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
3128 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
3129 NARY_SHAPE_INFER(tosa::CastOp)
3130 NARY_SHAPE_INFER(tosa::CeilOp)
3131 NARY_SHAPE_INFER(tosa::ClampOp)
3132 NARY_SHAPE_INFER(tosa::ClzOp)
3133 NARY_SHAPE_INFER(tosa::CosOp)
3134 NARY_SHAPE_INFER(tosa::ExpOp)
3135 NARY_SHAPE_INFER(tosa::FloorOp)
3136 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
3137 NARY_SHAPE_INFER(tosa::GreaterOp)
3138 NARY_SHAPE_INFER(tosa::IdentityOp)
3139 NARY_SHAPE_INFER(tosa::IntDivOp)
3140 NARY_SHAPE_INFER(tosa::LogOp)
3141 NARY_SHAPE_INFER(tosa::LogicalAndOp)
3142 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
3143 NARY_SHAPE_INFER(tosa::LogicalNotOp)
3144 NARY_SHAPE_INFER(tosa::LogicalOrOp)
3145 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
3146 NARY_SHAPE_INFER(tosa::LogicalXorOp)
3147 NARY_SHAPE_INFER(tosa::MaximumOp)
3148 NARY_SHAPE_INFER(tosa::MinimumOp)
3149 NARY_SHAPE_INFER(tosa::PowOp)
3150 NARY_SHAPE_INFER(tosa::ReciprocalOp)
3151 NARY_SHAPE_INFER(tosa::ReverseOp)
3152 NARY_SHAPE_INFER(tosa::RsqrtOp)
3153 NARY_SHAPE_INFER(tosa::SinOp)
3154 NARY_SHAPE_INFER(tosa::SelectOp)
3155 NARY_SHAPE_INFER(tosa::SubOp)
3156 NARY_SHAPE_INFER(tosa::TanhOp)
3157 NARY_SHAPE_INFER(tosa::ErfOp)
3158 NARY_SHAPE_INFER(tosa::SigmoidOp)
3159 #undef PRED_SHAPE_INFER
3160 
3161 LogicalResult tosa::NegateOp::inferReturnTypeComponents(
3162  MLIRContext *context, ::std::optional<Location> location,
3163  NegateOp::Adaptor adaptor,
3164  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3165  ShapeAdaptor inputShape(adaptor.getInput1().getType());
3166  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3167  return success();
3168 }
3169 
3170 LogicalResult tosa::NegateOp::verify() {
3171  // Verify same element type
3172  const Type input1Type = getInput1().getType();
3173  const Type outputType = getOutput().getType();
3174  if (verifySameElementTypes(*this, input1Type, outputType).failed())
3175  return failure();
3176 
3177  // Verify same shape
3178  const SmallVector<Type, 2> types = {input1Type, outputType};
3179  if (failed(verifyCompatibleShapes(types)))
3180  return emitOpError() << "requires the same shape for input1 and output";
3181 
3182  const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
3183  const Type input1ZpEType =
3184  getStorageElementTypeOrSelf(getInput1Zp().getType());
3185  if (input1EType != input1ZpEType) {
3186  return emitOpError("expect both input1 and its zero point are the same "
3187  "element type, got ")
3188  << input1EType << " and " << input1ZpEType;
3189  }
3190  const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
3191  const Type outputZpEType =
3192  getStorageElementTypeOrSelf(getOutputZp().getType());
3193  if (outputEType != outputZpEType) {
3194  return emitOpError("expect both output and its zero point are the same "
3195  "element type, got ")
3196  << outputEType << " and " << outputZpEType;
3197  }
3198 
3199  FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
3200  if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
3201  return failure();
3202 
3203  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3204  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3205  return failure();
3206 
3207  return success();
3208 }
3209 
3210 static LogicalResult poolingInferReturnTypes(
3211  ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
3212  ArrayRef<int64_t> pad,
3213  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3214  llvm::SmallVector<int64_t> outputShape;
3215  outputShape.resize(4, ShapedType::kDynamic);
3216 
3217  // We only know the rank if the input type is unranked.
3218  if (!inputShape) {
3219  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3220  return success();
3221  }
3222 
3223  // Batch and number of channels are identical for pooling layer.
3224  outputShape[0] = inputShape.getDimSize(0);
3225  outputShape[3] = inputShape.getDimSize(3);
3226 
3227  int64_t height = inputShape.getDimSize(1);
3228  int64_t width = inputShape.getDimSize(2);
3229 
3230  if (ShapedType::isStatic(height)) {
3231  int64_t padded = height + pad[0] + pad[1] - kernel[0];
3232  outputShape[1] = padded / stride[0] + 1;
3233  }
3234 
3235  if (ShapedType::isStatic(width)) {
3236  int64_t padded = width + pad[2] + pad[3] - kernel[1];
3237  outputShape[2] = padded / stride[1] + 1;
3238  }
3239 
3240  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3241  return success();
3242 }
3243 
3244 LogicalResult Conv2DOp::inferReturnTypeComponents(
3245  MLIRContext *context, ::std::optional<Location> location,
3246  Conv2DOp::Adaptor adaptor,
3247  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3248  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3249 
3250  int64_t inputWidth = ShapedType::kDynamic;
3251  int64_t inputHeight = ShapedType::kDynamic;
3252  int64_t weightWidth = ShapedType::kDynamic;
3253  int64_t weightHeight = ShapedType::kDynamic;
3254 
3255  // Input shape describes input width/height and batch.
3256 
3257  ShapeAdaptor inputShape(adaptor.getInput().getType());
3258  if (inputShape.hasRank()) {
3259  outputShape[0] = inputShape.getDimSize(0);
3260  inputHeight = inputShape.getDimSize(1);
3261  inputWidth = inputShape.getDimSize(2);
3262  }
3263 
3264  // Weight shapes describes the filter width/height and the output channels.
3265  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3266  if (weightShape.hasRank()) {
3267  outputShape[3] = weightShape.getDimSize(0);
3268  weightHeight = weightShape.getDimSize(1);
3269  weightWidth = weightShape.getDimSize(2);
3270  }
3271 
3272  // Bias shape can describe the output channels.
3273  ShapeAdaptor biasShape(adaptor.getBias().getType());
3274  if (biasShape.hasRank()) {
3275  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3276  ? biasShape.getDimSize(0)
3277  : outputShape[3];
3278  }
3279 
3280  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3281  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3282  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3283 
3284  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3285  int64_t inputSize = inputHeight + padding[0] + padding[1];
3286  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3287  int64_t unstridedResult = inputSize - filterSize + 1;
3288  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3289  }
3290 
3291  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3292  int64_t inputSize = inputWidth + padding[2] + padding[3];
3293  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3294  int64_t unstridedResult = inputSize - filterSize + 1;
3295  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3296  }
3297 
3298  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3299  return success();
3300 }
3301 
3302 LogicalResult Conv2DOp::verify() {
3303  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3304  verifyConvOpErrorIf(*this).failed())
3305  return failure();
3306  return success();
3307 }
3308 
3309 LogicalResult Conv3DOp::inferReturnTypeComponents(
3310  MLIRContext *context, ::std::optional<Location> location,
3311  Conv3DOp::Adaptor adaptor,
3312  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3313  llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
3314 
3315  int64_t inputWidth = ShapedType::kDynamic;
3316  int64_t inputHeight = ShapedType::kDynamic;
3317  int64_t inputDepth = ShapedType::kDynamic;
3318 
3319  int64_t weightWidth = ShapedType::kDynamic;
3320  int64_t weightHeight = ShapedType::kDynamic;
3321  int64_t weightDepth = ShapedType::kDynamic;
3322 
3323  // Input shape describes input width/height and batch.
3324  ShapeAdaptor inputShape(adaptor.getInput().getType());
3325  if (inputShape.hasRank()) {
3326  outputShape[0] = inputShape.getDimSize(0);
3327  inputDepth = inputShape.getDimSize(1);
3328  inputHeight = inputShape.getDimSize(2);
3329  inputWidth = inputShape.getDimSize(3);
3330  }
3331 
3332  // Weight shapes describes the filter width/height and the output channels.
3333  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3334  if (weightShape.hasRank()) {
3335  outputShape[4] = weightShape.getDimSize(0);
3336  weightDepth = weightShape.getDimSize(1);
3337  weightHeight = weightShape.getDimSize(2);
3338  weightWidth = weightShape.getDimSize(3);
3339  }
3340 
3341  // Bias shape can describe the output channels.
3342  ShapeAdaptor biasShape(adaptor.getBias().getType());
3343  if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
3344  outputShape[4] = biasShape.getDimSize(0);
3345  }
3346 
3347  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3348  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3349  llvm::ArrayRef<int64_t> pad = adaptor.getPad();
3350 
3351  if (ShapedType::isStatic(inputDepth) && ShapedType::isStatic(weightDepth)) {
3352  int32_t inputSize = inputDepth + pad[0] + pad[1];
3353  int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
3354  int32_t unstridedResult = inputSize - filterSize + 1;
3355  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3356  }
3357 
3358  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3359  int32_t inputSize = inputHeight + pad[2] + pad[3];
3360  int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
3361  int32_t unstridedResult = inputSize - filterSize + 1;
3362  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3363  }
3364 
3365  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3366  int32_t inputSize = inputWidth + pad[4] + pad[5];
3367  int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
3368  int32_t unstridedResult = inputSize - filterSize + 1;
3369  outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3370  }
3371 
3372  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3373  return success();
3374 }
3375 
3376 LogicalResult Conv3DOp::verify() {
3377  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3378  verifyConvOpErrorIf(*this).failed())
3379  return failure();
3380  return success();
3381 }
3382 
3383 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3384  MLIRContext *context, ::std::optional<Location> location,
3385  AvgPool2dOp::Adaptor adaptor,
3386  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3387  ShapeAdaptor inputShape(adaptor.getInput().getType());
3388  const Properties &prop = adaptor.getProperties();
3389  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3390  inferredReturnShapes);
3391 }
3392 
3393 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3394  MLIRContext *context, ::std::optional<Location> location,
3395  MaxPool2dOp::Adaptor adaptor,
3396  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3397  ShapeAdaptor inputShape(adaptor.getInput().getType());
3398  const Properties &prop = adaptor.getProperties();
3399  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3400  inferredReturnShapes);
3401 }
3402 
3403 LogicalResult MaxPool2dOp::verify() {
3404  if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
3405  /* outType = */ getOutput().getType())))
3406  return failure();
3407 
3408  if (failed(verifyPoolingOp(*this)))
3409  return failure();
3410 
3411  return success();
3412 }
3413 
3414 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3415  MLIRContext *context, ::std::optional<Location> location,
3416  DepthwiseConv2DOp::Adaptor adaptor,
3417  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3418  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3419 
3420  int64_t inputWidth = ShapedType::kDynamic;
3421  int64_t inputHeight = ShapedType::kDynamic;
3422  int64_t inputChannels = ShapedType::kDynamic;
3423 
3424  int64_t weightWidth = ShapedType::kDynamic;
3425  int64_t weightHeight = ShapedType::kDynamic;
3426  int64_t depthChannels = ShapedType::kDynamic;
3427 
3428  // Input shape describes input width/height and batch.
3429  ShapeAdaptor inputShape(adaptor.getInput().getType());
3430  if (inputShape.hasRank()) {
3431  outputShape[0] = inputShape.getDimSize(0);
3432  inputHeight = inputShape.getDimSize(1);
3433  inputWidth = inputShape.getDimSize(2);
3434  inputChannels = inputShape.getDimSize(3);
3435  }
3436 
3437  // Weight shapes describes the filter width/height and the output channels.
3438  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3439  if (weightShape.hasRank()) {
3440  weightHeight = weightShape.getDimSize(0);
3441  weightWidth = weightShape.getDimSize(1);
3442  inputChannels = ShapedType::isDynamic(inputChannels)
3443  ? weightShape.getDimSize(2)
3444  : inputChannels;
3445  depthChannels = weightShape.getDimSize(3);
3446  }
3447 
3448  // If both inputChannels and depthChannels are available we can determine
3449  // the output channels.
3450  if (ShapedType::isStatic(inputChannels) &&
3451  ShapedType::isStatic(depthChannels)) {
3452  outputShape[3] = inputChannels * depthChannels;
3453  }
3454 
3455  // Bias shape can describe the output channels.
3456  ShapeAdaptor biasShape(adaptor.getBias().getType());
3457  if (biasShape.hasRank()) {
3458  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3459  ? biasShape.getDimSize(0)
3460  : outputShape[3];
3461  }
3462 
3463  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3464  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3465  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3466 
3467  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3468  int64_t inputSize = inputHeight + padding[0] + padding[1];
3469  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3470  int64_t unstridedResult = inputSize - filterSize + 1;
3471  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3472  }
3473 
3474  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3475  int64_t inputSize = inputWidth + padding[2] + padding[3];
3476  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3477  int64_t unstridedResult = inputSize - filterSize + 1;
3478  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3479  }
3480 
3481  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3482  return success();
3483 }
3484 
3485 LogicalResult DepthwiseConv2DOp::verify() {
3486  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3487  verifyConvOpErrorIf(*this).failed())
3488  return failure();
3489  return success();
3490 }
3491 
3492 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3493  MLIRContext *context, ::std::optional<Location> location,
3494  TransposeConv2DOp::Adaptor adaptor,
3495  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3496  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3497 
3498  int64_t inputWidth = ShapedType::kDynamic;
3499  int64_t inputHeight = ShapedType::kDynamic;
3500  int64_t weightWidth = ShapedType::kDynamic;
3501  int64_t weightHeight = ShapedType::kDynamic;
3502 
3503  // Input shape describes input width/height and batch.
3504  ShapeAdaptor inputShape(adaptor.getInput().getType());
3505  if (inputShape.hasRank()) {
3506  outputShape[0] = ShapedType::isDynamic(outputShape[0])
3507  ? inputShape.getDimSize(0)
3508  : outputShape[0];
3509  inputHeight = inputShape.getDimSize(1);
3510  inputWidth = inputShape.getDimSize(2);
3511  }
3512 
3513  // Weight shapes describes the filter width/height and the output channels.
3514  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3515  if (weightShape.hasRank()) {
3516  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3517  ? weightShape.getDimSize(0)
3518  : outputShape[3];
3519  weightHeight = weightShape.getDimSize(1);
3520  weightWidth = weightShape.getDimSize(2);
3521  }
3522 
3523  // Bias shape can describe the output channels.
3524  ShapeAdaptor biasShape(adaptor.getInput().getType());
3525  if (biasShape.hasRank()) {
3526  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3527  ? biasShape.getDimSize(0)
3528  : outputShape[3];
3529  }
3530 
3531  llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
3532  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3533 
3534  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3535  int64_t calculateSize =
3536  (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3537  outputShape[1] =
3538  ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3539  }
3540 
3541  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3542  int64_t calculateSize =
3543  (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3544  outputShape[2] =
3545  ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3546  }
3547 
3548  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3549  return success();
3550 }
3551 
3552 LogicalResult TransposeConv2DOp::verify() {
3553  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
3554  return failure();
3555 
3556  const llvm::ArrayRef<int64_t> strides = getStride();
3557  const int64_t strideY = strides[0];
3558  const int64_t strideX = strides[1];
3559 
3560  if (strideY < 1 || strideX < 1)
3561  return emitOpError("expect all stride values to be >= 1, got [")
3562  << strides << "]";
3563 
3564  const auto checkPadAgainstKernelDim =
3565  [this](int64_t pad_value, int64_t kernel_dim_size,
3566  llvm::StringRef pad_name,
3567  llvm::StringRef kernel_dim_name) -> LogicalResult {
3568  if (pad_value <= -kernel_dim_size)
3569  return emitOpError("expected ")
3570  << pad_name << " > -" << kernel_dim_name
3571  << ", but got: " << pad_name << "=" << pad_value << " and "
3572  << kernel_dim_name << "=" << kernel_dim_size;
3573  return success();
3574  };
3575 
3576  const llvm::ArrayRef<int64_t> padding = getOutPad();
3577  const int64_t outPadTop = padding[0];
3578  const int64_t outPadBottom = padding[1];
3579  const int64_t outPadLeft = padding[2];
3580  const int64_t outPadRight = padding[3];
3581 
3582  const auto weightType =
3583  llvm::dyn_cast<RankedTensorType>(getWeight().getType());
3584 
3585  if (weightType) {
3586  const int64_t kernelHeight = weightType.getDimSize(1);
3587  if (ShapedType::isStatic(kernelHeight)) {
3588  if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
3589  "out_pad_top", "KH")))
3590  return failure();
3591 
3592  if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3593  "out_pad_bottom", "KH")))
3594  return failure();
3595  }
3596 
3597  const int64_t kernelWidth = weightType.getDimSize(2);
3598  if (ShapedType::isStatic(kernelWidth)) {
3599  if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
3600  "out_pad_left", "KW")))
3601  return failure();
3602 
3603  if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3604  "out_pad_right", "KW")))
3605  return failure();
3606  }
3607  }
3608 
3609  // Rest of the checks depend on the output type being a RankedTensorType
3610  const auto outputType =
3611  llvm::dyn_cast<RankedTensorType>(getOutput().getType());
3612  if (!outputType)
3613  return success();
3614 
3615  const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
3616  if (inputType && weightType) {
3617  const int64_t inputHeight = inputType.getDimSize(1);
3618  const int64_t kernelHeight = weightType.getDimSize(1);
3619  const int64_t outputHeight = outputType.getDimSize(1);
3620 
3621  if (ShapedType::isStatic(inputHeight) &&
3622  ShapedType::isStatic(outputHeight)) {
3623  if (outputHeight !=
3624  (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3625  return emitOpError(
3626  "dimension mismatch: expected OH == (IH - 1) * stride_y "
3627  "+ out_pad_top + out_pad_bottom + KH, but got ")
3628  << outputHeight << " != (" << inputHeight << " - 1) * "
3629  << strideY << " + " << outPadTop << " + " << outPadBottom
3630  << " + " << kernelHeight;
3631  }
3632 
3633  const int64_t inputWidth = inputType.getDimSize(2);
3634  const int64_t kernelWidth = weightType.getDimSize(2);
3635  const int64_t outputWidth = outputType.getDimSize(2);
3636 
3637  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
3638  if (outputWidth !=
3639  (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3640  return emitOpError(
3641  "dimension mismatch: expected OW == (IW - 1) * stride_x "
3642  "+ out_pad_left + out_pad_right + KW, but got ")
3643  << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
3644  << " + " << outPadLeft << " + " << outPadRight << " + "
3645  << kernelWidth;
3646  }
3647  }
3648 
3649  const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().getType());
3650 
3651  if (!biasType)
3652  return success();
3653 
3654  const int64_t biasChannels = biasType.getDimSize(0);
3655 
3656  // Skip further checks if bias is dynamic
3657  if (biasChannels == ShapedType::kDynamic)
3658  return success();
3659 
3660  const int64_t outputChannels = outputType.getDimSize(3);
3661  if (!ShapedType::isDynamic(outputChannels) &&
3662  biasChannels != outputChannels && biasChannels != 1)
3663  return emitOpError(
3664  "bias channels expected to be equal to output channels (")
3665  << outputChannels << ") or 1, got " << biasChannels;
3666 
3667  return success();
3668 }
3669 
3670 LogicalResult RescaleOp::verify() {
3671  auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
3672  if (!inputType) {
3673  emitOpError("expect shaped tensor for input, got ") << getInput().getType();
3674  return failure();
3675  }
3676 
3677  auto inputElementType =
3678  getStorageElementTypeOrSelf(inputType.getElementType());
3679  if (!mlir::isa<IntegerType>(inputElementType)) {
3680  emitOpError("expect input to have integer element type, got ")
3681  << inputElementType;
3682  return failure();
3683  }
3684 
3685  auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
3686  if (!outputType) {
3687  emitOpError("expect shaped tensor for output, got ")
3688  << getOutput().getType();
3689  return failure();
3690  }
3691 
3692  auto outputElementType =
3693  getStorageElementTypeOrSelf(outputType.getElementType());
3694  if (!mlir::isa<IntegerType>(outputElementType)) {
3695  emitOpError("expect output to have integer element type, got ")
3696  << outputElementType;
3697  return failure();
3698  }
3699 
3700  if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
3701  .failed())
3702  return failure();
3703 
3704  if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
3705  .failed())
3706  return failure();
3707 
3708  FailureOr<int64_t> maybeIZp = getInputZeroPoint();
3709  if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
3710  return failure();
3711 
3712  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3713  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3714  return failure();
3715 
3716  auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
3717  if (!multiplierType) {
3718  emitOpError("expect shaped tensor for multiplier, got ")
3719  << getMultiplier().getType();
3720  return failure();
3721  }
3722 
3723  auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
3724  if (!shiftType) {
3725  emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
3726  return failure();
3727  }
3728 
3729  // multiplier element type must be i32 for scale32 = true
3730  if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
3731  emitOpError("expect i32 element type for multiplier for scale32=true, got ")
3732  << multiplierType.getElementType();
3733  return failure();
3734  }
3735 
3736  // multiplier element type must be i16 for scale32 = false
3737  if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
3738  emitOpError(
3739  "expect i16 element type for multiplier for scale32=false, got ")
3740  << multiplierType.getElementType();
3741  return failure();
3742  }
3743 
3744  if (!inputType.hasRank())
3745  return success();
3746 
3747  // multiplier/shift must have shape = {numChannels},
3748  // where numChannel is 1 if per_channel = false
3749  // otherwise numChannel is dimension in input shape's last axis
3750  int64_t numChannels = 1;
3751  if (getPerChannel()) {
3752  if (inputType.getRank() < 1) {
3753  emitOpError("requires input to be at least rank 1 when per_channel is "
3754  "true, but got rank ")
3755  << inputType.getRank();
3756  return failure();
3757  }
3758  numChannels = inputType.getDimSize(inputType.getRank() - 1);
3759  }
3760 
3761  if (!multiplierType.hasRank())
3762  return success();
3763 
3764  ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
3765  // multiplier input has rank 1 by dialect definition
3766  if (multiplierShape[0] != ShapedType::kDynamic &&
3767  multiplierShape[0] != numChannels) {
3768  emitOpError("expect shape of { ")
3769  << numChannels << " } for multiplier input, got { "
3770  << multiplierShape[0] << " }";
3771  return failure();
3772  }
3773 
3774  if (!shiftType.hasRank())
3775  return success();
3776 
3777  ArrayRef<int64_t> shiftShape = shiftType.getShape();
3778  // shift input has rank 1 by dialect definition
3779  if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3780  emitOpError("expect shape of { ")
3781  << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
3782  return failure();
3783  }
3784 
3785  return success();
3786 }
3787 
3788 LogicalResult RescaleOp::inferReturnTypeComponents(
3789  MLIRContext *context, ::std::optional<Location> location,
3790  RescaleOp::Adaptor adaptor,
3791  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3792  ShapeAdaptor inputShape(adaptor.getInput().getType());
3793  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3794  return success();
3795 }
3796 
3797 LogicalResult IfOp::inferReturnTypeComponents(
3798  MLIRContext *context, ::std::optional<Location> location,
3799  IfOp::Adaptor adaptor,
3800  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3802  for (Region *region : adaptor.getRegions()) {
3803  for (auto &block : *region)
3804  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3805  yieldOps.push_back(returnOp);
3806  }
3807 
3808  if (yieldOps.empty())
3809  return failure();
3810 
3811  // Get the initial type information for the yield op.
3812  llvm::SmallVector<ValueKnowledge> resultKnowledge;
3813  resultKnowledge.reserve(yieldOps.front().getNumOperands());
3814  for (auto operand : yieldOps.front().getOperands()) {
3815  resultKnowledge.push_back(
3816  ValueKnowledge::getKnowledgeFromType(operand.getType()));
3817  }
3818 
3819  for (auto yieldOp : yieldOps) {
3820  if (resultKnowledge.size() != yieldOp.getNumOperands())
3821  return failure();
3822 
3823  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
3824  int32_t index = it.index();
3825  auto meet = ValueKnowledge::meet(
3826  resultKnowledge[index],
3827  ValueKnowledge::getKnowledgeFromType(it.value().getType()));
3828  if (!meet)
3829  continue;
3830  resultKnowledge[index] = meet;
3831  }
3832  }
3833 
3834  for (const ValueKnowledge &result : resultKnowledge) {
3835  inferredReturnShapes.push_back(result.getShapedTypeComponents());
3836  }
3837 
3838  return success();
3839 }
3840 
3841 LogicalResult WhileOp::inferReturnTypeComponents(
3842  MLIRContext *context, ::std::optional<Location> location,
3843  WhileOp::Adaptor adaptor,
3844  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3846  for (auto &block : adaptor.getBodyGraph())
3847  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3848  yieldOps.push_back(returnOp);
3849 
3850  // TOSA's while must have a tosa.yield as its terminator. If not found this
3851  // tosa.while is invalid.
3852  if (yieldOps.empty())
3853  return failure();
3854 
3855  // Get the initial type information from the operand types.
3856  llvm::SmallVector<ValueKnowledge> resultKnowledge;
3857  resultKnowledge.reserve(yieldOps.front().getNumOperands());
3858  for (auto operand : yieldOps.front().getOperands()) {
3859  resultKnowledge.push_back(
3860  ValueKnowledge::getKnowledgeFromType(operand.getType()));
3861  }
3862 
3863  for (auto yieldOp : yieldOps) {
3864  if (resultKnowledge.size() != yieldOp.getNumOperands())
3865  return failure();
3866 
3867  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
3868  int32_t index = it.index();
3869  if (auto meet = ValueKnowledge::meet(
3870  resultKnowledge[index],
3871  ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
3872  resultKnowledge[index] = meet;
3873  }
3874  }
3875  }
3876 
3877  for (const ValueKnowledge &result : resultKnowledge) {
3878  inferredReturnShapes.push_back(result.getShapedTypeComponents());
3879  }
3880 
3881  return success();
3882 }
3883 
3884 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
3885  if (auto vt = llvm::dyn_cast<VectorType>(getType()))
3886  return llvm::to_vector<4>(vt.getShape());
3887  return std::nullopt;
3888 }
3889 
3891  Block::BlockArgListType blocksArgs,
3892  ValueRange initializers,
3893  StringRef prefix = "") {
3894  assert(blocksArgs.size() == initializers.size() &&
3895  "expected same length of arguments and initializers");
3896  if (initializers.empty())
3897  return;
3898 
3899  parser << prefix << '(';
3900  llvm::interleaveComma(
3901  llvm::zip(blocksArgs, initializers), parser,
3902  [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
3903  parser << ")";
3904 }
3905 
3906 // parse and print of IfOp refer to the implementation of SCF dialect.
3907 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
3908  // Create the regions for 'then'.
3909  result.regions.reserve(2);
3910  Region *thenRegion = result.addRegion();
3911  Region *elseRegion = result.addRegion();
3912 
3914 
3915  if (parser.parseOperand(cond))
3916  return failure();
3917 
3920 
3921  // Parse the optional block arguments
3922  OptionalParseResult listResult =
3923  parser.parseOptionalAssignmentList(regionArgs, operands);
3924  if (listResult.has_value() && failed(listResult.value()))
3925  return failure();
3926 
3927  // Parse a colon.
3928  if (failed(parser.parseColon()))
3929  return parser.emitError(parser.getCurrentLocation(),
3930  "expected type for condition operand");
3931 
3932  // Parse the type of the condition operand
3933  Type condType;
3934  if (failed(parser.parseType(condType)))
3935  return parser.emitError(parser.getCurrentLocation(),
3936  "expected type for condition operand");
3937 
3938  // Resolve operand with provided type
3939  if (failed(parser.resolveOperand(cond, condType, result.operands)))
3940  return failure();
3941 
3942  // Parse optional block arg types
3943  if (listResult.has_value()) {
3944  FunctionType functionType;
3945 
3946  if (failed(parser.parseType(functionType)))
3947  return parser.emitError(parser.getCurrentLocation())
3948  << "expected list of types for block arguments "
3949  << "followed by arrow type and list of return types";
3950 
3951  result.addTypes(functionType.getResults());
3952 
3953  if (functionType.getNumInputs() != operands.size()) {
3954  return parser.emitError(parser.getCurrentLocation())
3955  << "expected as many input types as operands "
3956  << "(expected " << operands.size() << " got "
3957  << functionType.getNumInputs() << ")";
3958  }
3959 
3960  // Resolve input operands.
3961  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3962  parser.getCurrentLocation(),
3963  result.operands)))
3964  return failure();
3965  } else {
3966  // Parse optional results type list.
3967  if (parser.parseOptionalArrowTypeList(result.types))
3968  return failure();
3969  }
3970 
3971  // Parse the 'then' region.
3972  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
3973  return failure();
3974 
3975  // If we find an 'else' keyword then parse the 'else' region.
3976  if (!parser.parseOptionalKeyword("else")) {
3977  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
3978  return failure();
3979  }
3980 
3981  // Parse the optional attribute list.
3982  if (parser.parseOptionalAttrDict(result.attributes))
3983  return failure();
3984  return success();
3985 }
3986 
3987 void IfOp::print(OpAsmPrinter &p) {
3988  p << " " << getCondition();
3989 
3990  printInitializationList(p, getThenGraph().front().getArguments(),
3991  getInputList(), " ");
3992  p << " : ";
3993  p << getCondition().getType();
3994 
3995  if (!getInputList().empty()) {
3996  p << " (";
3997  llvm::interleaveComma(getInputList().getTypes(), p);
3998  p << ")";
3999  }
4000  p.printArrowTypeList(getResultTypes());
4001  p << " ";
4002 
4003  p.printRegion(getThenGraph());
4004 
4005  // Print the 'else' regions if it exists and has a block.
4006  auto &elseRegion = getElseGraph();
4007  if (!elseRegion.empty()) {
4008  p << " else ";
4009  p.printRegion(elseRegion);
4010  }
4011 
4012  p.printOptionalAttrDict((*this)->getAttrs());
4013 }
4014 
4015 LogicalResult IfOp::verify() {
4016  if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
4017  "'then_graph' arguments", getInputList(),
4018  "'input_list'")
4019  .failed())
4020  return failure();
4021 
4022  if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(),
4023  "'else_graph' arguments", getInputList(),
4024  "'input_list'")
4025  .failed())
4026  return failure();
4027 
4028  // MLIR will verify the absence of the terminator for us if otherwise.
4029  if (getThenGraph().front().mightHaveTerminator()) {
4030  auto thenYield =
4031  dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
4032  if (thenYield && errorIfTypeOrShapeMismatch(
4033  *this, thenYield.getInputs(), "'then_graph' results",
4034  getOutputList(), "'output_list'")
4035  .failed())
4036  return failure();
4037  }
4038 
4039  // MLIR will verify the absence of the terminator for us if otherwise.
4040  if (getElseGraph().front().mightHaveTerminator()) {
4041  auto elseYield =
4042  dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
4043  if (elseYield && errorIfTypeOrShapeMismatch(
4044  *this, elseYield.getInputs(), "'else_graph' results",
4045  getOutputList(), "'output_list'")
4046  .failed())
4047  return failure();
4048  }
4049 
4050  auto condType = getCondition().getType();
4051  if (errorIfShapeNotSizeOne(*this, condType).failed())
4052  return emitOpError() << "'condition' must be a size 1 tensor, got "
4053  << condType;
4054 
4055  return success();
4056 }
4057 
4058 LogicalResult WhileOp::verify() {
4059  if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'",
4060  getOutputList(), "'output_list'")
4061  .failed())
4062  return failure();
4063 
4064  if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(),
4065  "'cond_graph' arguments", getInputList(),
4066  "'input_list'")
4067  .failed())
4068  return failure();
4069 
4070  if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(),
4071  "'body_graph' arguments", getInputList(),
4072  "'input_list'")
4073  .failed())
4074  return failure();
4075 
4076  if (getBodyGraph().front().mightHaveTerminator()) {
4077  auto bodyYield =
4078  dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
4079  if (bodyYield && errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
4080  "'body_graph' results",
4081  getInputList(), "'input_list'")
4082  .failed())
4083  return failure();
4084  }
4085 
4086  // Condition block output must be a single element tensor with a single bool
4087  // value.
4088  if (!getCondGraph().front().mightHaveTerminator())
4089  return success();
4090 
4091  auto condYield =
4092  dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
4093  if (!condYield)
4094  return success();
4095 
4096  if (condYield.getInputs().size() != 1)
4097  return emitOpError() << "require 'cond_graph' only have one result";
4098 
4099  auto condOutType = condYield.getInputs()[0].getType();
4100  if (errorIfShapeNotSizeOne(*this, condOutType).failed())
4101  return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
4102  << condOutType;
4103 
4104  if (!getElementTypeOrSelf(condOutType).isInteger(1))
4105  return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
4106  << condOutType;
4107 
4108  return success();
4109 }
4110 
4111 LogicalResult ReverseOp::verify() {
4112  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
4113  /* outType = */ getOutput().getType())
4114  .failed())
4115  return failure();
4116  TensorType inputType = getInput1().getType();
4117  TensorType outputType = getOutput().getType();
4118  int32_t reverseAxis = getAxis();
4119 
4120  if (reverseAxis < 0)
4121  return emitOpError("expected non-negative reverse axis");
4122  if (inputType.hasRank()) {
4123  int64_t inputRank = inputType.getRank();
4124  // We allow for a special case where the input/output shape has rank 0 and
4125  // axis is also 0.
4126  if (reverseAxis >= inputRank && (reverseAxis != 0 || inputRank != 0))
4127  return emitOpError("expect input tensor rank (")
4128  << inputRank << ") to be larger than reverse axis (" << reverseAxis
4129  << ")";
4130  }
4131  if (outputType.hasRank()) {
4132  int64_t outputRank = outputType.getRank();
4133  if (inputType.hasRank() && outputRank != inputType.getRank())
4134  return emitOpError(
4135  "expect output tensor rank to be equal to input tensor rank");
4136  if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0))
4137  return emitOpError("expect output tensor rank (")
4138  << outputRank << ") to be larger than reverse axis ("
4139  << reverseAxis << ")";
4140  }
4141  return success();
4142 }
4143 
4144 LogicalResult tosa::SelectOp::verify() {
4145  // verify input2 and input3 have same element type as output
4146  if (verifySameElementTypes(*this, /* inType = */ getOnTrue().getType(),
4147  /* outType = */ getOutput().getType())
4148  .failed() ||
4149  verifySameElementTypes(*this, /* inType = */ getOnFalse().getType(),
4150  /* outType = */ getOutput().getType())
4151  .failed()) {
4152  return failure();
4153  }
4154  // verify input1 has element type of bool
4155  auto predicateType = llvm::dyn_cast<ShapedType>(getPred().getType());
4156  if (!predicateType) {
4157  return emitOpError("expect shaped tensor for input1, got ")
4158  << getInput1().getType();
4159  }
4160  auto predicateElementType = predicateType.getElementType();
4161  if (!predicateElementType.isInteger(1)) {
4162  return emitOpError("expect element type of bool for input1, got ")
4163  << predicateElementType;
4164  }
4165 
4166  return success();
4167 }
4168 
4169 LogicalResult tosa::VariableOp::verify() {
4170  StringRef symName = getName();
4171  FailureOr<tosa::VariableOp> varOp = findVariableDecl(*this, symName);
4172  if (succeeded(varOp))
4173  return emitOpError("illegal to have multiple declaration of '")
4174  << symName << "'";
4175 
4176  return success();
4177 }
4178 
4179 LogicalResult tosa::VariableReadOp::verify() {
4180  if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
4181  .failed())
4182  return failure();
4183 
4184  return success();
4185 }
4186 
4187 LogicalResult tosa::VariableWriteOp::verify() {
4188  if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
4189  .failed())
4190  return failure();
4191 
4192  return success();
4193 }
4194 
4195 // parse and print of WhileOp refer to the implementation of SCF dialect.
4196 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
4199  Region *cond = result.addRegion();
4200  Region *body = result.addRegion();
4201 
4202  OptionalParseResult listResult =
4203  parser.parseOptionalAssignmentList(regionArgs, operands);
4204  if (listResult.has_value() && failed(listResult.value()))
4205  return failure();
4206 
4207  FunctionType functionType;
4208  SMLoc typeLoc = parser.getCurrentLocation();
4209  if (failed(parser.parseColonType(functionType)))
4210  return failure();
4211 
4212  result.addTypes(functionType.getResults());
4213 
4214  if (functionType.getNumInputs() != operands.size()) {
4215  return parser.emitError(typeLoc)
4216  << "expected as many input types as operands "
4217  << "(expected " << operands.size() << " got "
4218  << functionType.getNumInputs() << ")";
4219  }
4220 
4221  // Resolve input operands.
4222  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
4223  parser.getCurrentLocation(),
4224  result.operands)))
4225  return failure();
4226 
4227  // Propagate the types into the region arguments.
4228  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
4229  regionArgs[i].type = functionType.getInput(i);
4230 
4231  return failure(parser.parseRegion(*cond, regionArgs) ||
4232  parser.parseKeyword("do") || parser.parseRegion(*body) ||
4234 }
4235 
4236 void WhileOp::print(OpAsmPrinter &parser) {
4237  printInitializationList(parser, getCondGraph().front().getArguments(),
4238  getInputList(), " ");
4239  parser << " : ";
4240  parser.printFunctionalType(getInputList().getTypes(),
4241  getResults().getTypes());
4242  parser << ' ';
4243  parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
4244  parser << " do ";
4245  parser.printRegion(getBodyGraph());
4246  parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
4247 }
4248 
4249 // Create a rank-1 const tensor for zero point of the source tensor.
4250 std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
4251  Location loc,
4252  Type srcElemType,
4253  int64_t zp) {
4254  srcElemType = getStorageElementTypeOrSelf(srcElemType);
4255  auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
4256  if (llvm::isa<FloatType>(srcElemType)) {
4257  auto zpAttr = DenseElementsAttr::get(
4258  zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
4259  return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4260  }
4261  if (llvm::isa<IntegerType>(srcElemType)) {
4262  auto zpAttr =
4263  DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp));
4264  return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4265  }
4266  llvm::errs() << "zero point is not allowed for unsupported data types\n";
4267  return std::nullopt;
4268 }
4269 
4270 //===----------------------------------------------------------------------===//
4271 // TOSA Shape and Shape Operators Helper functions.
4272 //===----------------------------------------------------------------------===//
4273 
4275  return mlir::isa<tosa::shapeType>(t);
4276 }
4277 
4278 LogicalResult
4280  int rank) {
4281  if (rank < 0)
4282  return emitError() << "invalid rank (must be >= 0): " << rank;
4283  return success();
4284 }
4285 
4287  for (auto v : op->getOperands()) {
4288  if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
4289  Operation *definingOp = v.getDefiningOp();
4290  if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
4291  return op->emitOpError("shape operand is not compile time resolvable");
4292  }
4293  }
4294  }
4295  return success();
4296 }
4297 
4299  for (auto type : op->getOperandTypes()) {
4300  if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4301  return op->emitOpError("must have operands with tosa shape type");
4302  }
4303  }
4304  for (auto type : op->getResultTypes()) {
4305  if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4306  return op->emitOpError("must have result with tosa shape type");
4307  }
4308  }
4309  return success();
4310 }
4311 
4312 LogicalResult
4316  return failure();
4317 
4318  // delegate function that returns rank of shape type
4319  auto getRank = [](const Type type) {
4320  return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4321  };
4322  auto operandTypes = op->getOperandTypes();
4323  auto resultTypes = op->getResultTypes();
4324 
4325  auto rank = getRank(*op->getOperandTypes().begin());
4326  for (auto type : operandTypes) {
4327  if (getRank(type) != rank) {
4328  return op->emitOpError("operands don't have matching ranks");
4329  }
4330  }
4331  for (auto type : resultTypes) {
4332  if (getRank(type) != rank) {
4333  return op->emitOpError("result shape has different rank than operands");
4334  }
4335  }
4336  return success();
4337 }
4338 
4339 //===----------------------------------------------------------------------===//
4340 // TOSA Shape Operators verify functions.
4341 //===----------------------------------------------------------------------===//
4342 
4343 LogicalResult tosa::ConstShapeOp::verify() {
4344  // check one dimensional rank
4345  auto valuesRank = getValues().getType().getRank();
4346  if (valuesRank != 1)
4347  return emitOpError("expect elements in attribute values with rank 1");
4348  // check that number of elements in values attr equal to rank of result shape
4349  auto count = getValues().getNumElements();
4350  auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
4351  if (count != rank && (count != 1 || rank != 0)) {
4352  return emitOpError("expect number of elements in attribute values (")
4353  << count << ") to be equal to the rank (" << rank
4354  << ") for the result shape type";
4355  }
4356  return success();
4357 }
4358 
4359 //===----------------------------------------------------------------------===//
4360 // TOSA Attribute Definitions.
4361 //===----------------------------------------------------------------------===//
4362 
4363 #define GET_ATTRDEF_CLASSES
4364 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
4365 
4366 //===----------------------------------------------------------------------===//
4367 // TOSA Type Definitions.
4368 //===----------------------------------------------------------------------===//
4369 #define GET_TYPEDEF_CLASSES
4370 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
4371 
4372 //===----------------------------------------------------------------------===//
4373 // TOSA Operator Definitions.
4374 //===----------------------------------------------------------------------===//
4375 
4376 #define GET_OP_CLASSES
4377 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
Definition: AMXDialect.cpp:85
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition: TosaOps.cpp:1272
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
Definition: TosaOps.cpp:2422
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType)
Definition: TosaOps.cpp:963
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:2995
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
Definition: TosaOps.cpp:532
static FailureOr< tosa::VariableOp > findVariableDecl(Operation *op, StringRef symName)
Definition: TosaOps.cpp:910
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
Definition: TosaOps.cpp:900
#define REDUCE_SHAPE_INFER(OP)
Definition: TosaOps.cpp:3020
static LogicalResult verifyConvOp(T op)
Definition: TosaOps.cpp:572
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
Definition: TosaOps.cpp:944
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:3210
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition: TosaOps.cpp:1386
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
Definition: TosaOps.cpp:1400
static LogicalResult verifyReduceOp(T op)
Definition: TosaOps.cpp:3045
#define NARY_SHAPE_INFER(OP)
Definition: TosaOps.cpp:3113
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
Definition: TosaOps.cpp:2492
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition: TosaOps.cpp:1250
static LogicalResult verifyConvOpErrorIf(T op)
Definition: TosaOps.cpp:718
static LogicalResult verifyConvOpModes(T op)
Definition: TosaOps.cpp:677
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:3101
static Type getStorageElementTypeOrSelf(Type type)
Definition: TosaOps.cpp:521
#define COMPATIBLE_RETURN_TYPES(OP)
Definition: TosaOps.cpp:3011
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition: TosaOps.cpp:1431
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
Definition: TosaOps.cpp:1346
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition: TosaOps.cpp:1226
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
Definition: TosaOps.cpp:1301
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
Definition: TosaOps.cpp:858
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition: TosaOps.cpp:136
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
Definition: TosaOps.cpp:2450
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Definition: TosaOps.cpp:3890
static LogicalResult verifyPoolingOp(T op)
Definition: TosaOps.cpp:1026
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition: TosaOps.cpp:515
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
Definition: TosaOps.cpp:1520
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:227
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:253
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:261
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:192
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:94
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:207
This class indicates that op operates on tosa shape types.
Definition: TosaOps.h:130
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:40
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:53
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:50
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
Definition: Operation.cpp:919
LogicalResult verifyTosaShapeOperator(Operation *op)
Definition: TosaOps.cpp:4298
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
Definition: TosaOps.cpp:4313
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
Definition: TosaOps.cpp:4286
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition: Traits.cpp:59
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Definition: QuantUtils.cpp:198
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
Definition: QuantUtils.cpp:289
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
Definition: TosaOps.cpp:225
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
Definition: QuantUtils.cpp:269
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
Definition: QuantUtils.cpp:162
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
Definition: TosaOps.cpp:250
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
Definition: QuantUtils.cpp:214
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
Definition: TosaOps.cpp:4250
bool isa_tosa_shape_type(mlir::Type t)
Definition: TosaOps.cpp:4274
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Definition: QuantUtils.cpp:243
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
Definition: TosaOps.cpp:552
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
Definition: ShapeUtils.h:33
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45