MLIR  22.0.0git
TosaOps.cpp
Go to the documentation of this file.
1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://www.mlplatform.org/tosa/tosa_spec.html
12 //
13 //===----------------------------------------------------------------------===//
14 
22 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/TypeUtilities.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 
31 #include <numeric>
32 
33 using namespace mlir;
34 using namespace mlir::tosa;
35 
36 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
38 
39 //===----------------------------------------------------------------------===//
40 // Tosa dialect interface includes.
41 //===----------------------------------------------------------------------===//
42 
43 #include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
44 #include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
45 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
46 #include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
47 
48 namespace {
49 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
50 
51 //===----------------------------------------------------------------------===//
52 // Dialect Function Inliner Interface.
53 //===----------------------------------------------------------------------===//
54 struct TosaInlinerInterface : public DialectInlinerInterface {
56 
57  //===--------------------------------------------------------------------===//
58  // Analysis Hooks.
59  //===--------------------------------------------------------------------===//
60 
61  /// All operations can be inlined by default.
62  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
63  IRMapping &map) const final {
64  return true;
65  }
66 
67  /// All regions with If and While parent operators can be inlined.
68  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
69  IRMapping &map) const final {
70  return (isa<tosa::IfOp>(dest->getParentOp()) ||
71  isa<tosa::WhileOp>(dest->getParentOp()));
72  }
73 };
74 
75 /// This class implements the bytecode interface for the Tosa dialect.
76 struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
77  TosaDialectBytecodeInterface(Dialect *dialect)
78  : BytecodeDialectInterface(dialect) {}
79 
80  //===--------------------------------------------------------------------===//
81  // Attributes
82 
83  Attribute readAttribute(DialectBytecodeReader &reader) const override {
84  return ::readAttribute(getContext(), reader);
85  }
86 
87  LogicalResult writeAttribute(Attribute attr,
88  DialectBytecodeWriter &writer) const override {
89  return ::writeAttribute(attr, writer);
90  }
91 
92  //===--------------------------------------------------------------------===//
93  // Types
94 
95  Type readType(DialectBytecodeReader &reader) const override {
96  return ::readType(getContext(), reader);
97  }
98 
99  LogicalResult writeType(Type type,
100  DialectBytecodeWriter &writer) const override {
101  return ::writeType(type, writer);
102  }
103 
104  void writeVersion(DialectBytecodeWriter &writer) const final {
105  // TODO: Populate.
106  }
107 
108  std::unique_ptr<DialectVersion>
109  readVersion(DialectBytecodeReader &reader) const final {
110  // TODO: Populate
111  reader.emitError("Dialect does not support versioning");
112  return nullptr;
113  }
114 
115  LogicalResult upgradeFromVersion(Operation *topLevelOp,
116  const DialectVersion &version) const final {
117  return success();
118  }
119 };
120 
121 } // namespace
122 
123 //===----------------------------------------------------------------------===//
124 // TOSA control flow support.
125 //===----------------------------------------------------------------------===//
126 
127 /// Returns the while loop body.
128 SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
129  return {&getBodyGraph()};
130 }
131 
132 //===----------------------------------------------------------------------===//
133 // TOSA variable operator support.
134 //===----------------------------------------------------------------------===//
135 
137  return to_vector(llvm::map_range(shape, [](int64_t dim) {
138  return dim == -1 ? ShapedType::kDynamic : dim;
139  }));
140 }
141 
142 // returns type of variable op
143 RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
144  Type elementType = variableOp.getType();
145  DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
146  auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
147  return RankedTensorType::get(shape, elementType);
148 }
149 
150 //===----------------------------------------------------------------------===//
151 // Tosa dialect initialization.
152 //===----------------------------------------------------------------------===//
153 
154 void TosaDialect::initialize() {
155  addTypes<
156 #define GET_TYPEDEF_LIST
157 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
158  >();
159  addOperations<
160 #define GET_OP_LIST
161 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
162  >();
163  addAttributes<
164 #define GET_ATTRDEF_LIST
165 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
166  >();
167  addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
168  declarePromisedInterfaces<
169  shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
170  ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
171  LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
172  LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
173  BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
174  NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
175  GreaterEqualOp, MatMulOp>();
176 }
177 
179  Type type, Location loc) {
180  // Tosa dialect constants only support ElementsAttr unlike standard dialect
181  // constant which supports all attributes.
182  if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
183  return tosa::ConstShapeOp::create(builder, loc, type,
184  llvm::cast<DenseIntElementsAttr>(value));
185  }
186  if (llvm::isa<ElementsAttr>(value))
187  return tosa::ConstOp::create(builder, loc, type,
188  llvm::cast<ElementsAttr>(value));
189  return nullptr;
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // Parsers and printers
194 //===----------------------------------------------------------------------===//
195 
196 namespace {
197 
198 ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
199  DenseElementsAttr &varShapeAttr,
200  TypeAttr &typeAttr) {
201  if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
202  if (!shapedType.hasRank())
203  return parser.emitError(parser.getCurrentLocation())
204  << "expected ranked type";
205 
206  auto elementType = shapedType.getElementType();
207  typeAttr = TypeAttr::get(elementType);
208  ArrayRef<int64_t> shape = shapedType.getShape();
209  Builder builder(parser.getContext());
210  varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
211  return success();
212  }
213  return parser.emitError(parser.getCurrentLocation())
214  << "expected shaped type";
215 }
216 
217 } // namespace
218 
219 // parses the optional initial value or type for a tosa variable
220 // with initial value:
221 // tosa.variable @name = dense<0.0> : tensor<1x8xf32>
222 //
223 // without initial value:
224 // tosa.variable @name : tensor<1x8xf32>
226  OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
227  Attribute &initialValueAttr) {
228  if (succeeded(parser.parseOptionalEqual())) {
229  if (failed(parser.parseAttribute(initialValueAttr))) {
230  return parser.emitError(parser.getCurrentLocation())
231  << "expected attribute";
232  }
233  if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
234  return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
235  typeAttr);
236  }
237  return parser.emitError(parser.getCurrentLocation())
238  << "expected Typed attr";
239  }
240 
241  initialValueAttr = nullptr;
242  Type parsedType;
243  if (failed(parser.parseColonType(parsedType))) {
244  return parser.emitError(parser.getCurrentLocation())
245  << "expected type after colon";
246  }
247  return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
248 }
249 
251  OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
252  TypeAttr typeAttr, Attribute initialValueAttr) {
253  bool needsSpace = false;
254  if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
255  auto shape =
256  convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
257  Type elementType = typeAttr.getValue();
258  RankedTensorType tensorType =
259  RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
260  auto tensorTypeAttr = TypeAttr::get(tensorType);
261  p << ": ";
262  p.printAttribute(tensorTypeAttr);
263  needsSpace = true; // subsequent attr value needs a space separator
264  }
265  if (initialValueAttr) {
266  if (needsSpace)
267  p << ' ';
268  p << "= ";
269  p.printAttribute(initialValueAttr);
270  }
271 }
272 
273 namespace {
274 
275 // parse attributes with special handling for tosa enum attributes
276 template <typename EnumType>
277 ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser,
278  NamedAttrList &outAttrs) {
279  llvm::StringRef name;
280  if (parser.parseOptionalKeyword(&name) || parser.parseEqual())
281  return failure();
282 
283  // special handling: rounding_mode accepts a *bare* RoundingMode enum
284  // keyword.
285  llvm::StringRef kw;
286  if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
287  if (name == "rounding_mode" &&
288  succeeded(parser.parseOptionalKeyword(&kw))) {
289  auto sym = symbolizeRoundingMode(kw);
290  if (!sym)
291  return parser.emitError(parser.getCurrentLocation())
292  << "invalid rounding_mode value: " << kw;
293  auto attr = RoundingModeAttr::get(parser.getContext(), sym.value());
294  outAttrs.push_back(NamedAttribute(name, attr));
295  return success();
296  }
297  }
298  // special handling: mode accepts a *bare* ResizeMode enum keyword.
299  if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
300  if (name == "mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
301  auto sym = symbolizeResizeMode(kw);
302  if (!sym)
303  return parser.emitError(parser.getCurrentLocation())
304  << "invalid resize mode value: " << kw;
305  auto attr = ResizeModeAttr::get(parser.getContext(), sym.value());
306  outAttrs.push_back(NamedAttribute(name, attr));
307  return success();
308  }
309  }
310  // special handling: nan_mode accepts a *bare* NanPropagationMode enum
311  // keyword.
312  if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
313  if (name == "nan_mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
314  auto sym = symbolizeNanPropagationMode(kw);
315  if (!sym)
316  return parser.emitError(parser.getCurrentLocation())
317  << "invalid nan_mode value: " << kw;
318  auto attr = NanPropagationModeAttr::get(parser.getContext(), sym.value());
319  outAttrs.push_back(NamedAttribute(name, attr));
320  return success();
321  }
322  }
323 
324  // Default path: parse any normal attribute literal, including fully qualified
325  // enum keyword
326  Attribute attr;
327  return parser.parseAttribute(attr, name, outAttrs);
328 }
329 
330 template <typename EnumType>
331 ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
332  // parse operands
334  if (parser.parseCommaSeparatedList(
335  [&]() { return parser.parseOperand(operands.emplace_back()); }))
336  return failure();
337 
338  // Parse { attr-dict } with special handling for enum bare token
339  NamedAttrList attrs;
340  if (succeeded(parser.parseOptionalLBrace()) &&
341  failed(parser.parseOptionalRBrace())) {
342  do {
343  if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
344  return failure();
345  } while (succeeded(parser.parseOptionalComma()));
346  if (parser.parseRBrace())
347  return failure();
348  }
349 
350  FunctionType fnTy;
351  if (parser.parseColonType(fnTy))
352  return failure();
353 
354  // Resolve operands and types
355  if (failed(parser.resolveOperands(operands, fnTy.getInputs(),
356  parser.getCurrentLocation(),
357  result.operands)))
358  return failure();
359 
360  result.addTypes(fnTy.getResult(0));
361  result.addAttributes(attrs);
362 
363  return success();
364 }
365 
366 void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) {
367  parser << namedAttr.getName().strref() << " = ";
368  auto attr = namedAttr.getValue();
369  if (auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
370  parser << roundingModeAttr.getValue();
371  } else if (auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
372  parser << resizeModeAttr.getValue();
373  } else if (auto nanPropagationModeAttr =
374  dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
375  parser << nanPropagationModeAttr.getValue();
376  } else {
377  parser.printAttribute(attr);
378  }
379 }
380 
381 // print with special handling for default valued NanPropagationMode attribute
382 void printWithNanPropagationHandling(OpAsmPrinter &parser, Operation *op) {
383  parser << " ";
384  parser.printOperands(op->getOperands());
385 
386  NamedAttrList toPrint(op->getAttrs());
387  // remove default NanPropagate attribute
388  const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
389  for (auto attr : op->getAttrs()) {
390  if (auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
391  if (nanAttr.getValue() == kDefaultNanValue) {
392  // elide from toPrint
393  toPrint.erase(attr.getName());
394  break;
395  }
396  }
397  }
398 
399  if (!toPrint.empty()) {
400  parser << " {";
401  llvm::interleaveComma(toPrint, parser, [&](const NamedAttribute namedAttr) {
402  printNamedAttr(parser, namedAttr);
403  });
404  parser << "}";
405  }
406 
407  parser << " : ";
408  parser.printFunctionalType(op);
409 }
410 
411 // print with special handling for enums: RoundingMode, ResizeMode
412 void printWithEnumHandling(OpAsmPrinter &parser, Operation *op) {
413  parser << " ";
414  parser.printOperands(op->getOperands());
415 
416  if (!op->getAttrs().empty()) {
417  parser << " {";
418  llvm::interleaveComma(op->getAttrs(), parser,
419  [&](const NamedAttribute namedAttr) {
420  printNamedAttr(parser, namedAttr);
421  });
422  parser << "}";
423  }
424 
425  parser << " : ";
426  parser.printFunctionalType(op);
427 }
428 
429 } // namespace
430 
431 ParseResult RescaleOp::parse(OpAsmParser &parser, OperationState &result) {
432  return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
433 }
434 
435 void RescaleOp::print(OpAsmPrinter &parser) {
436  printWithEnumHandling(parser, *this);
437 }
438 
439 ParseResult ApplyScaleOp::parse(OpAsmParser &parser, OperationState &result) {
440  return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
441 }
442 
443 void ApplyScaleOp::print(OpAsmPrinter &parser) {
444  printWithEnumHandling(parser, *this);
445 }
446 
447 ParseResult ResizeOp::parse(OpAsmParser &parser, OperationState &result) {
448  return parseWithEnumHandling<tosa::ResizeMode>(parser, result);
449 }
450 
451 void ResizeOp::print(OpAsmPrinter &parser) {
452  printWithEnumHandling(parser, *this);
453 }
454 
455 ParseResult ArgMaxOp::parse(OpAsmParser &parser, OperationState &result) {
456  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
457 }
458 
459 void ArgMaxOp::print(OpAsmPrinter &parser) {
460  printWithNanPropagationHandling(parser, *this);
461 }
462 
463 ParseResult MaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) {
464  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
465 }
466 
467 void MaxPool2dOp::print(OpAsmPrinter &parser) {
468  printWithNanPropagationHandling(parser, *this);
469 }
470 
471 ParseResult ClampOp::parse(OpAsmParser &parser, OperationState &result) {
472  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
473 }
474 
475 void ClampOp::print(OpAsmPrinter &parser) {
476  printWithNanPropagationHandling(parser, *this);
477 }
478 
479 ParseResult MaximumOp::parse(OpAsmParser &parser, OperationState &result) {
480  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
481 }
482 
483 void MaximumOp::print(OpAsmPrinter &parser) {
484  printWithNanPropagationHandling(parser, *this);
485 }
486 
487 ParseResult MinimumOp::parse(OpAsmParser &parser, OperationState &result) {
488  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
489 }
490 
491 void MinimumOp::print(OpAsmPrinter &parser) {
492  printWithNanPropagationHandling(parser, *this);
493 }
494 
495 ParseResult ReduceMaxOp::parse(OpAsmParser &parser, OperationState &result) {
496  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
497 }
498 
499 void ReduceMaxOp::print(OpAsmPrinter &parser) {
500  printWithNanPropagationHandling(parser, *this);
501 }
502 
503 ParseResult ReduceMinOp::parse(OpAsmParser &parser, OperationState &result) {
504  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
505 }
506 
507 void ReduceMinOp::print(OpAsmPrinter &parser) {
508  printWithNanPropagationHandling(parser, *this);
509 }
510 
511 //===----------------------------------------------------------------------===//
512 // Tosa utilities.
513 //===----------------------------------------------------------------------===//
514 
515 std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
516  if (lhs % rhs != 0)
517  return std::nullopt;
518  return lhs / rhs;
519 }
520 
522  auto srcType = getElementTypeOrSelf(type);
523  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
524  srcType = quantType.getStorageType();
525  return srcType;
526 }
527 
529  return getStorageElementTypeOrSelf(value.getType());
530 }
531 
532 static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
533  Value valZp, StringRef name) {
535  Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
536 
537  bool bothInts =
538  mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
539  bool sameBitWidth =
540  (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
541 
542  if (!bothInts || !sameBitWidth) {
543  return op->emitOpError()
544  << "expected " << name << " and " << name
545  << "_zp to both be integer of the same bitwidth, but got " << eType
546  << " vs. " << eZpType;
547  }
548  return success();
549 }
550 
551 // Create a pad-const const tensor with value of `val` of required data-type
553  Value src, int32_t val) {
554  const auto srcType = getElementTypeOrSelf(src);
555  const auto srcElemType = getStorageElementTypeOrSelf(src);
556  const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
557  const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
558  const auto padConstAttr{
559  llvm::isa<FloatType>(srcElemType)
560  ? DenseElementsAttr::get(padConstEType,
561  builder.getFloatAttr(srcElemType, val))
562  : DenseElementsAttr::get(padConstEType,
563  builder.getIntegerAttr(srcElemType, val))};
564  return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
565 }
566 
567 //===----------------------------------------------------------------------===//
568 // TOSA Operator Verifiers.
569 //===----------------------------------------------------------------------===//
570 
571 template <typename T>
572 static LogicalResult verifyConvOp(T op) {
573  const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
574  const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
575 
576  auto inputEType = inputType.getElementType();
577  auto weightEType = weightType.getElementType();
578  auto biasEType =
579  llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
580  auto resultEType =
581  llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
582  bool biasIsFloat = llvm::isa<FloatType>(biasEType);
583  bool resultIsFloat = llvm::isa<FloatType>(resultEType);
584 
585  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
586  inputEType = quantType.getStorageType();
587 
588  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
589  weightEType = quantType.getStorageType();
590 
591  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
592  biasEType = quantType.getStorageType();
593 
594  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
595  resultEType = quantType.getStorageType();
596 
597  if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
598  // for now, only enforce bias element type == result element type for
599  // float types.
600  op.emitOpError(
601  "expect both bias and result to have same element type, got ")
602  << biasEType << " and " << resultEType;
603  return failure();
604  }
605 
606  if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
607  isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
608  if (inputEType != weightEType) {
609  op.emitOpError(
610  "expect both input and weight to have same element type, got ")
611  << inputEType << " and " << weightEType;
612  return failure();
613  }
614  }
615 
616  bool inputIsFloat = llvm::isa<FloatType>(inputEType);
617  bool weightIsFloat = llvm::isa<FloatType>(weightEType);
618 
619  // Either both must be float or both non-float.
620  if (inputIsFloat != weightIsFloat) {
621  op.emitOpError(
622  "expect both input and weight to be float or not together, got ")
623  << inputEType << " and " << weightEType;
624  return failure();
625  }
626 
627  auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
628  if (inputEType != inputZpEType) {
629  return op.emitOpError("expect both input and its zero point are the same "
630  "element type, got ")
631  << inputEType << " and " << inputZpEType;
632  }
633 
634  auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
635  if (weightEType != weightZpEType) {
636  return op.emitOpError("expect both weight and its zero point are the same "
637  "element type, got ")
638  << weightEType << " and " << weightZpEType;
639  }
640 
641  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
642  if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
643  return failure();
644 
645  FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
646  if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
647  return failure();
648 
649  return success();
650 }
651 
652 LogicalResult tosa::ConstOp::verify() {
653 
654  auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().getType());
655  auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
656 
657  if (!attrType || !outputType) {
658  emitOpError("expected tensors for attr/result type");
659  return failure();
660  }
661 
662  if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
663  outputType.getElementType())) {
664  if (result.getStorageType() == attrType.getElementType())
665  return success();
666  }
667 
668  if (attrType.getElementType() != outputType.getElementType()) {
669  emitOpError("expected same attr/result element types");
670  return failure();
671  }
672 
673  return success();
674 }
675 
676 template <typename T>
677 static LogicalResult verifyConvOpModes(T op) {
678  auto inputEType =
679  llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
680 
681  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
682  inputEType = quantType.getStorageType();
683 
684  auto accType = op.getAccType();
685  if (inputEType.isInteger(8) && !accType.isInteger(32))
686  return op.emitOpError("accumulator type for i8 tensor is not i32");
687 
688  if (inputEType.isInteger(16) && !accType.isInteger(48))
689  return op.emitOpError("accumulator type for i16 tensor is not i48");
690 
691  if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
692  return op.emitOpError("accumulator type for f8 tensor is not f16");
693 
694  if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
695  return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
696 
697  if (inputEType.isBF16() && !accType.isF32())
698  return op.emitOpError("accumulator type for bf16 tensor is not f32");
699 
700  if (inputEType.isF32() && !accType.isF32())
701  return op.emitOpError("accumulator type for f32 tensor is not f32");
702 
703  auto resultEType =
704  llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
705 
706  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
707  resultEType = quantType.getStorageType();
708 
709  return success();
710 }
711 
712 //===----------------------------------------------------------------------===//
713 // ERROR_IF functions.
714 // ERROR_IF is a predicate that must set an error if the condition holds.
715 //===----------------------------------------------------------------------===//
716 
717 template <typename T>
718 static LogicalResult verifyConvOpErrorIf(T op) {
719  llvm::ArrayRef<int64_t> padding = op.getPad();
720  if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
721  return op.emitOpError("expect all padding values to be >= 0, got ")
722  << padding;
723 
724  llvm::ArrayRef<int64_t> strides = op.getStride();
725  if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
726  return op.emitOpError("expect all stride values to be >= 1, got ")
727  << strides;
728 
729  llvm::ArrayRef<int64_t> dilations = op.getDilation();
730  if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
731  return op.emitOpError("expect all dilation values to be >= 1, got ")
732  << dilations;
733 
734  const RankedTensorType outputType =
735  llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
736  if (!outputType)
737  // Skip following checks if output is not ranked
738  return success();
739 
740  const RankedTensorType inputType =
741  llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
742  const RankedTensorType weightType =
743  llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
744 
745  if (inputType && weightType) {
746  const auto verifyOutputSize =
747  [&op](const int64_t inputSize, const int64_t kernelSize,
748  const int64_t outputSize, const int64_t padBefore,
749  const int64_t padAfter, const int64_t stride,
750  const int64_t dilation, const llvm::StringRef dimName,
751  const llvm::StringRef dimAxis,
752  const llvm::StringRef padBeforeName,
753  const llvm::StringRef padAfterName) -> LogicalResult {
754  if (inputSize == ShapedType::kDynamic ||
755  kernelSize == ShapedType::kDynamic)
756  return success();
757 
758  // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
759 
760  const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
761  inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
762  stride);
763  if (!calculatedOutSizeMinusOne.has_value())
764  return op.emitOpError("expected input_")
765  << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
766  << padAfterName << " - (kernel_" << dimName
767  << " - 1) * dilation_" << dimAxis
768  << " to be wholly divisible by stride_" << dimAxis << ", got ("
769  << inputSize << " - 1 + " << padBefore << " + " << padAfter
770  << " - (" << kernelSize << " - 1) * " << dilation << ") / "
771  << stride;
772 
773  const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
774  if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
775  return op.emitOpError("calculated output ")
776  << dimName << " did not match expected: "
777  << "calculated=" << calculatedOutSize
778  << ", expected=" << outputSize;
779 
780  return success();
781  };
782 
783  // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
784  if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
785  if (failed(verifyOutputSize(
786  inputType.getDimSize(1), weightType.getDimSize(1),
787  outputType.getDimSize(1), padding[0], padding[1], strides[0],
788  dilations[0], "height", "y", "top", "bottom")))
789  return failure();
790 
791  if (failed(verifyOutputSize(
792  inputType.getDimSize(2), weightType.getDimSize(2),
793  outputType.getDimSize(2), padding[2], padding[3], strides[1],
794  dilations[1], "width", "x", "left", "right")))
795  return failure();
796  }
797 
798  // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
799  if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
800  if (failed(verifyOutputSize(
801  inputType.getDimSize(1), weightType.getDimSize(0),
802  outputType.getDimSize(1), padding[0], padding[1], strides[0],
803  dilations[0], "height", "y", "top", "bottom")))
804  return failure();
805 
806  if (failed(verifyOutputSize(
807  inputType.getDimSize(2), weightType.getDimSize(1),
808  outputType.getDimSize(2), padding[2], padding[3], strides[1],
809  dilations[1], "width", "x", "left", "right")))
810  return failure();
811  }
812 
813  // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
814  if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
815  if (failed(verifyOutputSize(
816  inputType.getDimSize(1), weightType.getDimSize(1),
817  outputType.getDimSize(1), padding[0], padding[1], strides[0],
818  dilations[0], "depth", "d", "front", "back")))
819  return failure();
820 
821  if (failed(verifyOutputSize(
822  inputType.getDimSize(2), weightType.getDimSize(2),
823  outputType.getDimSize(2), padding[2], padding[3], strides[1],
824  dilations[1], "height", "y", "top", "bottom")))
825  return failure();
826 
827  if (failed(verifyOutputSize(
828  inputType.getDimSize(3), weightType.getDimSize(3),
829  outputType.getDimSize(3), padding[4], padding[5], strides[2],
830  dilations[2], "width", "x", "left", "right")))
831  return failure();
832  }
833  }
834 
835  const RankedTensorType biasType =
836  llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
837  if (!biasType)
838  // Skip following checks if bias is not ranked
839  return success();
840 
841  const int64_t biasChannels = biasType.getDimSize(0);
842  const int64_t outputChannels =
843  outputType.getDimSize(outputType.getRank() - 1);
844  if (biasChannels == ShapedType::kDynamic ||
845  outputChannels == ShapedType::kDynamic)
846  // Skip following checks if biasChannels or outputChannels is dynamic dim
847  return success();
848 
849  if (biasChannels != outputChannels && biasChannels != 1)
850  return op.emitOpError(
851  "bias channels expected to be equal to output channels (")
852  << outputChannels << ") or 1, got " << biasChannels;
853 
854  return success();
855 }
856 
857 // Verify whether same type and shape of the given two types.
858 static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
859  StringRef name1, Type type2,
860  StringRef name2) {
861  auto shapeType1 = dyn_cast<ShapedType>(type1);
862  auto shapeType2 = dyn_cast<ShapedType>(type2);
863  if (!shapeType1 || !shapeType2)
864  return failure();
865 
866  auto elemType1 = shapeType1.getElementType();
867  auto elemType2 = shapeType2.getElementType();
868  if (elemType1 != elemType2)
869  return op->emitOpError()
870  << "require same element type for " << name1 << " (" << elemType1
871  << ") and " << name2 << " (" << elemType2 << ")";
872 
873  if (failed(verifyCompatibleShape(type1, type2)))
874  return op->emitOpError()
875  << "require same shapes for " << name1 << " (" << type1 << ") and "
876  << name2 << " (" << type2 << ")";
877 
878  return success();
879 }
880 
881 // Verify whether same length, type, and shape of the given two tensor lists.
882 static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
883  StringRef name1,
884  ValueRange list2,
885  StringRef name2) {
886  if (list1.size() != list2.size())
887  return op->emitOpError()
888  << "require same number of values in " << name1 << " ("
889  << list1.size() << ") and " << name2 << " (" << list2.size() << ")";
890 
891  for (auto [type1, type2] :
892  llvm::zip_equal(list1.getTypes(), list2.getTypes())) {
893  if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
894  return failure();
895  }
896 
897  return success();
898 }
899 
900 static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
901  ShapeAdaptor shapeAdaptor(type);
902  if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
903  return success();
904 
905  return shapeAdaptor.getNumElements() == 1 ? success() : failure();
906 }
907 
908 // Returns the first declaration point prior to this operation or failure if
909 // not found.
910 static FailureOr<tosa::VariableOp> findVariableDecl(Operation *op,
911  StringRef symName) {
912  ModuleOp module = op->getParentOfType<ModuleOp>();
913  tosa::VariableOp varOp = nullptr;
914 
915  // TODO: Adopt SymbolTable trait to Varible ops.
916  // Currently, the variable's definition point is searched via walk(),
917  // starting from the top-level ModuleOp and stopping at the point of use. Once
918  // TOSA control flow and variable extensions reach the complete state, may
919  // leverage MLIR's Symbol Table functionality to look up symbol and enhance
920  // the search to a TOSA specific graph traversal over the IR structure.
921  module.walk([&](Operation *tempOp) {
922  // Reach this op itself.
923  if (tempOp == op) {
924  return WalkResult::interrupt();
925  }
926 
927  if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
928  if (symName == tosaOp.getName()) {
929  varOp = tosaOp;
930  return WalkResult::interrupt();
931  }
932  }
933 
934  return WalkResult::advance();
935  });
936 
937  if (varOp)
938  return varOp;
939 
940  return failure();
941 }
942 
943 template <typename T>
944 static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
945  StringRef symName = op.getName();
946  FailureOr<tosa::VariableOp> varOp = findVariableDecl(op, symName);
947  if (failed(varOp))
948  return op->emitOpError("'")
949  << symName << "' has not been declared by 'tosa.variable'";
950 
951  // Verify type and shape
952  auto variableType = getVariableType(varOp.value());
953  if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
954  "the input tensor")
955  .failed())
956  return failure();
957 
958  return success();
959 }
960 
961 // verify that inType and outType have same element types
962 template <typename T>
963 static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
964  auto inputType = llvm::dyn_cast<TensorType>(inType);
965  auto outputType = llvm::dyn_cast<TensorType>(outType);
966  if (!inputType) {
967  op.emitOpError("expect shaped tensor for input, got ") << inType;
968  return failure();
969  }
970  if (!outputType) {
971  op.emitOpError("expect shaped tensor for output, got ") << outType;
972  return failure();
973  }
974  auto inputElementType = inputType.getElementType();
975  auto outputElementType = outputType.getElementType();
976  auto inputQuantType =
977  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
978  auto outputQuantType =
979  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
980  if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
981  (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
982  inputElementType != outputElementType) {
983  // only check if both element types are int/index/float/UniformQuantized
984  // eg, not sure how to check quant::QuantizedType
985  // this happens in test_conv2d_q_grouped_convolution in
986  // tfl-to-tosa-pipeline.mlir
987  op.emitOpError("expect input and output to have same element type, got ")
988  << inputElementType << " and " << outputElementType;
989  return failure();
990  }
991  return success();
992 }
993 
994 LogicalResult tosa::ArgMaxOp::verify() {
995  const ShapedType resultType = llvm::cast<ShapedType>(getType());
996 
997  // Ensure output is of 32-bit integer
998  if (const auto resultETy = resultType.getElementType();
999  !resultETy.isIntOrIndex())
1000  return emitOpError("result tensor is not of integer type");
1001 
1002  const auto inputType = llvm::cast<ShapedType>(getInput().getType());
1003  if (!inputType.hasRank())
1004  return success();
1005 
1006  // Ensure axis is within the tensor rank
1007  const int64_t axis = getAxisAttr().getInt();
1008  if (((axis < 0) || axis >= inputType.getRank()))
1009  return emitOpError("specified axis is outside the rank of the tensor");
1010 
1011  if (!resultType.hasRank())
1012  return success();
1013 
1014  const ArrayRef<int64_t> inputShape = inputType.getShape();
1015  const ArrayRef<int64_t> outputShape = resultType.getShape();
1016  llvm::SmallVector<int64_t> expectedOutputShape(inputShape);
1017  expectedOutputShape.erase(expectedOutputShape.begin() + axis);
1018  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
1019  return emitOpError("expected output shape '")
1020  << expectedOutputShape << "', got '" << outputShape << "'";
1021 
1022  return success();
1023 }
1024 
1025 template <typename T>
1026 static LogicalResult verifyPoolingOp(T op) {
1027  const llvm::ArrayRef<int64_t> kernel = op.getKernel();
1028  if (llvm::any_of(kernel, [](int64_t s) { return s < 1; }))
1029  return op.emitOpError("expect all kernel values to be >= 1, got ")
1030  << kernel;
1031 
1032  const llvm::ArrayRef<int64_t> strides = op.getStride();
1033  if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
1034  return op.emitOpError("expect all stride values to be >= 1, got ")
1035  << strides;
1036 
1037  const llvm::ArrayRef<int64_t> padding = op.getPad();
1038  if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
1039  return op.emitOpError("expect all padding values to be >= 0, got ")
1040  << padding;
1041 
1042  // Padding must be less than kernel size to avoid a divide-by-zero
1043  const int64_t kernelX = kernel[1];
1044  const int64_t padLeft = padding[2];
1045  const int64_t padRight = padding[3];
1046  if (padRight >= kernelX || padLeft >= kernelX)
1047  return op.emitOpError("expected left/right padding to be less than the "
1048  "width of the kernel, got pad_left=")
1049  << padLeft << ", pad_right=" << padRight << ", kernel_x=" << kernelX;
1050 
1051  const int64_t kernelY = kernel[0];
1052  const int64_t padTop = padding[0];
1053  const int64_t padBottom = padding[1];
1054  if (padTop >= kernelY || padBottom >= kernelY)
1055  return op.emitOpError("expected top/bottom padding to be less than the "
1056  "height of the kernel, got pad_top=")
1057  << padTop << ", pad_bottom=" << padBottom
1058  << ", kernel_y=" << kernelY;
1059 
1060  const auto inputType =
1061  llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
1062  const auto outputType =
1063  llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
1064  if (!inputType || !outputType)
1065  return success();
1066 
1067  const auto verifyOutputSize =
1068  [&op](const int64_t inputSize, const int64_t outputSize,
1069  const int64_t kernelSize, const int64_t strideSize,
1070  const int64_t padBefore, const int64_t padAfter,
1071  const llvm::StringRef dimName, const llvm::StringRef dimAxis,
1072  const llvm::StringRef padBeforeName,
1073  const llvm::StringRef padAfterName) -> LogicalResult {
1074  if (ShapedType::isDynamic(inputSize))
1075  return success();
1076 
1077  const std::optional<int64_t> calculatedOutSizeMinusOne =
1078  idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1079  if (!calculatedOutSizeMinusOne.has_value())
1080  return op.emitOpError("expected input_")
1081  << dimName << " + pad_" << padBeforeName << " + pad_"
1082  << padAfterName << " - kernel_" << dimAxis
1083  << " to be wholly divisible by stride_" << dimAxis << ", got ("
1084  << inputSize << " + " << padBefore << " + " << padAfter << " - "
1085  << kernelSize << ") / " << strideSize;
1086 
1087  const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1088  if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1089  return op.emitOpError("calculated output ")
1090  << dimName << " did not match expected: "
1091  << "calculated=" << calculatedOutSize
1092  << ", expected=" << outputSize;
1093 
1094  return success();
1095  };
1096 
1097  if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
1098  kernel[0], strides[0], padding[0], padding[1],
1099  "height", "y", "top", "bottom")))
1100  return failure();
1101 
1102  if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
1103  kernel[1], strides[1], padding[2], padding[3],
1104  "width", "x", "left", "right")))
1105  return failure();
1106 
1107  return success();
1108 }
1109 
1110 LogicalResult tosa::AvgPool2dOp::verify() {
1111  if (failed(verifyPoolingOp(*this)))
1112  return failure();
1113 
1114  const Type inputETy = getStorageElementTypeOrSelf(getInput().getType());
1115  const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType());
1116  const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType());
1117  const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType());
1118 
1119  auto accType = getAccType();
1120  if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1121  return emitOpError("accumulator type for integer tensor is not i32");
1122 
1123  if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
1124  return emitOpError("accumulator type for f16 tensor is not f16/f32");
1125 
1126  if (inputETy.isBF16() && !accType.isF32())
1127  return emitOpError("accumulator type for bf16 tensor is not f32");
1128 
1129  if (inputETy.isF32() && !accType.isF32())
1130  return emitOpError("accumulator type for f32 tensor is not f32");
1131 
1132  if (inputETy != inputZpETy)
1133  return emitOpError("expect both input and its zero point are the same "
1134  "element type, got ")
1135  << inputETy << " and " << inputZpETy;
1136 
1137  if (resultETy != outputZpETy)
1138  return emitOpError("expect both output and its zero point are the same "
1139  "element type, got ")
1140  << resultETy << " and " << outputZpETy;
1141 
1142  FailureOr<int64_t> maybeIZp = getInputZeroPoint();
1143  if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
1144  return failure();
1145 
1146  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1147  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
1148  return failure();
1149 
1150  return success();
1151 }
1152 
1153 LogicalResult tosa::ClampOp::verify() {
1154  mlir::Type inputETy =
1155  llvm::cast<ShapedType>(getInput().getType()).getElementType();
1156  if (auto quantType =
1157  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1158  inputETy = quantType.getStorageType();
1159  }
1160  mlir::Type outputETy =
1161  llvm::cast<ShapedType>(getOutput().getType()).getElementType();
1162  if (auto quantType =
1163  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1164  outputETy = quantType.getStorageType();
1165  }
1166  if (inputETy != outputETy)
1167  return emitOpError("input/output element types are incompatible.");
1168 
1169  auto maxValAttr = getMaxValAttr();
1170  auto minValAttr = getMinValAttr();
1171 
1172  unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
1173 
1174  if (inputETy.isInteger(dataTypeBitWidth)) {
1175  // if input datatype is integer, check that the min_val/max_val attributes
1176  // are integer attributes, and that their type is the same as the input's
1177  // datatype
1178  auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1179  auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1180  if (!intMaxValAttr || !intMinValAttr ||
1181  (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1182  (intMaxValAttr.getType() != inputETy))
1183  return emitOpError("min/max attributes types are incompatible with "
1184  "input/output element types.");
1185 
1186  const bool isUnsigned = inputETy.isUnsignedInteger();
1187  const bool isBoolean = inputETy.isInteger(1);
1188  const APInt minVal = intMinValAttr.getValue();
1189  const APInt maxVal = intMaxValAttr.getValue();
1190  if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1191  return emitOpError("expected min_val <= max_val, got min_val=")
1192  << minValAttr << ", max_val=" << maxValAttr;
1193  } else {
1194  // otherwise, input datatype is float, check that the min_val/max_val
1195  // attributes share the same type and that their type is the same as the
1196  // input's datatype
1197  auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1198  auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1199  if (!floatMaxValAttr || !floatMinValAttr ||
1200  (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1201  (floatMaxValAttr.getType() != inputETy))
1202  return emitOpError("min/max attributes types are incompatible with "
1203  "input/output element types.");
1204 
1205  const APFloat minVal = floatMinValAttr.getValue();
1206  const APFloat maxVal = floatMaxValAttr.getValue();
1207  if (minVal.isNaN() || maxVal.isNaN())
1208  return emitOpError("min/max attributes should not be 'NaN', got min_val=")
1209  << minValAttr << ", max_val=" << maxValAttr;
1210 
1211  if (maxVal < minVal)
1212  return emitOpError("expected min_val <= max_val, got min_val=")
1213  << minValAttr << ", max_val=" << maxValAttr;
1214  }
1215 
1216  return success();
1217 }
1218 
1219 //===----------------------------------------------------------------------===//
1220 // TOSA Operator Quantization Builders.
1221 //===----------------------------------------------------------------------===//
1222 
1223 /// This builder is called on all convolution operators except TransposeConv,
1224 /// which has specialized output shape semantics. The builder also defines the
1225 /// bitwidth of the output given the bit width of the input & weight content.
1226 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
1227  Type outputType, Value input, Value weight,
1228  Value bias, DenseI64ArrayAttr pad,
1229  DenseI64ArrayAttr stride,
1230  DenseI64ArrayAttr dilation,
1231  TypeAttr accType) {
1232  auto zps = createZPsAsConst(builder, input, weight);
1233  result.addOperands({input, weight, bias, zps.first, zps.second});
1234  result.addAttribute("pad", pad);
1235  result.addAttribute("stride", stride);
1236  result.addAttribute("dilation", dilation);
1237  result.addAttribute("acc_type", accType);
1238  Type finalOutputType = outputType;
1239  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1240  if (quantAttr) {
1241  finalOutputType =
1242  buildConvOpResultTypeInfo(builder, outputType, input, weight);
1243  }
1244  result.addTypes(finalOutputType);
1245 }
1246 
1247 /// Handles tosa.transpose_conv2d which has outpad and output shape
1248 /// attributes.
1249 static void
1251  Type outputType, Value input, Value weight,
1252  Value bias, DenseI64ArrayAttr outpad,
1253  DenseI64ArrayAttr stride, TypeAttr accType) {
1254  auto zps = createZPsAsConst(builder, input, weight);
1255  result.addOperands({input, weight, bias, zps.first, zps.second});
1256  result.addAttribute("out_pad", outpad);
1257  result.addAttribute("stride", stride);
1258  result.addAttribute("acc_type", accType);
1259  Type finalOutputType = outputType;
1260  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1261  if (quantAttr) {
1262  finalOutputType =
1263  buildConvOpResultTypeInfo(builder, outputType, input, weight);
1264  }
1265  result.addTypes(finalOutputType);
1266 }
1267 
1268 /// The tosa.matmul op is also intended to be generated where a fully_connected
1269 /// op must be constructed where the weight is not a constant. In this case,
1270 /// the fully_connected op must be expressed using matmul.
1271 /// TODO: Add link to the leglization document explaining this.
1273  OperationState &result, Type outputType,
1274  Value a, Value b) {
1275  auto zps = createZPsAsConst(builder, a, b);
1276  result.addOperands({a, b, zps.first, zps.second});
1277 
1278  Type finalOutputType{outputType};
1279  if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
1280  auto eType = getStorageElementTypeOrSelf(a.getType());
1281  auto inputBits = eType.getIntOrFloatBitWidth();
1282 
1283  auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1284  assert(outputShapedType && "Output must be a shaped type");
1285 
1286  IntegerType accElementType;
1287  if (inputBits == 16)
1288  accElementType = builder.getIntegerType(48);
1289  else
1290  accElementType = builder.getI32Type();
1291 
1292  finalOutputType = outputShapedType.clone(accElementType);
1293  }
1294  result.addTypes(finalOutputType);
1295 }
1296 
1297 /// Both the tosa.avg_pool2d and unary ops use the same
1298 /// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
1299 /// has additional parameters not part of the unary ops.
1300 static void
1302  Type outputType, Value input,
1303  DenseArrayAttr kernel, DenseArrayAttr stride,
1304  DenseArrayAttr pad, TypeAttr accType) {
1305  const Location loc{result.location};
1306  int64_t inputZp{0};
1307  int64_t outputZp{0};
1308 
1309  if (auto quantAttr =
1310  buildUnaryOpQuantizationAttr(builder, input, outputType)) {
1311  inputZp = quantAttr.getInputZp();
1312  outputZp = quantAttr.getOutputZp();
1313  }
1314  const std::optional<Value> inputZpOp =
1315  createZeroPointTensor(builder, loc, input.getType(), inputZp);
1316  if (!inputZpOp) {
1317  (void)emitError(
1318  loc,
1319  "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1320  }
1321  const std::optional<Value> outputZpOp =
1322  createZeroPointTensor(builder, loc, outputType, outputZp);
1323  if (!outputZpOp) {
1324  (void)emitError(loc, "Failed to create output zero point tensor for "
1325  "quantized AVG_POOL2D op");
1326  }
1327 
1328  if (inputZpOp && outputZpOp) {
1329  result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1330  } else {
1331  // failed to create one or more zero points above: just add input as
1332  // operands this will trigger error in building the op because of missing
1333  // zero points
1334  result.addOperands({input});
1335  }
1336  result.addAttribute("kernel", kernel);
1337  result.addAttribute("stride", stride);
1338  result.addAttribute("pad", pad);
1339  result.addAttribute("acc_type", accType);
1340  result.types.push_back(outputType);
1341 }
1342 
1343 /// This builder is called on single-parameter negate operator
1344 /// to construct input and output zero points based on their
1345 /// types.
1347  OperationState &result, Type outputType,
1348  Value input) {
1349  const Location loc{result.location};
1350  int64_t input1Zp{0};
1351  int64_t outputZp{0};
1352  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
1353  if (quantAttr) {
1354  input1Zp = quantAttr.getInputZp();
1355  outputZp = quantAttr.getOutputZp();
1356  }
1357  const std::optional<Value> input1ZpOp =
1358  createZeroPointTensor(builder, loc, input.getType(), input1Zp);
1359  if (!input1ZpOp) {
1360  (void)emitError(
1361  loc, "Failed to create input1 zero point for quantized NEGATE op");
1362  }
1363 
1364  const std::optional<Value> outputZpOp =
1365  createZeroPointTensor(builder, loc, input.getType(), outputZp);
1366  if (!outputZpOp) {
1367  (void)emitError(
1368  loc, "Failed to create output zero point for quantized NEGATE op");
1369  }
1370 
1371  if (input1ZpOp && outputZpOp) {
1372  result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1373  } else {
1374  // failed to create one or more zero points above: just add input as
1375  // operands. This will trigger error in building the op because of
1376  // missing zero points
1377  result.addOperands({input});
1378  }
1379 
1380  result.types.push_back(outputType);
1381 }
1382 
1383 /// This builder is called on TOSA pad operator that needs to create its own
1384 /// OptionalAttr quantization_attr parameter to scale the padding values
1385 /// correctly. No pad_const is interpreted as zero-padding.
1386 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
1387  Type outputType, Value input,
1388  Value paddings) {
1389  const Location loc{result.location};
1390  int32_t zp{0};
1391  const auto quantAttr = buildPadOpQuantizationAttr(builder, input);
1392  if (quantAttr) {
1393  zp = static_cast<int32_t>(quantAttr.getInputZp());
1394  }
1395  const auto padConstOp{createPadConstTensor(builder, loc, input, zp)};
1396  result.addOperands({input, paddings, padConstOp});
1397  result.types.push_back(outputType);
1398 }
1399 
1400 static void buildVariableOp(OpBuilder &builder, OperationState &result,
1401  StringRef name, Type variableType,
1402  Attribute initialValue) {
1403  const Location loc{result.location};
1404  auto nameAttr = builder.getStringAttr(name);
1405 
1406  auto shapedType = dyn_cast<ShapedType>(variableType);
1407  if (!shapedType) {
1408  (void)emitError(loc, "variable type must be a shaped type");
1409  return;
1410  }
1411  if (!shapedType.hasRank()) {
1412  (void)emitError(loc, "variable type must be a ranked type");
1413  return;
1414  }
1415 
1416  auto elementType = shapedType.getElementType();
1417  auto elementTypeAttr = TypeAttr::get(elementType);
1418  ArrayRef<int64_t> shape = shapedType.getShape();
1419  auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
1420 
1421  result.addAttribute("name", nameAttr);
1422  result.addAttribute("var_shape", varShapeAttr);
1423  result.addAttribute("type", elementTypeAttr);
1424  result.addAttribute("initial_value", initialValue);
1425 }
1426 
1427 //===----------------------------------------------------------------------===//
1428 // TOSA Operator Return Type Inference.
1429 //===----------------------------------------------------------------------===//
1430 
1431 static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
1432  SmallVector<int64_t> &outShape) {
1433  int64_t outRank = 0;
1434  for (int i = 0, e = operands.size(); i != e; ++i) {
1435  auto shape = operands.getShape(i);
1436  if (!shape.hasRank()) {
1437  // TODO(jennik): Update function to have better case handling for
1438  // invalid operands and for ranked tensors.
1439  return failure();
1440  }
1441  outRank = std::max<int64_t>(outRank, shape.getRank());
1442  }
1443 
1444  outShape.resize(outRank, 1);
1445 
1446  for (int i = 0, e = operands.size(); i != e; ++i) {
1447  auto shape = operands.getShape(i);
1448  auto rankDiff = outShape.size() - shape.getRank();
1449 
1450  for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
1451  auto dim1 = outShape[i + rankDiff];
1452  auto dim2 = shape.getDimSize(i);
1453  auto resolvedDim = dim1;
1454 
1455  if (dim1 == 1) {
1456  resolvedDim = dim2;
1457  } else if (dim2 == 1) {
1458  resolvedDim = dim1;
1459  } else if (dim1 != dim2) {
1460  return failure();
1461  }
1462  outShape[i + rankDiff] = resolvedDim;
1463  }
1464  }
1465 
1466  return success();
1467 }
1468 
1469 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1470  MLIRContext *context, ::std::optional<Location> location,
1471  ArgMaxOp::Adaptor adaptor,
1472  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1473  ShapeAdaptor inputShape(adaptor.getInput().getType());
1474  IntegerAttr axis = adaptor.getProperties().axis;
1475  int32_t axisVal = axis.getValue().getSExtValue();
1476 
1477  if (!inputShape.hasRank()) {
1478  inferredReturnShapes.push_back(ShapedTypeComponents());
1479  return success();
1480  }
1481 
1482  SmallVector<int64_t> outShape;
1483  outShape.reserve(inputShape.getRank() - 1);
1484  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1485  if (i == axisVal)
1486  continue;
1487  outShape.push_back(inputShape.getDimSize(i));
1488  }
1489 
1490  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1491  return success();
1492 }
1493 
1494 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1495  MLIRContext *context, ::std::optional<Location> location,
1496  RFFT2dOp::Adaptor adaptor,
1497  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1498  ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1499 
1500  if (!inputShape.hasRank())
1501  return failure();
1502 
1503  llvm::SmallVector<int64_t> outputShape;
1504  outputShape.resize(3, ShapedType::kDynamic);
1505  outputShape[0] = inputShape.getDimSize(0);
1506  outputShape[1] = inputShape.getDimSize(1);
1507  int64_t inWidth = inputShape.getDimSize(2);
1508 
1509  // Note that we can support this calculation symbolically
1510  // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
1511  if (inWidth != ShapedType::kDynamic)
1512  outputShape[2] = inWidth / 2 + 1;
1513 
1514  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1515  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1516 
1517  return success();
1518 }
1519 
1520 static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
1521  const llvm::StringRef dimName) {
1522  const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1523  if (!isPowerOfTwo)
1524  return op->emitOpError("expected ")
1525  << dimName << " to be a power of two, got " << dimSize;
1526 
1527  return success();
1528 }
1529 
1530 LogicalResult tosa::RFFT2dOp::verify() {
1531  const auto outputTypes = getResultTypes();
1532  if (failed(verifyCompatibleShapes(outputTypes)))
1533  return emitOpError("expected output shapes to match, got ") << outputTypes;
1534 
1535  const auto inputType =
1536  llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1537  if (!inputType)
1538  return success();
1539 
1540  const int64_t height = inputType.getDimSize(1);
1541  if (ShapedType::isStatic(height) &&
1542  failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1543  return failure();
1544 
1545  const int64_t width = inputType.getDimSize(2);
1546  if (ShapedType::isStatic(width) &&
1547  failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1548  return failure();
1549 
1550  const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1551  if (!outputType)
1552  return success();
1553 
1554  // Batch and height input/output dimensions should match
1555  if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
1556  outputType.getShape().drop_back())))
1557  return emitOpError("expected batch and height dimensions of input/output "
1558  "to match, got input=")
1559  << inputType << " output=" << outputType;
1560 
1561  // Output width dimension expected to be input_width / 2 + 1
1562  const int64_t outputWidth = outputType.getDimSize(2);
1563  if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1564  (outputWidth != (width / 2) + 1))
1565  return emitOpError(
1566  "expected output width to be equal to input_width / 2 + 1, got ")
1567  << outputWidth;
1568 
1569  return success();
1570 }
1571 
1572 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1573  MLIRContext *context, ::std::optional<Location> location,
1574  FFT2dOp::Adaptor adaptor,
1575  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1576  inferredReturnShapes.push_back(
1577  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
1578  inferredReturnShapes.push_back(
1579  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
1580  return success();
1581 }
1582 
1583 LogicalResult tosa::FFT2dOp::verify() {
1584  const auto inputRealType =
1585  llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1586  const auto inputImagType =
1587  llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
1588  if (!inputRealType || !inputImagType)
1589  return success();
1590 
1591  const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
1592  return ShapedType::isDynamic(a) ? a : b;
1593  };
1594 
1595  const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1596  inputImagType.getDimSize(1));
1597  if (ShapedType::isStatic(height) &&
1598  failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1599  return failure();
1600 
1601  const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1602  inputImagType.getDimSize(2));
1603  if (ShapedType::isStatic(width) &&
1604  failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1605  return failure();
1606 
1607  return success();
1608 }
1609 
1610 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1611  MLIRContext *context, ::std::optional<Location> location,
1612  ConcatOp::Adaptor adaptor,
1613  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1614  // Infer all dimension sizes by reducing based on inputs.
1615  const Properties &prop = adaptor.getProperties();
1616  int32_t axis = prop.axis.getValue().getSExtValue();
1617  llvm::SmallVector<int64_t> outputShape;
1618  bool hasRankedInput = false;
1619  for (auto operand : adaptor.getOperands()) {
1620  ShapeAdaptor operandShape(operand.getType());
1621  if (!operandShape.hasRank())
1622  continue;
1623 
1624  // Copy the Operand's rank.
1625  if (!hasRankedInput)
1626  outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1627 
1628  // Copy shapes until the dim is non-dynamic.
1629  for (int i = 0, s = operandShape.getRank(); i < s; i++) {
1630  if (i == axis || operandShape.isDynamicDim(i))
1631  continue;
1632  if (outputShape[i] == ShapedType::kDynamic)
1633  outputShape[i] = operandShape.getDimSize(i);
1634  if (outputShape[i] != operandShape.getDimSize(i))
1635  return emitOptionalError(location,
1636  "Cannot concat tensors with different sizes"
1637  " on the non-axis dimension ",
1638  i);
1639  }
1640 
1641  hasRankedInput = true;
1642  }
1643 
1644  if (adaptor.getInput1().empty())
1645  return failure();
1646 
1647  Type inputType =
1648  llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1649  if (!hasRankedInput) {
1650  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1651  return success();
1652  }
1653 
1654  // Determine the dimension size along the concatenation axis.
1655  int64_t concatDimSize = 0;
1656  for (auto operand : adaptor.getOperands()) {
1657  ShapeAdaptor operandShape(operand.getType());
1658 
1659  // We need to know the length of the concatenation axis of all inputs to
1660  // determine the dimension size of the output shape.
1661  if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1662  concatDimSize = ShapedType::kDynamic;
1663  break;
1664  }
1665 
1666  concatDimSize += operandShape.getDimSize(axis);
1667  }
1668 
1669  outputShape[axis] = concatDimSize;
1670 
1671  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1672  return success();
1673 }
1674 
1675 LogicalResult tosa::ConcatOp::verify() {
1676  // check that each input has same element type as output
1677  auto outType = getOutput().getType();
1678  const Operation::operand_range inputList = getInput1();
1679 
1680  // Check there is at least one input
1681  if (inputList.empty())
1682  return emitOpError("expect at least one input");
1683 
1684  if (!llvm::all_of(inputList, [&](auto input) {
1685  return succeeded(verifySameElementTypes(
1686  *this, /* inType = */ input.getType(), outType));
1687  })) {
1688  return failure();
1689  }
1690 
1691  const int32_t axis = getAxis();
1692  ShapeAdaptor firstRankedInputShape = nullptr;
1693  for (const auto &input : inputList) {
1694  const Type inputType = input.getType();
1695  ShapeAdaptor currShape(inputType);
1696  if (currShape.hasRank()) {
1697  firstRankedInputShape = currShape;
1698  // Check axis is in expected range
1699  if (axis < 0 || axis >= firstRankedInputShape.getRank())
1700  return emitOpError("expect axis to be within range 0 < axis < "
1701  "rank(input1[firstRankedTensorIdx]), got ")
1702  << axis;
1703  break;
1704  }
1705  }
1706 
1707  const auto allOperandsHasRank = [](const Value input) {
1708  return ShapeAdaptor(input.getType()).hasRank();
1709  };
1710  if (llvm::all_of(inputList, allOperandsHasRank)) {
1711  const int64_t firstInputRank = firstRankedInputShape.getRank();
1712 
1713  for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {
1714  const ShapeAdaptor inputShape(input.getType());
1715  const int64_t inputRank = inputShape.getRank();
1716  const size_t operandNum = index + 1;
1717 
1718  // Check that each operand has the same rank
1719  if (inputRank != firstInputRank)
1720  return emitOpError(
1721  "expect all operands to have the same rank, but got ")
1722  << firstInputRank << " vs " << inputRank << " on operands 0 and "
1723  << operandNum;
1724 
1725  // Check non-axis dims match
1726  for (int i = 0; i < inputRank; i++) {
1727  const int64_t inputDim = inputShape.getDimSize(i);
1728  const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
1729  if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
1730  inputShape.isDynamicDim(i))
1731  continue;
1732  if (inputDim != firstInputDim)
1733  return emitOpError("expect all operand shapes to have the same sizes "
1734  "on non-axis dimensions, but got ")
1735  << inputDim << " vs " << firstInputDim << " at index " << i
1736  << " on operands 0 and " << operandNum;
1737  }
1738  }
1739 
1740  // ERROR_IF(axis_sum != shape[axis]);
1741  int64_t axisSum = 0;
1742  for (const auto &input : inputList) {
1743  const ShapeAdaptor inputShape(input.getType());
1744  if (inputShape.isDynamicDim(axis)) {
1745  // make axisSum negative to indicate invalid value
1746  axisSum = -1;
1747  break;
1748  }
1749  axisSum += inputShape.getDimSize(axis);
1750  }
1751  const ShapeAdaptor outputShape(outType);
1752  if (axisSum >= 0 && outputShape.hasRank() &&
1753  !outputShape.isDynamicDim(axis) &&
1754  axisSum != outputShape.getDimSize(axis))
1755  return emitOpError("requires sum of axis dimensions of input1 "
1756  "equal to output axis dimension, got ")
1757  << axisSum << " and " << outputShape.getDimSize(axis);
1758  }
1759 
1760  return success();
1761 }
1762 
1763 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1764  MLIRContext *context, ::std::optional<Location> location,
1765  ValueShapeRange operands, DictionaryAttr attributes,
1766  OpaqueProperties properties, RegionRange regions,
1767  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1768  auto elementType = IntegerType::get(context, /*width=*/1);
1769 
1770  llvm::SmallVector<int64_t> outShape;
1771  if (resolveBroadcastShape(operands, outShape).failed()) {
1772  inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
1773  return success();
1774  }
1775 
1776  inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
1777  return success();
1778 }
1779 
1780 bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1781  if (l.size() != r.size() || l.size() != 1)
1782  return false;
1783  return succeeded(verifyCompatibleShape(l[0], r[0]));
1784 }
1785 
1786 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1787  MLIRContext *context, ::std::optional<Location> location,
1788  MatMulOp::Adaptor adaptor,
1789  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1790  ShapeAdaptor lhsShape(adaptor.getA().getType());
1791  ShapeAdaptor rhsShape(adaptor.getB().getType());
1792 
1793  // All shapes are dynamic.
1794  SmallVector<int64_t> outShape;
1795  outShape.resize(3, ShapedType::kDynamic);
1796 
1797  if (lhsShape.hasRank()) {
1798  outShape[0] = lhsShape.getDimSize(0);
1799  outShape[1] = lhsShape.getDimSize(1);
1800  }
1801 
1802  if (rhsShape.hasRank()) {
1803  outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1804  : outShape[0];
1805  outShape[2] = rhsShape.getDimSize(2);
1806  }
1807 
1808  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1809  return success();
1810 }
1811 
1812 LogicalResult MatMulOp::verify() {
1813  auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
1814  auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
1815 
1816  // Must be shaped tensor types
1817  if (!aType)
1818  return emitOpError("expect a shaped tensor for input a, got ")
1819  << getA().getType();
1820 
1821  if (!bType)
1822  return emitOpError("expect a shaped tensor for input b, got ")
1823  << getB().getType();
1824 
1825  auto aElementType = aType.getElementType();
1826  auto bElementType = bType.getElementType();
1827 
1828  auto aQuantizedEType =
1829  llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1830  auto bQuantizedEType =
1831  llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1832 
1833  if (aQuantizedEType || bQuantizedEType) {
1834  if (!aQuantizedEType || !bQuantizedEType) {
1835  return emitOpError("expect operands to be both quantized or both not "
1836  "quantized, got ")
1837  << aElementType << " and " << bElementType;
1838  }
1839  // both a and b have quantized element types
1840  auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1841  auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1842  if (aQuantWidth != bQuantWidth) {
1843  return emitOpError("expect quantized operands to have same widths, got ")
1844  << aQuantWidth << " and " << bQuantWidth;
1845  }
1846  } else {
1847  // non-quantized element types
1848  if (aElementType != bElementType) {
1849  return emitOpError("expect same element type for inputs a and b, got ")
1850  << aElementType << " and " << bElementType;
1851  }
1852  }
1853 
1854  // check a_zp and b_zp
1855  auto aEType = getStorageElementTypeOrSelf(aType);
1856  auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
1857  if (aEType != aZpEType) {
1858  return emitOpError("expect input a and a_zp have the same "
1859  "element type, got ")
1860  << aEType << " and " << aZpEType;
1861  }
1862 
1863  auto bEType = getStorageElementTypeOrSelf(bType);
1864  auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
1865  if (bEType != bZpEType) {
1866  return emitOpError("expect input b and b_zp have the same "
1867  "element type, got ")
1868  << bEType << " and " << bZpEType;
1869  }
1870 
1871  FailureOr<int64_t> maybeAZp = getAZeroPoint();
1872  if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1873  return failure();
1874 
1875  FailureOr<int64_t> maybeBZp = getBZeroPoint();
1876  if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1877  return failure();
1878 
1879  return success();
1880 }
1881 
1882 LogicalResult tosa::PadOp::inferReturnTypeComponents(
1883  MLIRContext *context, ::std::optional<Location> location,
1884  PadOp::Adaptor adaptor,
1885  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1886  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1887  auto paddingRank =
1888  cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
1889  SmallVector<int64_t> outputShape;
1890 
1891  // If the input rank is unknown, we can infer the output rank using the
1892  // padding shape's rank divided by 2.
1893  if (!inputShape.hasRank()) {
1894  outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
1895  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1896  return success();
1897  }
1898 
1899  SmallVector<int64_t> paddingValues;
1900  // If the paddings value is not a constant, all dimensions must be dynamic.
1901  if (!tosa::getConstShapeValues(adaptor.getPadding().getDefiningOp(),
1902  paddingValues)) {
1903  outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1904  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1905  return success();
1906  }
1907 
1908  outputShape.reserve(inputShape.getRank());
1909  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1910  if (inputShape.isDynamicDim(i)) {
1911  outputShape.push_back(ShapedType::kDynamic);
1912  continue;
1913  }
1914  auto padFront = paddingValues[i * 2];
1915  auto padBack = paddingValues[i * 2 + 1];
1916  if (padFront < 0 || padBack < 0) {
1917  // if either padding for dim i is -1, output dim is unknown
1918  outputShape.push_back(ShapedType::kDynamic);
1919  continue;
1920  }
1921 
1922  outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
1923  }
1924 
1925  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1926  return success();
1927 }
1928 
1929 LogicalResult tosa::PadOp::verify() {
1930  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1931  /* outType = */ getOutput().getType())
1932  .failed()) {
1933  return failure();
1934  }
1935 
1936  if (auto padConst = getPadConst()) {
1937  if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
1938  /* outType = */ getOutput().getType())
1939  .failed()) {
1940  return failure();
1941  }
1942  }
1943 
1944  RankedTensorType inputType =
1945  llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1946  RankedTensorType outputType =
1947  llvm::dyn_cast<RankedTensorType>(getOutput().getType());
1948  if (!inputType || !outputType)
1949  return success();
1950 
1951  auto inputRank = inputType.getRank();
1952  auto outputRank = outputType.getRank();
1953  if (inputRank != outputRank)
1954  return emitOpError() << "expect same input and output tensor rank, but got "
1955  << "inputRank: " << inputRank
1956  << ", outputRank: " << outputRank;
1957 
1958  DenseIntElementsAttr paddingAttr;
1959  if (!matchPattern(getPadding(), m_Constant(&paddingAttr))) {
1960  return failure();
1961  }
1962 
1963  auto paddingValues = paddingAttr.getValues<APInt>();
1964  if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
1965  return emitOpError() << "padding tensor must have " << inputRank
1966  << " * 2 = " << inputRank * 2 << " elements, but got "
1967  << paddingValues.size();
1968 
1969  auto inputShape = inputType.getShape();
1970  auto outputShape = outputType.getShape();
1971 
1972  for (int64_t i = 0; i < inputRank; ++i) {
1973  int64_t padStart = paddingValues[i * 2].getSExtValue();
1974  int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
1975 
1976  if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
1977  return emitOpError()
1978  << "invalid padding values at dimension " << i
1979  << ": values must be non-negative or -1 for dynamic padding, got ["
1980  << padStart << ", " << padEnd << "]";
1981  }
1982 
1983  // Skip shape verification for dynamic input/output
1984  if (inputShape[i] == ShapedType::kDynamic ||
1985  outputShape[i] == ShapedType::kDynamic)
1986  continue;
1987 
1988  if (outputShape[i] != inputShape[i] + padStart + padEnd) {
1989  return emitOpError() << "mismatch in output shape at dimension " << i
1990  << ": expected " << inputShape[i] << " + "
1991  << padStart << " + " << padEnd << " = "
1992  << (inputShape[i] + padStart + padEnd)
1993  << ", but got " << outputShape[i];
1994  }
1995  }
1996 
1997  return success();
1998 }
1999 
2000 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
2001  MLIRContext *context, ::std::optional<Location> location,
2002  SliceOp::Adaptor adaptor,
2003  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2004 
2005  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2006  SmallVector<int64_t> start;
2007  SmallVector<int64_t> size;
2008 
2009  if (!tosa::getConstShapeValues(adaptor.getStart().getDefiningOp(), start) ||
2010  !tosa::getConstShapeValues(adaptor.getSize().getDefiningOp(), size)) {
2011  auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
2012  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2013  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2014  return success();
2015  }
2016 
2017  // if size[i] is -1, all remaining elements in dimension i are included
2018  // in the slice, similar to TF.
2019  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2020  // initialize outputShape to all unknown
2021  SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
2022  if (inputShape.hasRank()) {
2023  for (size_t i = 0; i < size.size(); i++) {
2024  if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
2025  (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
2026  start[i] < inputShape.getDimSize(i))) {
2027  // size[i] is not 0 and not < -1, and start[i] is in valid range
2028  if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
2029  // input shape has unknown dim[i] - only valid if size[i] > 0
2030  if (size[i] > 0) {
2031  outputShape[i] = size[i];
2032  }
2033  } else {
2034  // input shape has known dim[i]
2035  if (size[i] == -1) {
2036  outputShape[i] = inputShape.getDimSize(i) - start[i];
2037  } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2038  // start[i] + size[i] is within bound of input shape's dim[i]
2039  outputShape[i] = size[i];
2040  }
2041  }
2042  }
2043  }
2044  } else {
2045  outputShape = convertToMlirShape(size);
2046  }
2047  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2048  return success();
2049 }
2050 
2051 LogicalResult tosa::SliceOp::verify() {
2052  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2053  /* outType = */ getOutput().getType())
2054  .failed())
2055  return failure();
2056 
2057  const ShapeAdaptor inputShape(getInput1().getType());
2058  if (inputShape.hasRank()) {
2059  const auto inputRank = inputShape.getRank();
2060  const ShapeAdaptor outputShape(getOutput().getType());
2061  if (outputShape.hasRank() && inputRank != outputShape.getRank())
2062  return emitOpError(
2063  "expect input1 and output to have the same ranks, got ")
2064  << inputRank << " and " << outputShape.getRank();
2065 
2066  const auto startShapeRank =
2067  llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
2068  if (inputRank != startShapeRank)
2069  return emitOpError("length of start is not equal to rank of input shape");
2070 
2071  const auto sizeShapeRank =
2072  llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
2073  if (inputRank != sizeShapeRank)
2074  return emitOpError("length of size is not equal to rank of input shape");
2075  }
2076 
2077  return success();
2078 }
2079 
2080 LogicalResult tosa::MulOp::inferReturnTypeComponents(
2081  MLIRContext *context, ::std::optional<Location> location,
2082  ValueShapeRange operands, DictionaryAttr attributes,
2083  OpaqueProperties properties, RegionRange regions,
2084  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2085  // mul op's output shape only depend on input1 and input2, not on shift
2086  ValueShapeRange twoInputs = operands.drop_back();
2087  llvm::SmallVector<int64_t> outShape;
2088  if (resolveBroadcastShape(twoInputs, outShape).failed()) {
2089  inferredReturnShapes.push_back(ShapedTypeComponents());
2090  } else {
2091  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2092  }
2093  return success();
2094 }
2095 
2096 LogicalResult tosa::MulOp::verify() {
2097  const Value output = getOutput();
2098  auto resElemType = getElementTypeOrSelf(output);
2099 
2100  // Verify if the element type among operands and result match tosa
2101  // specification.
2102  if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2103  IntegerType lhsIntType =
2104  dyn_cast<IntegerType>(getElementTypeOrSelf(getInput1()));
2105  IntegerType rhsIntType =
2106  dyn_cast<IntegerType>(getElementTypeOrSelf(getInput2()));
2107  if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2108  return emitOpError("requires the same element type for all operands");
2109 
2110  // Though the spec requires the element type of result to be i32, a more
2111  // relaxed way is provided at dialect level for easier cooperating with
2112  // other dialects.
2113  if (lhsIntType.getWidth() > resIntType.getWidth())
2114  return emitOpError("invalid data type size for operands or result");
2115 
2116  } else {
2117  // For other supported type, the spec requires requires the same element
2118  // type for all operands (excludes `shift` operand) and results.
2119  for (int i = 0; i < 2; ++i) {
2120  if (getElementTypeOrSelf(getOperand(i)) != resElemType)
2121  return emitOpError(
2122  "requires the same element type for all operands and results");
2123  }
2124 
2125  // verify shift has value 0 for non-integer types
2126  ElementsAttr shift_elem;
2127  if (matchPattern(getShift(), m_Constant(&shift_elem))) {
2128  int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
2129  if (shift != 0) {
2130  return emitOpError() << "require shift to be 0 for float type";
2131  }
2132  }
2133  }
2134 
2135  // Verify the op has same ranks for all main operands (excludes extra operands
2136  // such as shift of mul op, so this is the only difference with the built-in
2137  // `SameOperandsAndResultRank` trait) and results types, if known.
2138  TypeRange operandTypes = getOperandTypes();
2139  ShapedType aType = cast<ShapedType>(operandTypes[0]);
2140  ShapedType bType = cast<ShapedType>(operandTypes[1]);
2141 
2142  const bool aHasRank = aType.hasRank();
2143  const bool bHasRank = bType.hasRank();
2144  if (aHasRank && bHasRank) {
2145  const int64_t aRank = aType.getRank();
2146  const int64_t bRank = bType.getRank();
2147  if (aRank != bRank)
2148  return emitOpError("a and b operands don't have matching ranks, got ")
2149  << aRank << " and " << bRank;
2150 
2151  // check for broadcast compatible shapes
2152  SmallVector<int64_t> resultShape;
2154  aType.getShape(), bType.getShape(), resultShape))
2155  return emitOpError("a and b operands don't have broadcast-compatible "
2156  "shapes, got ")
2157  << aType << " and " << bType;
2158  }
2159 
2160  ShapedType resultType = cast<ShapedType>(output.getType());
2161  if (!resultType.hasRank())
2162  return success();
2163 
2164  const int64_t resultRank = resultType.getRank();
2165  if (aHasRank && resultRank != aType.getRank())
2166  return emitOpError("result type has different rank than a, got ")
2167  << resultRank << " vs " << aType.getRank();
2168  if (bHasRank && resultRank != bType.getRank())
2169  return emitOpError("result type has different rank than b, got ")
2170  << resultRank << " vs " << bType.getRank();
2171 
2172  return success();
2173 }
2174 
2175 LogicalResult tosa::TableOp::inferReturnTypeComponents(
2176  MLIRContext *context, ::std::optional<Location> location,
2177  TableOp::Adaptor adaptor,
2178  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2179  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2180 
2181  if (!inputShape.hasRank()) {
2182  inferredReturnShapes.push_back(ShapedTypeComponents());
2183  return success();
2184  }
2185 
2186  inferredReturnShapes.resize(1);
2187  inputShape.getDims(inferredReturnShapes[0]);
2188  return success();
2189 }
2190 
2191 LogicalResult tosa::TableOp::verify() {
2192  TensorType inputType = getInput1().getType();
2193  TensorType outputType = getOutput().getType();
2194 
2195  if (inputType.hasRank() && outputType.hasRank() &&
2196  inputType.getRank() != outputType.getRank())
2197  return emitOpError()
2198  << "expected input tensor rank to equal result tensor rank";
2199 
2200  auto inputDims = inputType.getShape();
2201  auto outputDims = outputType.getShape();
2202  for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
2203  int64_t dim = it.index();
2204  auto [inputDim, outputDim] = it.value();
2205  if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2206  return emitOpError() << "dim(result, " << dim << ") = " << outputDim
2207  << " doesn't match dim(input, " << dim
2208  << ") = " << inputDim;
2209  }
2210  }
2211  return success();
2212 }
2213 
2214 LogicalResult
2215 tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
2216  // Multiples must be constants.
2217  DenseIntElementsAttr multiplesAttr;
2218  if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
2219  return failure();
2220  multiples = llvm::to_vector(
2221  llvm::map_range(multiplesAttr.getValues<APInt>(),
2222  [](const APInt &val) { return val.getSExtValue(); }));
2223  return success();
2224 }
2225 
2226 LogicalResult tosa::TileOp::inferReturnTypeComponents(
2227  MLIRContext *context, ::std::optional<Location> location,
2228  TileOp::Adaptor adaptor,
2229  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2230  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2231  SmallVector<int64_t> multiples;
2232  if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
2233  multiples)) {
2234  auto rank =
2235  cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2236  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2237  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2238  return success();
2239  } else {
2240  multiples = convertToMlirShape(multiples);
2241  }
2242 
2243  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2244  SmallVector<int64_t> outputShape;
2245  if (!inputShape.hasRank()) {
2246  outputShape.resize(multiples.size(), ShapedType::kDynamic);
2247  inferredReturnShapes.push_back(
2248  ShapedTypeComponents(outputShape, inputType));
2249  return success();
2250  } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
2251  return failure();
2252 
2253  // Any non dynamic dimension can be multiplied to a known size.
2254  outputShape.reserve(multiples.size());
2255  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2256  if (multiples[i] == ShapedType::kDynamic) {
2257  outputShape.push_back(ShapedType::kDynamic);
2258  } else {
2259  int64_t dim = inputShape.getDimSize(i);
2260  if (dim != ShapedType::kDynamic)
2261  dim *= multiples[i];
2262  outputShape.push_back(dim);
2263  }
2264  }
2265 
2266  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2267  return success();
2268 }
2269 
2270 LogicalResult tosa::TileOp::verify() {
2271  if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
2272  /* outType = */ getOutput().getType())
2273  .failed()) {
2274  return failure();
2275  }
2276  ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
2277  ShapedType outputType = llvm::cast<ShapedType>(getType());
2278 
2279  shapeType multiplesType =
2280  llvm::cast<tosa::shapeType>(getMultiples().getType());
2281 
2282  auto multiplesRank = multiplesType.getRank();
2283 
2284  if (inputType.hasRank()) {
2285  if (inputType.getRank() != multiplesRank)
2286  return emitOpError("expect 'multiples' to have rank ")
2287  << inputType.getRank() << " but got " << multiplesRank << ".";
2288  if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
2289  return emitOpError("expect same input and output tensor rank.");
2290  } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2291  return emitOpError("expect 'multiples' array to have length ")
2292  << outputType.getRank() << " but got " << multiplesRank << ".";
2293 
2294  SmallVector<int64_t> multiples;
2295  if (getConstantMultiples(multiples).succeeded() &&
2296  llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
2297  return emitOpError(
2298  "expect element of 'multiples' to be positive integer or -1.");
2299 
2300  return success();
2301 }
2302 
2303 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2304  if (l.size() != r.size() || l.size() != 1)
2305  return false;
2306  return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
2307 }
2308 
2309 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2310  MLIRContext *context, ::std::optional<Location> location,
2311  ReshapeOp::Adaptor adaptor,
2312  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2313  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2314  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2315  llvm::SmallVector<int64_t> newShapeValue;
2316  if (!tosa::getConstShapeValues(adaptor.getShape().getDefiningOp(),
2317  newShapeValue)) {
2318  auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2319  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2320  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2321  return success();
2322  } else {
2323  newShapeValue = convertToMlirShape(newShapeValue);
2324  }
2325 
2326  // We cannot infer from the total number of elements so we must take the
2327  // shape attribute as exact.
2328  if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2329  inferredReturnShapes.push_back(
2330  ShapedTypeComponents(newShapeValue, inputType));
2331  return success();
2332  }
2333 
2334  // Determine the number of elements covered by the slice of all static
2335  // dimensions. This allows us to infer the length of the remaining dynamic
2336  // dimension.
2337  int64_t numElements = inputShape.getNumElements();
2338  int64_t staticMul = 1;
2339  for (auto val : newShapeValue) {
2340  if (ShapedType::isStatic(val)) {
2341  staticMul *= val;
2342  }
2343  }
2344 
2345  // Determine the length of the dynamic dimension.
2346  for (auto &val : newShapeValue) {
2347  if (ShapedType::isDynamic(val))
2348  val = numElements / staticMul;
2349  }
2350 
2351  inferredReturnShapes.push_back(
2352  ShapedTypeComponents(newShapeValue, inputType));
2353  return success();
2354 }
2355 
2356 llvm::LogicalResult tosa::ReshapeOp::verify() {
2357  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2358  /* outType = */ getOutput().getType())
2359  .failed()) {
2360  return failure();
2361  }
2362  TensorType inputType = getInput1().getType();
2363 
2364  SmallVector<int64_t> shapeValues;
2365  if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
2366  // skip following checks if shape is not constant
2367  return mlir::success();
2368  }
2369 
2370  int missingDims = llvm::count(shapeValues, -1);
2371  if (missingDims > 1)
2372  return emitOpError() << "expected at most one target dimension to be -1";
2373 
2374  const auto outputType = dyn_cast<RankedTensorType>(getType());
2375  if (!outputType)
2376  return success();
2377 
2378  if ((int64_t)shapeValues.size() != outputType.getRank())
2379  return emitOpError() << "new shape does not match result rank";
2380 
2381  for (auto [newShapeDim, outputShapeDim] :
2382  zip(shapeValues, outputType.getShape())) {
2383  if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2384  outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2385  return emitOpError() << "new shape is inconsistent with result shape";
2386 
2387  if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2388  return emitOpError() << "new shape has invalid tensor dimension size "
2389  << newShapeDim;
2390  }
2391 
2392  if (inputType.hasStaticShape()) {
2393  int64_t inputElementsNum = inputType.getNumElements();
2394  if (outputType.hasStaticShape()) {
2395  int64_t outputElementsNum = outputType.getNumElements();
2396  if (inputElementsNum != outputElementsNum) {
2397  return emitOpError() << "cannot reshape " << inputElementsNum
2398  << " elements into " << outputElementsNum;
2399  }
2400  }
2401 
2402  int64_t newShapeElementsNum = std::accumulate(
2403  shapeValues.begin(), shapeValues.end(), 1LL,
2404  [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
2405  bool isStaticNewShape =
2406  llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
2407  if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2408  (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2409  return emitOpError() << "cannot reshape " << inputElementsNum
2410  << " elements into " << newShapeElementsNum;
2411  }
2412  }
2413 
2414  return mlir::success();
2415 }
2416 
2417 // return failure if val is not a constant
2418 // set zp to -1 if val is non-zero float or val is not integer nor float
2419 // otherwise set zp to val's constant value
2420 static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
2421  ElementsAttr zpAttr;
2422  if (!matchPattern(val, m_Constant(&zpAttr))) {
2423  return failure();
2424  }
2425 
2426  Type zpElemType = zpAttr.getElementType();
2427 
2428  if (llvm::isa<FloatType>(zpElemType)) {
2429  if (zpAttr.getValues<APFloat>()[0].isZero()) {
2430  return 0;
2431  }
2432  // return non-zero value to trigger error check
2433  return -1;
2434  }
2435 
2436  if (llvm::isa<IntegerType>(zpElemType)) {
2437  if (signExtend)
2438  return zpAttr.getValues<APInt>()[0].getSExtValue();
2439  else
2440  return zpAttr.getValues<APInt>()[0].getZExtValue();
2441  }
2442 
2443  // return non-zero value to trigger error check
2444  return -1;
2445 }
2446 
2447 template <typename T>
2448 static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
2449  const std::string &operand) {
2450  Type zpElemType = getElementTypeOrSelf(val);
2451 
2452  if (!zpElemType.isInteger(8) && zp != 0) {
2453  // convert operand to lower case for error message
2454  std::string lower = operand;
2455  std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
2456  return op.emitOpError()
2457  << lower << " zero point must be zero for non-int8 integer types";
2458  }
2459 
2460  return success();
2461 }
2462 
2463 static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
2464  const int64_t &zp,
2465  const std::string &operand) {
2466  bool isInputZp = (operand == "Input");
2467 
2468  bool tensorUnsigned =
2469  isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2470  StringRef tensorName = isInputZp ? "input" : "output";
2471 
2472  Type zpElemType = getElementTypeOrSelf(zpVal);
2473 
2474  if (zp != 0) {
2475  if (!zpElemType.isInteger(8) &&
2476  !(zpElemType.isInteger(16) && tensorUnsigned)) {
2477  return op.emitOpError()
2478  << "expect " << tensorName << "_zp of 0, got " << zp;
2479  }
2480  if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
2481  return op.emitOpError() << "expect " << tensorName
2482  << "_zp of 0 or 32768 for unsigned int16 "
2483  << tensorName << ", got " << zp;
2484  }
2485  }
2486 
2487  return success();
2488 }
2489 
2490 #define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2491  FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2492  return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2493  } \
2494  LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2495  return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2496  }
2497 
2498 ZERO_POINT_HELPER(Conv2DOp, Input, true)
2499 ZERO_POINT_HELPER(Conv2DOp, Weight, true)
2500 ZERO_POINT_HELPER(Conv3DOp, Input, true)
2501 ZERO_POINT_HELPER(Conv3DOp, Weight, true)
2502 ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true)
2503 ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true)
2504 ZERO_POINT_HELPER(TransposeConv2DOp, Input, true)
2505 ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true)
2506 ZERO_POINT_HELPER(AvgPool2dOp, Input, true)
2507 ZERO_POINT_HELPER(AvgPool2dOp, Output, true)
2508 ZERO_POINT_HELPER(MatMulOp, A, true)
2509 ZERO_POINT_HELPER(MatMulOp, B, true)
2510 ZERO_POINT_HELPER(NegateOp, Input1, true)
2511 ZERO_POINT_HELPER(NegateOp, Output, true)
2512 ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned())
2513 ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned())
2514 #undef ZERO_POINT_HELPER
2515 
2516 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2517  MLIRContext *context, ::std::optional<Location> location,
2518  TransposeOp::Adaptor adaptor,
2519  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2520  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2521 
2522  // If input rank and permutation length is unknown, the output rank is
2523  // unknown.
2524  if (!inputShape.hasRank()) {
2525  inferredReturnShapes.push_back(ShapedTypeComponents());
2526  return success();
2527  }
2528 
2529  const auto inputRank = inputShape.getRank();
2530 
2531  // This would imply the number of permutations does not match the rank of
2532  // the input which is illegal.
2533  if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
2534  return failure();
2535  }
2536 
2537  SmallVector<int64_t> outputShape;
2538  // Rank-0 means no permutations matter.
2539  if (inputRank == 0) {
2540  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2541  return success();
2542  }
2543 
2544  // Check whether the input dimensions are all the same.
2545  bool allTheSame = true;
2546  for (int i = 1, s = inputRank; i < s; i++) {
2547  if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
2548  allTheSame = false;
2549  break;
2550  }
2551  }
2552 
2553  // If all of the input dimensions are the same we don't care about the
2554  // permutation.
2555  if (allTheSame) {
2556  outputShape.resize(inputRank, inputShape.getDimSize(0));
2557  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2558  return success();
2559  }
2560 
2561  outputShape.resize(inputRank, ShapedType::kDynamic);
2562 
2563  // Constant permutation values must be within the input rank.
2564  if (llvm::any_of(adaptor.getPerms(),
2565  [inputRank](const auto i) { return i >= inputRank; }))
2566  return failure();
2567 
2568  outputShape.reserve(inputRank);
2569  for (int i = 0, s = inputRank; i < s; i++) {
2570  outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
2571  }
2572 
2573  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2574  return success();
2575 }
2576 
2577 LogicalResult tosa::TransposeOp::verify() {
2578  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2579  /* outType = */ getOutput().getType())
2580  .failed()) {
2581  return failure();
2582  }
2583 
2584  const ShapeAdaptor inputShape(getInput1().getType());
2585  const ShapeAdaptor outputShape(getOutput().getType());
2586 
2587  const llvm::ArrayRef<int32_t> constantPerms = getPerms();
2588 
2589  if (inputShape.hasRank() &&
2590  constantPerms.size() != static_cast<size_t>(inputShape.getRank()))
2591  return emitOpError() << "expected perms attribute to have size "
2592  << inputShape.getRank()
2593  << " (input rank) but got size "
2594  << constantPerms.size();
2595 
2596  if (inputShape.hasRank() && outputShape.hasRank() &&
2597  inputShape.getRank() != outputShape.getRank())
2598  return emitOpError()
2599  << "expected input tensor rank to equal result tensor rank";
2600 
2601  if (outputShape.hasRank() &&
2602  constantPerms.size() != static_cast<size_t>(outputShape.getRank()))
2603  return emitOpError() << "expected perms attribute to have size "
2604  << outputShape.getRank()
2605  << " (output rank) but got size "
2606  << constantPerms.size();
2607 
2608  if (!llvm::all_of(constantPerms,
2609  [&constantPerms](int32_t s) {
2610  return s >= 0 &&
2611  static_cast<size_t>(s) < constantPerms.size();
2612  }) ||
2613  !isPermutationVector(llvm::to_vector(llvm::map_range(
2614  constantPerms, [](int32_t v) -> int64_t { return v; }))))
2615  return emitOpError() << "expected valid permutation indices";
2616 
2617  // ERROR_IF(tensor_size(shape1) != tensor_size(shape))
2618  if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2619  inputShape.getNumElements() != outputShape.getNumElements())
2620  return emitOpError() << "expected input1 and output to have same numbers "
2621  "of elements, got "
2622  << inputShape.getNumElements() << " and "
2623  << outputShape.getNumElements();
2624 
2625  // Verify that the types of the input and output tensors are properly
2626  // permuted.
2627  if (inputShape.hasRank() && outputShape.hasRank()) {
2628  for (auto i = 0; i < outputShape.getRank(); i++) {
2629  if (inputShape.isDynamicDim(constantPerms[i]) ||
2630  outputShape.isDynamicDim(i))
2631  continue;
2632 
2633  if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2634  return emitOpError()
2635  << "expected output tensor dim " << i << " to match "
2636  << "input dim " << constantPerms[i] << " with value of "
2637  << inputShape.getDimSize(constantPerms[i]);
2638  }
2639  }
2640 
2641  return success();
2642 }
2643 
2644 LogicalResult TransposeOp::reifyResultShapes(
2645  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2646 
2647  const llvm::ArrayRef<int32_t> transposePerms = getPerms();
2648 
2649  Value input = getInput1();
2650  auto inputType = cast<TensorType>(input.getType());
2651 
2652  SmallVector<OpFoldResult> returnedDims(inputType.getRank());
2653  for (auto dim : transposePerms) {
2654  int32_t dimInInput = transposePerms[dim];
2655  if (inputType.isDynamicDim(dimInInput))
2656  returnedDims[dim] =
2657  tensor::DimOp::create(builder, getLoc(), input, dimInInput)
2658  .getResult();
2659  else
2660  returnedDims[dim] =
2661  builder.getIndexAttr(inputType.getDimSize(dimInInput));
2662  }
2663 
2664  reifiedReturnShapes.emplace_back(std::move(returnedDims));
2665  return success();
2666 }
2667 
2668 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2669  MLIRContext *context, ::std::optional<Location> location,
2670  GatherOp::Adaptor adaptor,
2671  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2672  llvm::SmallVector<int64_t> outputShape;
2673  outputShape.resize(3, ShapedType::kDynamic);
2674 
2675  ShapeAdaptor valuesShape(adaptor.getValues().getType());
2676  if (valuesShape.hasRank()) {
2677  outputShape[0] = valuesShape.getDimSize(0);
2678  outputShape[2] = valuesShape.getDimSize(2);
2679  }
2680 
2681  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2682  if (indicesShape.hasRank()) {
2683  if (outputShape[0] == ShapedType::kDynamic)
2684  outputShape[0] = indicesShape.getDimSize(0);
2685  if (outputShape[1] == ShapedType::kDynamic)
2686  outputShape[1] = indicesShape.getDimSize(1);
2687  }
2688 
2689  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2690  return success();
2691 }
2692 
2693 LogicalResult tosa::GatherOp::verify() {
2694  if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
2695  /* outType = */ getOutput().getType())
2696  .failed()) {
2697  return failure();
2698  }
2699 
2700  const ShapeAdaptor valuesShape(getValues().getType());
2701  const ShapeAdaptor indicesShape(getIndices().getType());
2702  const ShapeAdaptor outputShape(getOutput().getType());
2703 
2704  int64_t N = ShapedType::kDynamic;
2705  int64_t W = ShapedType::kDynamic;
2706  int64_t C = ShapedType::kDynamic;
2707 
2708  if (valuesShape.hasRank()) {
2709  N = valuesShape.getDimSize(0);
2710  C = valuesShape.getDimSize(2);
2711  }
2712  if (indicesShape.hasRank()) {
2713  const int64_t indicesN = indicesShape.getDimSize(0);
2714  W = indicesShape.getDimSize(1);
2715  if (N == ShapedType::kDynamic)
2716  N = indicesN;
2717  else if (indicesN != ShapedType::kDynamic && N != indicesN)
2718  return emitOpError() << "requires indices dimension 0 to have size " << N
2719  << ", got " << indicesN;
2720  }
2721  if (outputShape.hasRank()) {
2722  const int64_t outputN = outputShape.getDimSize(0);
2723  const int64_t outputW = outputShape.getDimSize(1);
2724  const int64_t outputC = outputShape.getDimSize(2);
2725  if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2726  N != outputN)
2727  return emitOpError() << "requires output dimension 0 to have size " << N
2728  << ", got " << outputN;
2729 
2730  if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2731  W != outputW)
2732  return emitOpError() << "requires output dimension 1 to have size " << W
2733  << ", got " << outputW;
2734  if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2735  C != outputC)
2736  return emitOpError() << "requires output dimension 2 to have size " << C
2737  << ", got " << outputC;
2738  }
2739  return success();
2740 }
2741 
2742 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2743  MLIRContext *context, ::std::optional<Location> location,
2744  ResizeOp::Adaptor adaptor,
2745  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2746  llvm::SmallVector<int64_t, 4> outputShape;
2747  outputShape.resize(4, ShapedType::kDynamic);
2748 
2749  ShapeAdaptor inputShape(adaptor.getInput().getType());
2750  if (!inputShape.hasRank())
2751  return failure();
2752 
2753  outputShape[0] = inputShape.getDimSize(0);
2754  outputShape[3] = inputShape.getDimSize(3);
2755  int64_t inputHeight = inputShape.getDimSize(1);
2756  int64_t inputWidth = inputShape.getDimSize(2);
2757 
2758  if ((inputHeight == ShapedType::kDynamic) ||
2759  (inputWidth == ShapedType::kDynamic))
2760  return failure();
2761 
2762  SmallVector<int64_t> scaleInt, offsetInt, borderInt;
2763  if (!tosa::getConstShapeValues(adaptor.getScale().getDefiningOp(),
2764  scaleInt) ||
2765  !tosa::getConstShapeValues(adaptor.getOffset().getDefiningOp(),
2766  offsetInt) ||
2767  !tosa::getConstShapeValues(adaptor.getBorder().getDefiningOp(),
2768  borderInt)) {
2769  return failure();
2770  }
2771 
2772  // Compute the output shape based on attributes: scale, offset, and border.
2773  const int64_t outputHeight =
2774  (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2775  scaleInt[1]) +
2776  1;
2777 
2778  const int64_t outputWidth =
2779  (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2780  scaleInt[3]) +
2781  1;
2782 
2783  if (outputHeight < 0 || outputWidth < 0) {
2784  return emitOptionalError(
2785  location,
2786  "calculated output height and width must be non-negative, "
2787  "got height = ",
2788  outputHeight, ", width = ", outputWidth);
2789  }
2790 
2791  outputShape[1] = outputHeight;
2792  outputShape[2] = outputWidth;
2793  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2794  return success();
2795 }
2796 
2797 LogicalResult tosa::ResizeOp::verify() {
2798  const Value input = getInput();
2799  const Value output = getOutput();
2800  const RankedTensorType inputType =
2801  llvm::dyn_cast<RankedTensorType>(input.getType());
2802  const RankedTensorType outputType =
2803  llvm::dyn_cast<RankedTensorType>(output.getType());
2804 
2805  SmallVector<int64_t> scaleValues;
2806  SmallVector<int64_t> offsetValues;
2807  SmallVector<int64_t> borderValues;
2808  if (!tosa::getConstShapeValues(getScale().getDefiningOp(), scaleValues) ||
2809  !tosa::getConstShapeValues(getOffset().getDefiningOp(), offsetValues) ||
2810  !tosa::getConstShapeValues(getBorder().getDefiningOp(), borderValues)) {
2811  // Skip following checks if shape is not constant
2812  return success();
2813  }
2814 
2815  if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
2816  return emitOpError("expect all scale values to be > 0, got ")
2817  << scaleValues;
2818 
2819  const int64_t scaleYN = scaleValues[0];
2820  const int64_t scaleYD = scaleValues[1];
2821  const int64_t scaleXN = scaleValues[2];
2822  const int64_t scaleXD = scaleValues[3];
2823 
2824  const int64_t offsetY = offsetValues[0];
2825  const int64_t offsetX = offsetValues[1];
2826 
2827  const int64_t borderY = borderValues[0];
2828  const int64_t borderX = borderValues[1];
2829 
2830  if (!inputType)
2831  return success();
2832  if (!outputType)
2833  return success();
2834 
2835  const int64_t oh = outputType.getDimSize(1);
2836  const int64_t ow = outputType.getDimSize(2);
2837  const int64_t ih = inputType.getDimSize(1);
2838  const int64_t iw = inputType.getDimSize(2);
2839 
2840  // Don't check with input height that could be broadcast (ih != 1)
2841  // since Linalg, a consumer of TOSA, expects broadcasting support
2842  // in resize to be available. Taking the cautious approach for now,
2843  // we can consider removing support for broadcasting later.
2844  if (ih != ShapedType::kDynamic && ih != 1) {
2845  const std::optional<int64_t> calculatedOutHeightMinusOne =
2846  idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
2847  if (!calculatedOutHeightMinusOne.has_value())
2848  return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
2849  "border_y ")
2850  << "to be wholly divisible by scale_y_d, got ((" << ih
2851  << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
2852  << ") / " << scaleYD;
2853  const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
2854  if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
2855  return emitOpError("calculated output height did not match expected: ")
2856  << "calculated=" << calculatedOutHeight << ", expected=" << oh;
2857  }
2858 
2859  // Don't check with input width that could be broadcast (iw != 1)
2860  // since Linalg, a consumer of TOSA, expects broadcasting support
2861  // in resize to be available. Taking the cautious approach for now,
2862  // we can consider removing support for broadcasting later.
2863  if (iw != ShapedType::kDynamic && iw != 1) {
2864  const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
2865  const std::optional<int64_t> calculatedOutWidthMinusOne =
2866  idivCheck(scaledInWidth, scaleXD);
2867  if (!calculatedOutWidthMinusOne.has_value())
2868  return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
2869  "border_x ")
2870  << "to be wholly divisible by scale_x_d, got ((" << iw
2871  << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
2872  << ") / " << scaleXD;
2873  const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
2874  if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
2875  return emitOpError("calculated output width did not match expected: ")
2876  << "calculated=" << calculatedOutWidth << ", expected=" << ow;
2877  }
2878 
2879  return success();
2880 }
2881 
2882 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
2883  MLIRContext *context, ::std::optional<Location> location,
2884  ScatterOp::Adaptor adaptor,
2885  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2886  llvm::SmallVector<int64_t> outputShape;
2887  outputShape.resize(3, ShapedType::kDynamic);
2888 
2889  ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
2890  if (valuesInShape.hasRank()) {
2891  outputShape[0] = valuesInShape.getDimSize(0);
2892  outputShape[1] = valuesInShape.getDimSize(1);
2893  outputShape[2] = valuesInShape.getDimSize(2);
2894  }
2895 
2896  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2897  if (indicesShape.hasRank()) {
2898  if (outputShape[0] == ShapedType::kDynamic)
2899  outputShape[0] = indicesShape.getDimSize(0);
2900  }
2901 
2902  ShapeAdaptor inputShape(adaptor.getInput().getType());
2903  if (inputShape.hasRank()) {
2904  if (outputShape[0] == ShapedType::kDynamic)
2905  outputShape[0] = inputShape.getDimSize(0);
2906  if (outputShape[2] == ShapedType::kDynamic)
2907  outputShape[2] = inputShape.getDimSize(2);
2908  }
2909 
2910  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2911  return success();
2912 }
2913 
2914 LogicalResult tosa::ScatterOp::verify() {
2915  if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
2916  /* outType = */ getValuesOut().getType())
2917  .failed() ||
2918  verifySameElementTypes(*this, /* inType = */ getInput().getType(),
2919  /* outType = */ getValuesOut().getType())
2920  .failed()) {
2921  return failure();
2922  }
2923 
2924  const ShapeAdaptor valuesInShape(getValuesIn().getType());
2925  const ShapeAdaptor indicesShape(getIndices().getType());
2926  const ShapeAdaptor inputShape(getInput().getType());
2927  const ShapeAdaptor outputShape(getValuesOut().getType());
2928 
2929  int64_t N = ShapedType::kDynamic;
2930  int64_t K = ShapedType::kDynamic;
2931  int64_t W = ShapedType::kDynamic;
2932  int64_t C = ShapedType::kDynamic;
2933  if (valuesInShape.hasRank()) {
2934  N = valuesInShape.getDimSize(0);
2935  K = valuesInShape.getDimSize(1);
2936  C = valuesInShape.getDimSize(2);
2937  }
2938  if (indicesShape.hasRank()) {
2939  const int64_t indicesN = indicesShape.getDimSize(0);
2940  W = indicesShape.getDimSize(1);
2941  if (N == ShapedType::kDynamic)
2942  N = indicesN;
2943  else if (indicesN != ShapedType::kDynamic && N != indicesN)
2944  return emitOpError() << "requires indices dimension 0 to have size " << N
2945  << ", got " << indicesN;
2946  }
2947  if (inputShape.hasRank()) {
2948  const int64_t inputN = inputShape.getDimSize(0);
2949  const int64_t inputW = inputShape.getDimSize(1);
2950  const int64_t inputC = inputShape.getDimSize(2);
2951  if (N == ShapedType::kDynamic)
2952  N = inputN;
2953  else if (inputN != ShapedType::kDynamic && N != inputN)
2954  return emitOpError() << "requires input dimension 0 to have size " << N
2955  << ", got " << inputN;
2956  if (W == ShapedType::kDynamic)
2957  W = inputW;
2958  else if (inputW != ShapedType::kDynamic && W != inputW)
2959  return emitOpError() << "requires input dimension 1 to have size " << W
2960  << ", got " << inputW;
2961 
2962  if (C == ShapedType::kDynamic)
2963  C = inputC;
2964  else if (inputC != ShapedType::kDynamic && C != inputC)
2965  return emitOpError() << "requires input dimension 2 to have size " << C
2966  << ", got " << inputC;
2967  }
2968  if (outputShape.hasRank()) {
2969  const int64_t outputN = outputShape.getDimSize(0);
2970  const int64_t outputK = outputShape.getDimSize(1);
2971  const int64_t outputC = outputShape.getDimSize(2);
2972  if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2973  N != outputN)
2974  return emitOpError() << "requires values_out dimension 0 to have size "
2975  << N << ", got " << outputN;
2976  if (K == ShapedType::kDynamic)
2977  K = outputK;
2978  else if (outputK != ShapedType::kDynamic && K != outputK)
2979  return emitOpError() << "requires values_out dimension 1 to have size "
2980  << K << ", got " << outputK;
2981  if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2982  C != outputC)
2983  return emitOpError() << "requires values_out dimension 2 to have size "
2984  << C << ", got " << outputC;
2985  }
2986  if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
2987  return emitOpError() << "requires dimensions K >= W, got K=" << K
2988  << " and W=" << W;
2989 
2990  return success();
2991 }
2992 
2993 static LogicalResult ReduceInferReturnTypes(
2994  ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
2995  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2996  int64_t axisVal = axis.getValue().getSExtValue();
2997  if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
2998  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
2999  return success();
3000  }
3001 
3002  SmallVector<int64_t> outputShape;
3003  operandShape.getDims(outputShape);
3004  outputShape[axisVal] = 1;
3005  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
3006  return success();
3007 }
3008 
3009 #define COMPATIBLE_RETURN_TYPES(OP) \
3010  bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
3011  if (l.size() != r.size() || l.size() != 1) \
3012  return false; \
3013  if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
3014  return false; \
3015  return succeeded(verifyCompatibleShape(l[0], r[0])); \
3016  }
3017 
3018 #define REDUCE_SHAPE_INFER(OP) \
3019  LogicalResult OP::inferReturnTypeComponents( \
3020  MLIRContext *context, ::std::optional<Location> location, \
3021  OP::Adaptor adaptor, \
3022  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3023  Type inputType = \
3024  llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
3025  ShapeAdaptor inputShape(adaptor.getInput().getType()); \
3026  const Properties &prop = adaptor.getProperties(); \
3027  return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
3028  inferredReturnShapes); \
3029  } \
3030  COMPATIBLE_RETURN_TYPES(OP)
3031 
3032 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
3033 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
3034 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
3035 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
3036 REDUCE_SHAPE_INFER(tosa::ReduceProductOp)
3037 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
3038 #undef REDUCE_SHAPE_INFER
3039 COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
3040 #undef COMPATIBLE_RETURN_TYPES
3041 
3042 template <typename T>
3043 static LogicalResult verifyReduceOp(T op) {
3044  // All TOSA reduce Ops have input, output and axis.
3045  TensorType inputType = op.getInput().getType();
3046  TensorType outputType = op.getOutput().getType();
3047  int32_t reduceAxis = op.getAxis();
3048 
3049  if (reduceAxis < 0) {
3050  op.emitOpError("reduce axis must not be negative");
3051  return failure();
3052  }
3053  if (inputType.hasRank()) {
3054  int64_t inputRank = inputType.getRank();
3055  // We allow for a special case where the input/output shape has rank 0 and
3056  // axis is also 0.
3057  if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
3058  op.emitOpError("expect input tensor rank (")
3059  << inputRank << ") to be larger than reduce axis (" << reduceAxis
3060  << ")";
3061  return failure();
3062  }
3063  }
3064  if (outputType.hasRank()) {
3065  int64_t outputRank = outputType.getRank();
3066  if (inputType.hasRank() && outputRank != inputType.getRank()) {
3067  op.emitOpError(
3068  "expect output tensor rank to be equal to input tensor rank");
3069  return failure();
3070  }
3071  if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
3072  op.emitOpError("expect output tensor rank (")
3073  << outputRank << ") to be larger than reduce axis (" << reduceAxis
3074  << ")";
3075  return failure();
3076  }
3077  // We can only verify the reduced dimension size to be 1 if this is not
3078  // the special case of output rank == 0.
3079  if (outputRank != 0) {
3080  auto outputShape = outputType.getShape();
3081  if (!outputType.isDynamicDim(reduceAxis) &&
3082  outputShape[reduceAxis] != 1) {
3083  op.emitOpError("expect reduced dimension size to be 1, got ")
3084  << outputShape[reduceAxis];
3085  return failure();
3086  }
3087  }
3088  }
3089  return success();
3090 }
3091 
3092 LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
3093 LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
3094 LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
3095 LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
3096 LogicalResult tosa::ReduceProductOp::verify() { return verifyReduceOp(*this); }
3097 LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
3098 
3099 static LogicalResult NAryInferReturnTypes(
3100  const ValueShapeRange &operands,
3101  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3102  llvm::SmallVector<int64_t> outShape;
3103  if (resolveBroadcastShape(operands, outShape).failed()) {
3104  inferredReturnShapes.push_back(ShapedTypeComponents());
3105  } else {
3106  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
3107  }
3108  return success();
3109 }
3110 
3111 #define NARY_SHAPE_INFER(OP) \
3112  LogicalResult OP::inferReturnTypeComponents( \
3113  MLIRContext *context, ::std::optional<Location> location, \
3114  ValueShapeRange operands, DictionaryAttr attributes, \
3115  OpaqueProperties properties, RegionRange regions, \
3116  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3117  return NAryInferReturnTypes(operands, inferredReturnShapes); \
3118  }
3119 
3120 NARY_SHAPE_INFER(tosa::AbsOp)
3121 NARY_SHAPE_INFER(tosa::AddOp)
3122 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
3123 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
3124 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
3125 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
3126 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
3127 NARY_SHAPE_INFER(tosa::CastOp)
3128 NARY_SHAPE_INFER(tosa::CeilOp)
3129 NARY_SHAPE_INFER(tosa::ClampOp)
3130 NARY_SHAPE_INFER(tosa::ClzOp)
3131 NARY_SHAPE_INFER(tosa::CosOp)
3132 NARY_SHAPE_INFER(tosa::ExpOp)
3133 NARY_SHAPE_INFER(tosa::FloorOp)
3134 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
3135 NARY_SHAPE_INFER(tosa::GreaterOp)
3136 NARY_SHAPE_INFER(tosa::IdentityOp)
3137 NARY_SHAPE_INFER(tosa::IntDivOp)
3138 NARY_SHAPE_INFER(tosa::LogOp)
3139 NARY_SHAPE_INFER(tosa::LogicalAndOp)
3140 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
3141 NARY_SHAPE_INFER(tosa::LogicalNotOp)
3142 NARY_SHAPE_INFER(tosa::LogicalOrOp)
3143 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
3144 NARY_SHAPE_INFER(tosa::LogicalXorOp)
3145 NARY_SHAPE_INFER(tosa::MaximumOp)
3146 NARY_SHAPE_INFER(tosa::MinimumOp)
3147 NARY_SHAPE_INFER(tosa::PowOp)
3148 NARY_SHAPE_INFER(tosa::ReciprocalOp)
3149 NARY_SHAPE_INFER(tosa::ReverseOp)
3150 NARY_SHAPE_INFER(tosa::RsqrtOp)
3151 NARY_SHAPE_INFER(tosa::SinOp)
3152 NARY_SHAPE_INFER(tosa::SelectOp)
3153 NARY_SHAPE_INFER(tosa::SubOp)
3154 NARY_SHAPE_INFER(tosa::TanhOp)
3155 NARY_SHAPE_INFER(tosa::ErfOp)
3156 NARY_SHAPE_INFER(tosa::SigmoidOp)
3157 #undef PRED_SHAPE_INFER
3158 
3159 LogicalResult tosa::NegateOp::inferReturnTypeComponents(
3160  MLIRContext *context, ::std::optional<Location> location,
3161  NegateOp::Adaptor adaptor,
3162  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3163  ShapeAdaptor inputShape(adaptor.getInput1().getType());
3164  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3165  return success();
3166 }
3167 
3168 LogicalResult tosa::NegateOp::verify() {
3169  // Verify same element type
3170  const Type input1Type = getInput1().getType();
3171  const Type outputType = getOutput().getType();
3172  if (verifySameElementTypes(*this, input1Type, outputType).failed())
3173  return failure();
3174 
3175  // Verify same shape
3176  const SmallVector<Type, 2> types = {input1Type, outputType};
3177  if (failed(verifyCompatibleShapes(types)))
3178  return emitOpError() << "requires the same shape for input1 and output";
3179 
3180  const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
3181  const Type input1ZpEType =
3182  getStorageElementTypeOrSelf(getInput1Zp().getType());
3183  if (input1EType != input1ZpEType) {
3184  return emitOpError("expect both input1 and its zero point are the same "
3185  "element type, got ")
3186  << input1EType << " and " << input1ZpEType;
3187  }
3188  const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
3189  const Type outputZpEType =
3190  getStorageElementTypeOrSelf(getOutputZp().getType());
3191  if (outputEType != outputZpEType) {
3192  return emitOpError("expect both output and its zero point are the same "
3193  "element type, got ")
3194  << outputEType << " and " << outputZpEType;
3195  }
3196 
3197  FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
3198  if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
3199  return failure();
3200 
3201  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3202  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3203  return failure();
3204 
3205  return success();
3206 }
3207 
3208 static LogicalResult poolingInferReturnTypes(
3209  ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
3210  ArrayRef<int64_t> pad,
3211  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3212  llvm::SmallVector<int64_t> outputShape;
3213  outputShape.resize(4, ShapedType::kDynamic);
3214 
3215  // We only know the rank if the input type is unranked.
3216  if (!inputShape) {
3217  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3218  return success();
3219  }
3220 
3221  // Batch and number of channels are identical for pooling layer.
3222  outputShape[0] = inputShape.getDimSize(0);
3223  outputShape[3] = inputShape.getDimSize(3);
3224 
3225  int64_t height = inputShape.getDimSize(1);
3226  int64_t width = inputShape.getDimSize(2);
3227 
3228  if (ShapedType::isStatic(height)) {
3229  int64_t padded = height + pad[0] + pad[1] - kernel[0];
3230  outputShape[1] = padded / stride[0] + 1;
3231  }
3232 
3233  if (ShapedType::isStatic(width)) {
3234  int64_t padded = width + pad[2] + pad[3] - kernel[1];
3235  outputShape[2] = padded / stride[1] + 1;
3236  }
3237 
3238  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3239  return success();
3240 }
3241 
3242 LogicalResult Conv2DOp::inferReturnTypeComponents(
3243  MLIRContext *context, ::std::optional<Location> location,
3244  Conv2DOp::Adaptor adaptor,
3245  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3246  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3247 
3248  int64_t inputWidth = ShapedType::kDynamic;
3249  int64_t inputHeight = ShapedType::kDynamic;
3250  int64_t weightWidth = ShapedType::kDynamic;
3251  int64_t weightHeight = ShapedType::kDynamic;
3252 
3253  // Input shape describes input width/height and batch.
3254 
3255  ShapeAdaptor inputShape(adaptor.getInput().getType());
3256  if (inputShape.hasRank()) {
3257  outputShape[0] = inputShape.getDimSize(0);
3258  inputHeight = inputShape.getDimSize(1);
3259  inputWidth = inputShape.getDimSize(2);
3260  }
3261 
3262  // Weight shapes describes the filter width/height and the output channels.
3263  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3264  if (weightShape.hasRank()) {
3265  outputShape[3] = weightShape.getDimSize(0);
3266  weightHeight = weightShape.getDimSize(1);
3267  weightWidth = weightShape.getDimSize(2);
3268  }
3269 
3270  // Bias shape can describe the output channels.
3271  ShapeAdaptor biasShape(adaptor.getBias().getType());
3272  if (biasShape.hasRank()) {
3273  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3274  ? biasShape.getDimSize(0)
3275  : outputShape[3];
3276  }
3277 
3278  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3279  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3280  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3281 
3282  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3283  int64_t inputSize = inputHeight + padding[0] + padding[1];
3284  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3285  int64_t unstridedResult = inputSize - filterSize + 1;
3286  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3287  }
3288 
3289  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3290  int64_t inputSize = inputWidth + padding[2] + padding[3];
3291  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3292  int64_t unstridedResult = inputSize - filterSize + 1;
3293  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3294  }
3295 
3296  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3297  return success();
3298 }
3299 
3300 LogicalResult Conv2DOp::verify() {
3301  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3302  verifyConvOpErrorIf(*this).failed())
3303  return failure();
3304  return success();
3305 }
3306 
3307 LogicalResult Conv3DOp::inferReturnTypeComponents(
3308  MLIRContext *context, ::std::optional<Location> location,
3309  Conv3DOp::Adaptor adaptor,
3310  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3311  llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
3312 
3313  int64_t inputWidth = ShapedType::kDynamic;
3314  int64_t inputHeight = ShapedType::kDynamic;
3315  int64_t inputDepth = ShapedType::kDynamic;
3316 
3317  int64_t weightWidth = ShapedType::kDynamic;
3318  int64_t weightHeight = ShapedType::kDynamic;
3319  int64_t weightDepth = ShapedType::kDynamic;
3320 
3321  // Input shape describes input width/height and batch.
3322  ShapeAdaptor inputShape(adaptor.getInput().getType());
3323  if (inputShape.hasRank()) {
3324  outputShape[0] = inputShape.getDimSize(0);
3325  inputDepth = inputShape.getDimSize(1);
3326  inputHeight = inputShape.getDimSize(2);
3327  inputWidth = inputShape.getDimSize(3);
3328  }
3329 
3330  // Weight shapes describes the filter width/height and the output channels.
3331  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3332  if (weightShape.hasRank()) {
3333  outputShape[4] = weightShape.getDimSize(0);
3334  weightDepth = weightShape.getDimSize(1);
3335  weightHeight = weightShape.getDimSize(2);
3336  weightWidth = weightShape.getDimSize(3);
3337  }
3338 
3339  // Bias shape can describe the output channels.
3340  ShapeAdaptor biasShape(adaptor.getBias().getType());
3341  if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
3342  outputShape[4] = biasShape.getDimSize(0);
3343  }
3344 
3345  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3346  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3347  llvm::ArrayRef<int64_t> pad = adaptor.getPad();
3348 
3349  if (ShapedType::isStatic(inputDepth) && ShapedType::isStatic(weightDepth)) {
3350  int32_t inputSize = inputDepth + pad[0] + pad[1];
3351  int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
3352  int32_t unstridedResult = inputSize - filterSize + 1;
3353  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3354  }
3355 
3356  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3357  int32_t inputSize = inputHeight + pad[2] + pad[3];
3358  int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
3359  int32_t unstridedResult = inputSize - filterSize + 1;
3360  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3361  }
3362 
3363  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3364  int32_t inputSize = inputWidth + pad[4] + pad[5];
3365  int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
3366  int32_t unstridedResult = inputSize - filterSize + 1;
3367  outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3368  }
3369 
3370  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3371  return success();
3372 }
3373 
3374 LogicalResult Conv3DOp::verify() {
3375  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3376  verifyConvOpErrorIf(*this).failed())
3377  return failure();
3378  return success();
3379 }
3380 
3381 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3382  MLIRContext *context, ::std::optional<Location> location,
3383  AvgPool2dOp::Adaptor adaptor,
3384  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3385  ShapeAdaptor inputShape(adaptor.getInput().getType());
3386  const Properties &prop = adaptor.getProperties();
3387  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3388  inferredReturnShapes);
3389 }
3390 
3391 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3392  MLIRContext *context, ::std::optional<Location> location,
3393  MaxPool2dOp::Adaptor adaptor,
3394  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3395  ShapeAdaptor inputShape(adaptor.getInput().getType());
3396  const Properties &prop = adaptor.getProperties();
3397  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3398  inferredReturnShapes);
3399 }
3400 
3401 LogicalResult MaxPool2dOp::verify() {
3402  if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
3403  /* outType = */ getOutput().getType())))
3404  return failure();
3405 
3406  if (failed(verifyPoolingOp(*this)))
3407  return failure();
3408 
3409  return success();
3410 }
3411 
3412 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3413  MLIRContext *context, ::std::optional<Location> location,
3414  DepthwiseConv2DOp::Adaptor adaptor,
3415  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3416  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3417 
3418  int64_t inputWidth = ShapedType::kDynamic;
3419  int64_t inputHeight = ShapedType::kDynamic;
3420  int64_t inputChannels = ShapedType::kDynamic;
3421 
3422  int64_t weightWidth = ShapedType::kDynamic;
3423  int64_t weightHeight = ShapedType::kDynamic;
3424  int64_t depthChannels = ShapedType::kDynamic;
3425 
3426  // Input shape describes input width/height and batch.
3427  ShapeAdaptor inputShape(adaptor.getInput().getType());
3428  if (inputShape.hasRank()) {
3429  outputShape[0] = inputShape.getDimSize(0);
3430  inputHeight = inputShape.getDimSize(1);
3431  inputWidth = inputShape.getDimSize(2);
3432  inputChannels = inputShape.getDimSize(3);
3433  }
3434 
3435  // Weight shapes describes the filter width/height and the output channels.
3436  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3437  if (weightShape.hasRank()) {
3438  weightHeight = weightShape.getDimSize(0);
3439  weightWidth = weightShape.getDimSize(1);
3440  inputChannels = ShapedType::isDynamic(inputChannels)
3441  ? weightShape.getDimSize(2)
3442  : inputChannels;
3443  depthChannels = weightShape.getDimSize(3);
3444  }
3445 
3446  // If both inputChannels and depthChannels are available we can determine
3447  // the output channels.
3448  if (ShapedType::isStatic(inputChannels) &&
3449  ShapedType::isStatic(depthChannels)) {
3450  outputShape[3] = inputChannels * depthChannels;
3451  }
3452 
3453  // Bias shape can describe the output channels.
3454  ShapeAdaptor biasShape(adaptor.getBias().getType());
3455  if (biasShape.hasRank()) {
3456  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3457  ? biasShape.getDimSize(0)
3458  : outputShape[3];
3459  }
3460 
3461  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3462  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3463  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3464 
3465  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3466  int64_t inputSize = inputHeight + padding[0] + padding[1];
3467  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3468  int64_t unstridedResult = inputSize - filterSize + 1;
3469  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3470  }
3471 
3472  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3473  int64_t inputSize = inputWidth + padding[2] + padding[3];
3474  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3475  int64_t unstridedResult = inputSize - filterSize + 1;
3476  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3477  }
3478 
3479  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3480  return success();
3481 }
3482 
3483 LogicalResult DepthwiseConv2DOp::verify() {
3484  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3485  verifyConvOpErrorIf(*this).failed())
3486  return failure();
3487  return success();
3488 }
3489 
3490 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3491  MLIRContext *context, ::std::optional<Location> location,
3492  TransposeConv2DOp::Adaptor adaptor,
3493  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3494  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3495 
3496  int64_t inputWidth = ShapedType::kDynamic;
3497  int64_t inputHeight = ShapedType::kDynamic;
3498  int64_t weightWidth = ShapedType::kDynamic;
3499  int64_t weightHeight = ShapedType::kDynamic;
3500 
3501  // Input shape describes input width/height and batch.
3502  ShapeAdaptor inputShape(adaptor.getInput().getType());
3503  if (inputShape.hasRank()) {
3504  outputShape[0] = ShapedType::isDynamic(outputShape[0])
3505  ? inputShape.getDimSize(0)
3506  : outputShape[0];
3507  inputHeight = inputShape.getDimSize(1);
3508  inputWidth = inputShape.getDimSize(2);
3509  }
3510 
3511  // Weight shapes describes the filter width/height and the output channels.
3512  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3513  if (weightShape.hasRank()) {
3514  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3515  ? weightShape.getDimSize(0)
3516  : outputShape[3];
3517  weightHeight = weightShape.getDimSize(1);
3518  weightWidth = weightShape.getDimSize(2);
3519  }
3520 
3521  // Bias shape can describe the output channels.
3522  ShapeAdaptor biasShape(adaptor.getInput().getType());
3523  if (biasShape.hasRank()) {
3524  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3525  ? biasShape.getDimSize(0)
3526  : outputShape[3];
3527  }
3528 
3529  llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
3530  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3531 
3532  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3533  int64_t calculateSize =
3534  (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3535  outputShape[1] =
3536  ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3537  }
3538 
3539  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3540  int64_t calculateSize =
3541  (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3542  outputShape[2] =
3543  ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3544  }
3545 
3546  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3547  return success();
3548 }
3549 
3550 LogicalResult TransposeConv2DOp::verify() {
3551  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
3552  return failure();
3553 
3554  const llvm::ArrayRef<int64_t> strides = getStride();
3555  const int64_t strideY = strides[0];
3556  const int64_t strideX = strides[1];
3557 
3558  if (strideY < 1 || strideX < 1)
3559  return emitOpError("expect all stride values to be >= 1, got [")
3560  << strides << "]";
3561 
3562  const auto checkPadAgainstKernelDim =
3563  [this](int64_t pad_value, int64_t kernel_dim_size,
3564  llvm::StringRef pad_name,
3565  llvm::StringRef kernel_dim_name) -> LogicalResult {
3566  if (pad_value <= -kernel_dim_size)
3567  return emitOpError("expected ")
3568  << pad_name << " > -" << kernel_dim_name
3569  << ", but got: " << pad_name << "=" << pad_value << " and "
3570  << kernel_dim_name << "=" << kernel_dim_size;
3571  return success();
3572  };
3573 
3574  const llvm::ArrayRef<int64_t> padding = getOutPad();
3575  const int64_t outPadTop = padding[0];
3576  const int64_t outPadBottom = padding[1];
3577  const int64_t outPadLeft = padding[2];
3578  const int64_t outPadRight = padding[3];
3579 
3580  const auto weightType =
3581  llvm::dyn_cast<RankedTensorType>(getWeight().getType());
3582 
3583  if (weightType) {
3584  const int64_t kernelHeight = weightType.getDimSize(1);
3585  if (ShapedType::isStatic(kernelHeight)) {
3586  if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
3587  "out_pad_top", "KH")))
3588  return failure();
3589 
3590  if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3591  "out_pad_bottom", "KH")))
3592  return failure();
3593  }
3594 
3595  const int64_t kernelWidth = weightType.getDimSize(2);
3596  if (ShapedType::isStatic(kernelWidth)) {
3597  if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
3598  "out_pad_left", "KW")))
3599  return failure();
3600 
3601  if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3602  "out_pad_right", "KW")))
3603  return failure();
3604  }
3605  }
3606 
3607  // Rest of the checks depend on the output type being a RankedTensorType
3608  const auto outputType =
3609  llvm::dyn_cast<RankedTensorType>(getOutput().getType());
3610  if (!outputType)
3611  return success();
3612 
3613  const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
3614  if (inputType && weightType) {
3615  const int64_t inputHeight = inputType.getDimSize(1);
3616  const int64_t kernelHeight = weightType.getDimSize(1);
3617  const int64_t outputHeight = outputType.getDimSize(1);
3618 
3619  if (ShapedType::isStatic(inputHeight) &&
3620  ShapedType::isStatic(outputHeight)) {
3621  if (outputHeight !=
3622  (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3623  return emitOpError(
3624  "dimension mismatch: expected OH == (IH - 1) * stride_y "
3625  "+ out_pad_top + out_pad_bottom + KH, but got ")
3626  << outputHeight << " != (" << inputHeight << " - 1) * "
3627  << strideY << " + " << outPadTop << " + " << outPadBottom
3628  << " + " << kernelHeight;
3629  }
3630 
3631  const int64_t inputWidth = inputType.getDimSize(2);
3632  const int64_t kernelWidth = weightType.getDimSize(2);
3633  const int64_t outputWidth = outputType.getDimSize(2);
3634 
3635  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
3636  if (outputWidth !=
3637  (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3638  return emitOpError(
3639  "dimension mismatch: expected OW == (IW - 1) * stride_x "
3640  "+ out_pad_left + out_pad_right + KW, but got ")
3641  << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
3642  << " + " << outPadLeft << " + " << outPadRight << " + "
3643  << kernelWidth;
3644  }
3645  }
3646 
3647  const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().getType());
3648 
3649  if (!biasType)
3650  return success();
3651 
3652  const int64_t biasChannels = biasType.getDimSize(0);
3653 
3654  // Skip further checks if bias is dynamic
3655  if (biasChannels == ShapedType::kDynamic)
3656  return success();
3657 
3658  const int64_t outputChannels = outputType.getDimSize(3);
3659  if (!ShapedType::isDynamic(outputChannels) &&
3660  biasChannels != outputChannels && biasChannels != 1)
3661  return emitOpError(
3662  "bias channels expected to be equal to output channels (")
3663  << outputChannels << ") or 1, got " << biasChannels;
3664 
3665  return success();
3666 }
3667 
3668 LogicalResult RescaleOp::verify() {
3669  auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
3670  if (!inputType) {
3671  emitOpError("expect shaped tensor for input, got ") << getInput().getType();
3672  return failure();
3673  }
3674 
3675  auto inputElementType =
3676  getStorageElementTypeOrSelf(inputType.getElementType());
3677  if (!mlir::isa<IntegerType>(inputElementType)) {
3678  emitOpError("expect input to have integer element type, got ")
3679  << inputElementType;
3680  return failure();
3681  }
3682 
3683  auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
3684  if (!outputType) {
3685  emitOpError("expect shaped tensor for output, got ")
3686  << getOutput().getType();
3687  return failure();
3688  }
3689 
3690  auto outputElementType =
3691  getStorageElementTypeOrSelf(outputType.getElementType());
3692  if (!mlir::isa<IntegerType>(outputElementType)) {
3693  emitOpError("expect output to have integer element type, got ")
3694  << outputElementType;
3695  return failure();
3696  }
3697 
3698  if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
3699  .failed())
3700  return failure();
3701 
3702  if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
3703  .failed())
3704  return failure();
3705 
3706  FailureOr<int64_t> maybeIZp = getInputZeroPoint();
3707  if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
3708  return failure();
3709 
3710  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3711  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3712  return failure();
3713 
3714  auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
3715  if (!multiplierType) {
3716  emitOpError("expect shaped tensor for multiplier, got ")
3717  << getMultiplier().getType();
3718  return failure();
3719  }
3720 
3721  auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
3722  if (!shiftType) {
3723  emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
3724  return failure();
3725  }
3726 
3727  // multiplier element type must be i32 for scale32 = true
3728  if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
3729  emitOpError("expect i32 element type for multiplier for scale32=true, got ")
3730  << multiplierType.getElementType();
3731  return failure();
3732  }
3733 
3734  // multiplier element type must be i16 for scale32 = false
3735  if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
3736  emitOpError(
3737  "expect i16 element type for multiplier for scale32=false, got ")
3738  << multiplierType.getElementType();
3739  return failure();
3740  }
3741 
3742  if (!inputType.hasRank())
3743  return success();
3744 
3745  // multiplier/shift must have shape = {numChannels},
3746  // where numChannel is 1 if per_channel = false
3747  // otherwise numChannel is dimension in input shape's last axis
3748  int64_t numChannels = 1;
3749  if (getPerChannel()) {
3750  if (inputType.getRank() < 1) {
3751  emitOpError("requires input to be at least rank 1 when per_channel is "
3752  "true, but got rank ")
3753  << inputType.getRank();
3754  return failure();
3755  }
3756  numChannels = inputType.getDimSize(inputType.getRank() - 1);
3757  }
3758 
3759  if (!multiplierType.hasRank())
3760  return success();
3761 
3762  ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
3763  // multiplier input has rank 1 by dialect definition
3764  if (multiplierShape[0] != ShapedType::kDynamic &&
3765  multiplierShape[0] != numChannels) {
3766  emitOpError("expect shape of { ")
3767  << numChannels << " } for multiplier input, got { "
3768  << multiplierShape[0] << " }";
3769  return failure();
3770  }
3771 
3772  if (!shiftType.hasRank())
3773  return success();
3774 
3775  ArrayRef<int64_t> shiftShape = shiftType.getShape();
3776  // shift input has rank 1 by dialect definition
3777  if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3778  emitOpError("expect shape of { ")
3779  << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
3780  return failure();
3781  }
3782 
3783  return success();
3784 }
3785 
3786 LogicalResult RescaleOp::inferReturnTypeComponents(
3787  MLIRContext *context, ::std::optional<Location> location,
3788  RescaleOp::Adaptor adaptor,
3789  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3790  ShapeAdaptor inputShape(adaptor.getInput().getType());
3791  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3792  return success();
3793 }
3794 
3795 LogicalResult IfOp::inferReturnTypeComponents(
3796  MLIRContext *context, ::std::optional<Location> location,
3797  IfOp::Adaptor adaptor,
3798  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3800  for (Region *region : adaptor.getRegions()) {
3801  for (auto &block : *region)
3802  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3803  yieldOps.push_back(returnOp);
3804  }
3805 
3806  if (yieldOps.empty())
3807  return failure();
3808 
3809  // Get the initial type information for the yield op.
3810  llvm::SmallVector<ValueKnowledge> resultKnowledge;
3811  resultKnowledge.reserve(yieldOps.front().getNumOperands());
3812  for (auto operand : yieldOps.front().getOperands()) {
3813  resultKnowledge.push_back(
3814  ValueKnowledge::getKnowledgeFromType(operand.getType()));
3815  }
3816 
3817  for (auto yieldOp : yieldOps) {
3818  if (resultKnowledge.size() != yieldOp.getNumOperands())
3819  return failure();
3820 
3821  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
3822  int32_t index = it.index();
3823  auto meet = ValueKnowledge::meet(
3824  resultKnowledge[index],
3825  ValueKnowledge::getKnowledgeFromType(it.value().getType()));
3826  if (!meet)
3827  continue;
3828  resultKnowledge[index] = meet;
3829  }
3830  }
3831 
3832  for (const ValueKnowledge &result : resultKnowledge) {
3833  inferredReturnShapes.push_back(result.getShapedTypeComponents());
3834  }
3835 
3836  return success();
3837 }
3838 
3839 LogicalResult WhileOp::inferReturnTypeComponents(
3840  MLIRContext *context, ::std::optional<Location> location,
3841  WhileOp::Adaptor adaptor,
3842  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3844  for (auto &block : adaptor.getBodyGraph())
3845  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3846  yieldOps.push_back(returnOp);
3847 
3848  // TOSA's while must have a tosa.yield as its terminator. If not found this
3849  // tosa.while is invalid.
3850  if (yieldOps.empty())
3851  return failure();
3852 
3853  // Get the initial type information from the operand types.
3854  llvm::SmallVector<ValueKnowledge> resultKnowledge;
3855  resultKnowledge.reserve(yieldOps.front().getNumOperands());
3856  for (auto operand : yieldOps.front().getOperands()) {
3857  resultKnowledge.push_back(
3858  ValueKnowledge::getKnowledgeFromType(operand.getType()));
3859  }
3860 
3861  for (auto yieldOp : yieldOps) {
3862  if (resultKnowledge.size() != yieldOp.getNumOperands())
3863  return failure();
3864 
3865  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
3866  int32_t index = it.index();
3867  if (auto meet = ValueKnowledge::meet(
3868  resultKnowledge[index],
3869  ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
3870  resultKnowledge[index] = meet;
3871  }
3872  }
3873  }
3874 
3875  for (const ValueKnowledge &result : resultKnowledge) {
3876  inferredReturnShapes.push_back(result.getShapedTypeComponents());
3877  }
3878 
3879  return success();
3880 }
3881 
3882 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
3883  if (auto vt = llvm::dyn_cast<VectorType>(getType()))
3884  return llvm::to_vector<4>(vt.getShape());
3885  return std::nullopt;
3886 }
3887 
3889  Block::BlockArgListType blocksArgs,
3890  ValueRange initializers,
3891  StringRef prefix = "") {
3892  assert(blocksArgs.size() == initializers.size() &&
3893  "expected same length of arguments and initializers");
3894  if (initializers.empty())
3895  return;
3896 
3897  parser << prefix << '(';
3898  llvm::interleaveComma(
3899  llvm::zip(blocksArgs, initializers), parser,
3900  [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
3901  parser << ")";
3902 }
3903 
3904 // parse and print of IfOp refer to the implementation of SCF dialect.
3905 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
3906  // Create the regions for 'then'.
3907  result.regions.reserve(2);
3908  Region *thenRegion = result.addRegion();
3909  Region *elseRegion = result.addRegion();
3910 
3912 
3913  if (parser.parseOperand(cond))
3914  return failure();
3915 
3918 
3919  // Parse the optional block arguments
3920  OptionalParseResult listResult =
3921  parser.parseOptionalAssignmentList(regionArgs, operands);
3922  if (listResult.has_value() && failed(listResult.value()))
3923  return failure();
3924 
3925  // Parse a colon.
3926  if (failed(parser.parseColon()))
3927  return parser.emitError(parser.getCurrentLocation(),
3928  "expected type for condition operand");
3929 
3930  // Parse the type of the condition operand
3931  Type condType;
3932  if (failed(parser.parseType(condType)))
3933  return parser.emitError(parser.getCurrentLocation(),
3934  "expected type for condition operand");
3935 
3936  // Resolve operand with provided type
3937  if (failed(parser.resolveOperand(cond, condType, result.operands)))
3938  return failure();
3939 
3940  // Parse optional block arg types
3941  if (listResult.has_value()) {
3942  FunctionType functionType;
3943 
3944  if (failed(parser.parseType(functionType)))
3945  return parser.emitError(parser.getCurrentLocation())
3946  << "expected list of types for block arguments "
3947  << "followed by arrow type and list of return types";
3948 
3949  result.addTypes(functionType.getResults());
3950 
3951  if (functionType.getNumInputs() != operands.size()) {
3952  return parser.emitError(parser.getCurrentLocation())
3953  << "expected as many input types as operands "
3954  << "(expected " << operands.size() << " got "
3955  << functionType.getNumInputs() << ")";
3956  }
3957 
3958  // Resolve input operands.
3959  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3960  parser.getCurrentLocation(),
3961  result.operands)))
3962  return failure();
3963  } else {
3964  // Parse optional results type list.
3965  if (parser.parseOptionalArrowTypeList(result.types))
3966  return failure();
3967  }
3968 
3969  // Parse the 'then' region.
3970  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
3971  return failure();
3972 
3973  // If we find an 'else' keyword then parse the 'else' region.
3974  if (!parser.parseOptionalKeyword("else")) {
3975  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
3976  return failure();
3977  }
3978 
3979  // Parse the optional attribute list.
3980  if (parser.parseOptionalAttrDict(result.attributes))
3981  return failure();
3982  return success();
3983 }
3984 
3985 void IfOp::print(OpAsmPrinter &p) {
3986  p << " " << getCondition();
3987 
3988  printInitializationList(p, getThenGraph().front().getArguments(),
3989  getInputList(), " ");
3990  p << " : ";
3991  p << getCondition().getType();
3992 
3993  if (!getInputList().empty()) {
3994  p << " (";
3995  llvm::interleaveComma(getInputList().getTypes(), p);
3996  p << ")";
3997  }
3998  p.printArrowTypeList(getResultTypes());
3999  p << " ";
4000 
4001  p.printRegion(getThenGraph());
4002 
4003  // Print the 'else' regions if it exists and has a block.
4004  auto &elseRegion = getElseGraph();
4005  if (!elseRegion.empty()) {
4006  p << " else ";
4007  p.printRegion(elseRegion);
4008  }
4009 
4010  p.printOptionalAttrDict((*this)->getAttrs());
4011 }
4012 
4013 LogicalResult IfOp::verify() {
4014  if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
4015  "'then_graph' arguments", getInputList(),
4016  "'input_list'")
4017  .failed())
4018  return failure();
4019 
4020  if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(),
4021  "'else_graph' arguments", getInputList(),
4022  "'input_list'")
4023  .failed())
4024  return failure();
4025 
4026  auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
4027  if (errorIfTypeOrShapeMismatch(*this, thenYield.getInputs(),
4028  "'then_graph' results", getOutputList(),
4029  "'output_list'")
4030  .failed())
4031  return failure();
4032 
4033  auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
4034  if (errorIfTypeOrShapeMismatch(*this, elseYield.getInputs(),
4035  "'else_graph' results", getOutputList(),
4036  "'output_list'")
4037  .failed())
4038  return failure();
4039 
4040  auto condType = getCondition().getType();
4041  if (errorIfShapeNotSizeOne(*this, condType).failed())
4042  return emitOpError() << "'condition' must be a size 1 tensor, got "
4043  << condType;
4044 
4045  return success();
4046 }
4047 
4048 LogicalResult WhileOp::verify() {
4049  if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'",
4050  getOutputList(), "'output_list'")
4051  .failed())
4052  return failure();
4053 
4054  if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(),
4055  "'cond_graph' arguments", getInputList(),
4056  "'input_list'")
4057  .failed())
4058  return failure();
4059 
4060  if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(),
4061  "'body_graph' arguments", getInputList(),
4062  "'input_list'")
4063  .failed())
4064  return failure();
4065 
4066  auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
4067  if (errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
4068  "'body_graph' results", getInputList(),
4069  "'input_list'")
4070  .failed())
4071  return failure();
4072 
4073  // Condition block output must be a single element tensor with a single bool
4074  // value.
4075  auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
4076  if (condYield.getInputs().size() != 1)
4077  return emitOpError() << "require 'cond_graph' only have one result";
4078 
4079  auto condOutType = condYield.getInputs()[0].getType();
4080  if (errorIfShapeNotSizeOne(*this, condOutType).failed())
4081  return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
4082  << condOutType;
4083 
4084  if (!getElementTypeOrSelf(condOutType).isInteger(1))
4085  return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
4086  << condOutType;
4087 
4088  return success();
4089 }
4090 
4091 LogicalResult ReverseOp::verify() {
4092  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
4093  /* outType = */ getOutput().getType())
4094  .failed())
4095  return failure();
4096  TensorType inputType = getInput1().getType();
4097  TensorType outputType = getOutput().getType();
4098  int32_t reverseAxis = getAxis();
4099 
4100  if (reverseAxis < 0)
4101  return emitOpError("expected non-negative reverse axis");
4102  if (inputType.hasRank()) {
4103  int64_t inputRank = inputType.getRank();
4104  // We allow for a special case where the input/output shape has rank 0 and
4105  // axis is also 0.
4106  if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
4107  return emitOpError("expect input tensor rank (")
4108  << inputRank << ") to be larger than reverse axis (" << reverseAxis
4109  << ")";
4110  }
4111  if (outputType.hasRank()) {
4112  int64_t outputRank = outputType.getRank();
4113  if (inputType.hasRank() && outputRank != inputType.getRank())
4114  return emitOpError(
4115  "expect output tensor rank to be equal to input tensor rank");
4116  if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
4117  return emitOpError("expect output tensor rank (")
4118  << outputRank << ") to be larger than reverse axis ("
4119  << reverseAxis << ")";
4120  }
4121  return success();
4122 }
4123 
4124 LogicalResult tosa::SelectOp::verify() {
4125  // verify input2 and input3 have same element type as output
4126  if (verifySameElementTypes(*this, /* inType = */ getOnTrue().getType(),
4127  /* outType = */ getOutput().getType())
4128  .failed() ||
4129  verifySameElementTypes(*this, /* inType = */ getOnFalse().getType(),
4130  /* outType = */ getOutput().getType())
4131  .failed()) {
4132  return failure();
4133  }
4134  // verify input1 has element type of bool
4135  auto predicateType = llvm::dyn_cast<ShapedType>(getPred().getType());
4136  if (!predicateType) {
4137  return emitOpError("expect shaped tensor for input1, got ")
4138  << getInput1().getType();
4139  }
4140  auto predicateElementType = predicateType.getElementType();
4141  if (!predicateElementType.isInteger(1)) {
4142  return emitOpError("expect element type of bool for input1, got ")
4143  << predicateElementType;
4144  }
4145 
4146  return success();
4147 }
4148 
4149 LogicalResult tosa::VariableOp::verify() {
4150  StringRef symName = getName();
4151  FailureOr<tosa::VariableOp> varOp = findVariableDecl(*this, symName);
4152  if (succeeded(varOp))
4153  return emitOpError("illegal to have multiple declaration of '")
4154  << symName << "'";
4155 
4156  return success();
4157 }
4158 
4159 LogicalResult tosa::VariableReadOp::verify() {
4160  if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
4161  .failed())
4162  return failure();
4163 
4164  return success();
4165 }
4166 
4167 LogicalResult tosa::VariableWriteOp::verify() {
4168  if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
4169  .failed())
4170  return failure();
4171 
4172  return success();
4173 }
4174 
4175 // parse and print of WhileOp refer to the implementation of SCF dialect.
4176 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
4179  Region *cond = result.addRegion();
4180  Region *body = result.addRegion();
4181 
4182  OptionalParseResult listResult =
4183  parser.parseOptionalAssignmentList(regionArgs, operands);
4184  if (listResult.has_value() && failed(listResult.value()))
4185  return failure();
4186 
4187  FunctionType functionType;
4188  SMLoc typeLoc = parser.getCurrentLocation();
4189  if (failed(parser.parseColonType(functionType)))
4190  return failure();
4191 
4192  result.addTypes(functionType.getResults());
4193 
4194  if (functionType.getNumInputs() != operands.size()) {
4195  return parser.emitError(typeLoc)
4196  << "expected as many input types as operands "
4197  << "(expected " << operands.size() << " got "
4198  << functionType.getNumInputs() << ")";
4199  }
4200 
4201  // Resolve input operands.
4202  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
4203  parser.getCurrentLocation(),
4204  result.operands)))
4205  return failure();
4206 
4207  // Propagate the types into the region arguments.
4208  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
4209  regionArgs[i].type = functionType.getInput(i);
4210 
4211  return failure(parser.parseRegion(*cond, regionArgs) ||
4212  parser.parseKeyword("do") || parser.parseRegion(*body) ||
4214 }
4215 
4216 void WhileOp::print(OpAsmPrinter &parser) {
4217  printInitializationList(parser, getCondGraph().front().getArguments(),
4218  getInputList(), " ");
4219  parser << " : ";
4220  parser.printFunctionalType(getInputList().getTypes(),
4221  getResults().getTypes());
4222  parser << ' ';
4223  parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
4224  parser << " do ";
4225  parser.printRegion(getBodyGraph());
4226  parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
4227 }
4228 
4229 // Create a rank-1 const tensor for zero point of the source tensor.
4230 std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
4231  Location loc,
4232  Type srcElemType,
4233  int64_t zp) {
4234  srcElemType = getStorageElementTypeOrSelf(srcElemType);
4235  auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
4236  if (llvm::isa<FloatType>(srcElemType)) {
4237  auto zpAttr = DenseElementsAttr::get(
4238  zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
4239  return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4240  }
4241  if (llvm::isa<IntegerType>(srcElemType)) {
4242  auto zpAttr =
4243  DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp));
4244  return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4245  }
4246  llvm::errs() << "zero point is not allowed for unsupported data types\n";
4247  return std::nullopt;
4248 }
4249 
4250 //===----------------------------------------------------------------------===//
4251 // TOSA Shape and Shape Operators Helper functions.
4252 //===----------------------------------------------------------------------===//
4253 
4255  return mlir::isa<tosa::shapeType>(t);
4256 }
4257 
4258 LogicalResult
4260  int rank) {
4261  if (rank < 0)
4262  return emitError() << "invalid rank (must be >= 0): " << rank;
4263  return success();
4264 }
4265 
4267  for (auto v : op->getOperands()) {
4268  if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
4269  Operation *definingOp = v.getDefiningOp();
4270  if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
4271  return op->emitOpError("shape operand is not compile time resolvable");
4272  }
4273  }
4274  }
4275  return success();
4276 }
4277 
4279  for (auto type : op->getOperandTypes()) {
4280  if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4281  return op->emitOpError("must have operands with tosa shape type");
4282  }
4283  }
4284  for (auto type : op->getResultTypes()) {
4285  if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4286  return op->emitOpError("must have result with tosa shape type");
4287  }
4288  }
4289  return success();
4290 }
4291 
4292 LogicalResult
4296  return failure();
4297 
4298  // delegate function that returns rank of shape type
4299  auto getRank = [](const Type type) {
4300  return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4301  };
4302  auto operandTypes = op->getOperandTypes();
4303  auto resultTypes = op->getResultTypes();
4304 
4305  auto rank = getRank(*op->getOperandTypes().begin());
4306  for (auto type : operandTypes) {
4307  if (getRank(type) != rank) {
4308  return op->emitOpError("operands don't have matching ranks");
4309  }
4310  }
4311  for (auto type : resultTypes) {
4312  if (getRank(type) != rank) {
4313  return op->emitOpError("result shape has different rank than operands");
4314  }
4315  }
4316  return success();
4317 }
4318 
4319 //===----------------------------------------------------------------------===//
4320 // TOSA Shape Operators verify functions.
4321 //===----------------------------------------------------------------------===//
4322 
4323 LogicalResult tosa::ConstShapeOp::verify() {
4324  // check one dimensional rank
4325  auto valuesRank = getValues().getType().getRank();
4326  if (valuesRank != 1)
4327  return emitOpError("expect elements in attribute values with rank 1");
4328  // check that number of elements in values attr equal to rank of result shape
4329  auto count = getValues().getNumElements();
4330  auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
4331  if (!(count == rank || (count == 1 && rank == 0))) {
4332  return emitOpError("expect number of elements in attribute values (")
4333  << count << ") to be equal to the rank (" << rank
4334  << ") for the result shape type";
4335  }
4336  return success();
4337 }
4338 
4339 //===----------------------------------------------------------------------===//
4340 // TOSA Attribute Definitions.
4341 //===----------------------------------------------------------------------===//
4342 
4343 #define GET_ATTRDEF_CLASSES
4344 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
4345 
4346 //===----------------------------------------------------------------------===//
4347 // TOSA Type Definitions.
4348 //===----------------------------------------------------------------------===//
4349 #define GET_TYPEDEF_CLASSES
4350 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
4351 
4352 //===----------------------------------------------------------------------===//
4353 // TOSA Operator Definitions.
4354 //===----------------------------------------------------------------------===//
4355 
4356 #define GET_OP_CLASSES
4357 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
Definition: AMXDialect.cpp:85
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition: TosaOps.cpp:1272
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
Definition: TosaOps.cpp:2420
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType)
Definition: TosaOps.cpp:963
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:2993
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
Definition: TosaOps.cpp:532
static FailureOr< tosa::VariableOp > findVariableDecl(Operation *op, StringRef symName)
Definition: TosaOps.cpp:910
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
Definition: TosaOps.cpp:900
#define REDUCE_SHAPE_INFER(OP)
Definition: TosaOps.cpp:3018
static LogicalResult verifyConvOp(T op)
Definition: TosaOps.cpp:572
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
Definition: TosaOps.cpp:944
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:3208
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition: TosaOps.cpp:1386
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
Definition: TosaOps.cpp:1400
static LogicalResult verifyReduceOp(T op)
Definition: TosaOps.cpp:3043
#define NARY_SHAPE_INFER(OP)
Definition: TosaOps.cpp:3111
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
Definition: TosaOps.cpp:2490
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition: TosaOps.cpp:1250
static LogicalResult verifyConvOpErrorIf(T op)
Definition: TosaOps.cpp:718
static LogicalResult verifyConvOpModes(T op)
Definition: TosaOps.cpp:677
std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition: TosaOps.cpp:515
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:3099
#define COMPATIBLE_RETURN_TYPES(OP)
Definition: TosaOps.cpp:3009
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition: TosaOps.cpp:1431
Type getStorageElementTypeOrSelf(Type type)
Definition: TosaOps.cpp:521
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
Definition: TosaOps.cpp:1346
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition: TosaOps.cpp:1226
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
Definition: TosaOps.cpp:1301
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
Definition: TosaOps.cpp:858
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition: TosaOps.cpp:136
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
Definition: TosaOps.cpp:2448
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Definition: TosaOps.cpp:3888
static LogicalResult verifyPoolingOp(T op)
Definition: TosaOps.cpp:1026
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
Definition: TosaOps.cpp:1520
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:249
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:188
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:94
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:205
This class indicates that op operates on tosa shape types.
Definition: TosaOps.h:130
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:40
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:53
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:50
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
Definition: Operation.cpp:919
LogicalResult verifyTosaShapeOperator(Operation *op)
Definition: TosaOps.cpp:4278
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
Definition: TosaOps.cpp:4293
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
Definition: TosaOps.cpp:4266
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition: Traits.cpp:59
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Definition: QuantUtils.cpp:198
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
Definition: QuantUtils.cpp:289
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
Definition: TosaOps.cpp:225
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
Definition: QuantUtils.cpp:269
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
Definition: QuantUtils.cpp:162
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
Definition: TosaOps.cpp:250
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
Definition: QuantUtils.cpp:214
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
Definition: TosaOps.cpp:4230
bool isa_tosa_shape_type(mlir::Type t)
Definition: TosaOps.cpp:4254
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Definition: QuantUtils.cpp:243
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
Definition: TosaOps.cpp:552
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
Definition: ShapeUtils.h:33
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45