MLIR  22.0.0git
TosaOps.cpp
Go to the documentation of this file.
1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://www.mlplatform.org/tosa/tosa_spec.html
12 //
13 //===----------------------------------------------------------------------===//
14 
22 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/TypeUtilities.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 
31 #include <numeric>
32 
33 using namespace mlir;
34 using namespace mlir::tosa;
35 
36 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
38 
39 //===----------------------------------------------------------------------===//
40 // Tosa dialect interface includes.
41 //===----------------------------------------------------------------------===//
42 
43 #include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
44 #include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
45 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
46 #include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
47 
48 namespace {
49 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
50 
51 //===----------------------------------------------------------------------===//
52 // Dialect Function Inliner Interface.
53 //===----------------------------------------------------------------------===//
54 struct TosaInlinerInterface : public DialectInlinerInterface {
56 
57  //===--------------------------------------------------------------------===//
58  // Analysis Hooks.
59  //===--------------------------------------------------------------------===//
60 
61  /// All operations can be inlined by default.
62  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
63  IRMapping &map) const final {
64  return true;
65  }
66 
67  /// All regions with If and While parent operators can be inlined.
68  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
69  IRMapping &map) const final {
70  return (isa<tosa::IfOp>(dest->getParentOp()) ||
71  isa<tosa::WhileOp>(dest->getParentOp()));
72  }
73 };
74 
75 /// This class implements the bytecode interface for the Tosa dialect.
76 struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
77  TosaDialectBytecodeInterface(Dialect *dialect)
78  : BytecodeDialectInterface(dialect) {}
79 
80  //===--------------------------------------------------------------------===//
81  // Attributes
82 
83  Attribute readAttribute(DialectBytecodeReader &reader) const override {
84  return ::readAttribute(getContext(), reader);
85  }
86 
87  LogicalResult writeAttribute(Attribute attr,
88  DialectBytecodeWriter &writer) const override {
89  return ::writeAttribute(attr, writer);
90  }
91 
92  //===--------------------------------------------------------------------===//
93  // Types
94 
95  Type readType(DialectBytecodeReader &reader) const override {
96  return ::readType(getContext(), reader);
97  }
98 
99  LogicalResult writeType(Type type,
100  DialectBytecodeWriter &writer) const override {
101  return ::writeType(type, writer);
102  }
103 
104  void writeVersion(DialectBytecodeWriter &writer) const final {
105  // TODO: Populate.
106  }
107 
108  std::unique_ptr<DialectVersion>
109  readVersion(DialectBytecodeReader &reader) const final {
110  // TODO: Populate
111  reader.emitError("Dialect does not support versioning");
112  return nullptr;
113  }
114 
115  LogicalResult upgradeFromVersion(Operation *topLevelOp,
116  const DialectVersion &version) const final {
117  return success();
118  }
119 };
120 
121 } // namespace
122 
123 //===----------------------------------------------------------------------===//
124 // TOSA control flow support.
125 //===----------------------------------------------------------------------===//
126 
127 /// Returns the while loop body.
128 SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
129  return {&getBodyGraph()};
130 }
131 
132 //===----------------------------------------------------------------------===//
133 // TOSA variable operator support.
134 //===----------------------------------------------------------------------===//
135 
137  return to_vector(llvm::map_range(shape, [](int64_t dim) {
138  return dim == -1 ? ShapedType::kDynamic : dim;
139  }));
140 }
141 
142 // returns type of variable op
143 RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
144  Type elementType = variableOp.getType();
145  DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
146  auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
147  return RankedTensorType::get(shape, elementType);
148 }
149 
150 //===----------------------------------------------------------------------===//
151 // Tosa dialect initialization.
152 //===----------------------------------------------------------------------===//
153 
154 void TosaDialect::initialize() {
155  addTypes<
156 #define GET_TYPEDEF_LIST
157 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
158  >();
159  addOperations<
160 #define GET_OP_LIST
161 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
162  >();
163  addAttributes<
164 #define GET_ATTRDEF_LIST
165 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
166  >();
167  addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
168  declarePromisedInterfaces<
169  shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
170  ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
171  LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
172  LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
173  BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
174  NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
175  GreaterEqualOp, MatMulOp>();
176 }
177 
179  Type type, Location loc) {
180  // Tosa dialect constants only support ElementsAttr unlike standard dialect
181  // constant which supports all attributes.
182  if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
183  return tosa::ConstShapeOp::create(builder, loc, type,
184  llvm::cast<DenseIntElementsAttr>(value));
185  }
186  if (llvm::isa<ElementsAttr>(value))
187  return tosa::ConstOp::create(builder, loc, type,
188  llvm::cast<ElementsAttr>(value));
189  return nullptr;
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // Parsers and printers
194 //===----------------------------------------------------------------------===//
195 
196 namespace {
197 
198 ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
199  DenseElementsAttr &varShapeAttr,
200  TypeAttr &typeAttr) {
201  if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
202  if (!shapedType.hasRank())
203  return parser.emitError(parser.getCurrentLocation())
204  << "expected ranked type";
205 
206  auto elementType = shapedType.getElementType();
207  typeAttr = TypeAttr::get(elementType);
208  ArrayRef<int64_t> shape = shapedType.getShape();
209  Builder builder(parser.getContext());
210  varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
211  return success();
212  }
213  return parser.emitError(parser.getCurrentLocation())
214  << "expected shaped type";
215 }
216 
217 } // namespace
218 
219 // parses the optional initial value or type for a tosa variable
220 // with initial value:
221 // tosa.variable @name = dense<0.0> : tensor<1x8xf32>
222 //
223 // without initial value:
224 // tosa.variable @name : tensor<1x8xf32>
226  OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
227  Attribute &initialValueAttr) {
228  if (succeeded(parser.parseOptionalEqual())) {
229  if (failed(parser.parseAttribute(initialValueAttr))) {
230  return parser.emitError(parser.getCurrentLocation())
231  << "expected attribute";
232  }
233  if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
234  return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
235  typeAttr);
236  }
237  return parser.emitError(parser.getCurrentLocation())
238  << "expected Typed attr";
239  }
240 
241  initialValueAttr = nullptr;
242  Type parsedType;
243  if (failed(parser.parseColonType(parsedType))) {
244  return parser.emitError(parser.getCurrentLocation())
245  << "expected type after colon";
246  }
247  return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
248 }
249 
251  OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
252  TypeAttr typeAttr, Attribute initialValueAttr) {
253  bool needsSpace = false;
254  if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
255  auto shape =
256  convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
257  Type elementType = typeAttr.getValue();
258  RankedTensorType tensorType =
259  RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
260  auto tensorTypeAttr = TypeAttr::get(tensorType);
261  p << ": ";
262  p.printAttribute(tensorTypeAttr);
263  needsSpace = true; // subsequent attr value needs a space separator
264  }
265  if (initialValueAttr) {
266  if (needsSpace)
267  p << ' ';
268  p << "= ";
269  p.printAttribute(initialValueAttr);
270  }
271 }
272 
273 namespace {
274 
275 // parse attributes with special handling for tosa enum attributes
276 template <typename EnumType>
277 ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser,
278  NamedAttrList &outAttrs) {
279  llvm::StringRef name;
280  if (parser.parseOptionalKeyword(&name) || parser.parseEqual())
281  return failure();
282 
283  // special handling: rounding_mode accepts a *bare* RoundingMode enum
284  // keyword.
285  llvm::StringRef kw;
286  if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
287  if (name == "rounding_mode" &&
288  succeeded(parser.parseOptionalKeyword(&kw))) {
289  auto sym = symbolizeRoundingMode(kw);
290  if (!sym)
291  return parser.emitError(parser.getCurrentLocation())
292  << "invalid rounding_mode value: " << kw;
293  auto attr = RoundingModeAttr::get(parser.getContext(), sym.value());
294  outAttrs.push_back(NamedAttribute(name, attr));
295  return success();
296  }
297  }
298  // special handling: mode accepts a *bare* ResizeMode enum keyword.
299  if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
300  if (name == "mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
301  auto sym = symbolizeResizeMode(kw);
302  if (!sym)
303  return parser.emitError(parser.getCurrentLocation())
304  << "invalid resize mode value: " << kw;
305  auto attr = ResizeModeAttr::get(parser.getContext(), sym.value());
306  outAttrs.push_back(NamedAttribute(name, attr));
307  return success();
308  }
309  }
310  // special handling: nan_mode accepts a *bare* NanPropagationMode enum
311  // keyword.
312  if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
313  if (name == "nan_mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
314  auto sym = symbolizeNanPropagationMode(kw);
315  if (!sym)
316  return parser.emitError(parser.getCurrentLocation())
317  << "invalid nan_mode value: " << kw;
318  auto attr = NanPropagationModeAttr::get(parser.getContext(), sym.value());
319  outAttrs.push_back(NamedAttribute(name, attr));
320  return success();
321  }
322  }
323 
324  // special handling: block_size accepts a *bare* BlockSizeMode enum
325  if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
326  if (name == "block_size" && succeeded(parser.parseOptionalKeyword(&kw))) {
327  auto sym = symbolizeBlockSize(kw);
328  if (!sym)
329  return parser.emitError(parser.getCurrentLocation())
330  << "invalid block_size value: " << kw;
331  auto attr = BlockSizeAttr::get(parser.getContext(), sym.value());
332  outAttrs.push_back(NamedAttribute(name, attr));
333  return success();
334  }
335  }
336 
337  // Default path: parse any normal attribute literal, including fully qualified
338  // enum keyword
339  Attribute attr;
340  return parser.parseAttribute(attr, name, outAttrs);
341 }
342 
343 template <typename EnumType>
344 ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
345  // parse operands
347  if (parser.parseCommaSeparatedList(
348  [&]() { return parser.parseOperand(operands.emplace_back()); }))
349  return failure();
350 
351  // Parse { attr-dict } with special handling for enum bare token
352  NamedAttrList attrs;
353  if (succeeded(parser.parseOptionalLBrace()) &&
354  failed(parser.parseOptionalRBrace())) {
355  do {
356  if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
357  return failure();
358  } while (succeeded(parser.parseOptionalComma()));
359  if (parser.parseRBrace())
360  return failure();
361  }
362 
363  FunctionType fnTy;
364  if (parser.parseColonType(fnTy))
365  return failure();
366 
367  // Resolve operands and types
368  if (failed(parser.resolveOperands(operands, fnTy.getInputs(),
369  parser.getCurrentLocation(),
370  result.operands)))
371  return failure();
372 
373  result.addTypes(fnTy.getResults());
374  result.addAttributes(attrs);
375 
376  return success();
377 }
378 
379 void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) {
380  parser << namedAttr.getName().strref() << " = ";
381  auto attr = namedAttr.getValue();
382  if (auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
383  parser << roundingModeAttr.getValue();
384  } else if (auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
385  parser << resizeModeAttr.getValue();
386  } else if (auto nanPropagationModeAttr =
387  dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
388  parser << nanPropagationModeAttr.getValue();
389  } else if (auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
390  parser << blockSizeAttr.getValue();
391  } else {
392  parser.printAttribute(attr);
393  }
394 }
395 
396 // print with special handling for default valued NanPropagationMode attribute
397 void printWithNanPropagationHandling(OpAsmPrinter &parser, Operation *op) {
398  parser << " ";
399  parser.printOperands(op->getOperands());
400 
401  NamedAttrList toPrint(op->getAttrs());
402  // remove default NanPropagate attribute
403  const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
404  for (auto attr : op->getAttrs()) {
405  if (auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
406  if (nanAttr.getValue() == kDefaultNanValue) {
407  // elide from toPrint
408  toPrint.erase(attr.getName());
409  break;
410  }
411  }
412  }
413 
414  if (!toPrint.empty()) {
415  parser << " {";
416  llvm::interleaveComma(toPrint, parser, [&](const NamedAttribute namedAttr) {
417  printNamedAttr(parser, namedAttr);
418  });
419  parser << "}";
420  }
421 
422  parser << " : ";
423  parser.printFunctionalType(op);
424 }
425 
426 // print with special handling for enums: RoundingMode, ResizeMode
427 void printWithEnumHandling(OpAsmPrinter &parser, Operation *op) {
428  parser << " ";
429  parser.printOperands(op->getOperands());
430 
431  if (!op->getAttrs().empty()) {
432  parser << " {";
433  llvm::interleaveComma(op->getAttrs(), parser,
434  [&](const NamedAttribute namedAttr) {
435  printNamedAttr(parser, namedAttr);
436  });
437  parser << "}";
438  }
439 
440  parser << " : ";
441  parser.printFunctionalType(op);
442 }
443 
444 } // namespace
445 
446 ParseResult RescaleOp::parse(OpAsmParser &parser, OperationState &result) {
447  return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
448 }
449 
450 void RescaleOp::print(OpAsmPrinter &parser) {
451  printWithEnumHandling(parser, *this);
452 }
453 
454 ParseResult ApplyScaleOp::parse(OpAsmParser &parser, OperationState &result) {
455  return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
456 }
457 
458 void ApplyScaleOp::print(OpAsmPrinter &parser) {
459  printWithEnumHandling(parser, *this);
460 }
461 
462 ParseResult ResizeOp::parse(OpAsmParser &parser, OperationState &result) {
463  return parseWithEnumHandling<tosa::ResizeMode>(parser, result);
464 }
465 
466 void ResizeOp::print(OpAsmPrinter &parser) {
467  printWithEnumHandling(parser, *this);
468 }
469 
470 ParseResult ArgMaxOp::parse(OpAsmParser &parser, OperationState &result) {
471  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
472 }
473 
474 void ArgMaxOp::print(OpAsmPrinter &parser) {
475  printWithNanPropagationHandling(parser, *this);
476 }
477 
478 ParseResult MaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) {
479  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
480 }
481 
482 void MaxPool2dOp::print(OpAsmPrinter &parser) {
483  printWithNanPropagationHandling(parser, *this);
484 }
485 
486 ParseResult ClampOp::parse(OpAsmParser &parser, OperationState &result) {
487  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
488 }
489 
490 void ClampOp::print(OpAsmPrinter &parser) {
491  printWithNanPropagationHandling(parser, *this);
492 }
493 
494 ParseResult MaximumOp::parse(OpAsmParser &parser, OperationState &result) {
495  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
496 }
497 
498 void MaximumOp::print(OpAsmPrinter &parser) {
499  printWithNanPropagationHandling(parser, *this);
500 }
501 
502 ParseResult MinimumOp::parse(OpAsmParser &parser, OperationState &result) {
503  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
504 }
505 
506 void MinimumOp::print(OpAsmPrinter &parser) {
507  printWithNanPropagationHandling(parser, *this);
508 }
509 
510 ParseResult ReduceMaxOp::parse(OpAsmParser &parser, OperationState &result) {
511  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
512 }
513 
514 void ReduceMaxOp::print(OpAsmPrinter &parser) {
515  printWithNanPropagationHandling(parser, *this);
516 }
517 
518 ParseResult ReduceMinOp::parse(OpAsmParser &parser, OperationState &result) {
519  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
520 }
521 
522 void ReduceMinOp::print(OpAsmPrinter &parser) {
523  printWithNanPropagationHandling(parser, *this);
524 }
525 
526 ParseResult MatmulTBlockScaledOp::parse(OpAsmParser &parser,
527  OperationState &result) {
528  return parseWithEnumHandling<tosa::BlockSize>(parser, result);
529 }
530 
532  printWithEnumHandling(parser, *this);
533 }
534 
535 ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser,
536  OperationState &result) {
537  return parseWithEnumHandling<tosa::BlockSize>(parser, result);
538 }
539 
541  printWithEnumHandling(parser, *this);
542 }
543 
544 ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser,
545  OperationState &result) {
546  return parseWithEnumHandling<tosa::BlockSize>(parser, result);
547 }
548 
550  printWithEnumHandling(parser, *this);
551 }
552 
553 //===----------------------------------------------------------------------===//
554 // Tosa utilities.
555 //===----------------------------------------------------------------------===//
556 
557 static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
558  if (lhs % rhs != 0)
559  return std::nullopt;
560  return lhs / rhs;
561 }
562 
564  auto srcType = getElementTypeOrSelf(type);
565  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
566  srcType = quantType.getStorageType();
567  return srcType;
568 }
569 
571  return getStorageElementTypeOrSelf(value.getType());
572 }
573 
574 static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
575  Value valZp, StringRef name) {
577  Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
578 
579  bool bothInts =
580  mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
581  bool sameBitWidth =
582  (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
583 
584  if (!bothInts || !sameBitWidth) {
585  return op->emitOpError()
586  << "expected " << name << " and " << name
587  << "_zp to both be integer of the same bitwidth, but got " << eType
588  << " vs. " << eZpType;
589  }
590  return success();
591 }
592 
593 // Create a pad-const const tensor with value of `val` of required data-type
595  Value src, int32_t val) {
596  const auto srcType = getElementTypeOrSelf(src);
597  const auto srcElemType = getStorageElementTypeOrSelf(src);
598  const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
599  const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
600  const auto padConstAttr{
601  llvm::isa<FloatType>(srcElemType)
602  ? DenseElementsAttr::get(padConstEType,
603  builder.getFloatAttr(srcElemType, val))
604  : DenseElementsAttr::get(padConstEType,
605  builder.getIntegerAttr(srcElemType, val))};
606  return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
607 }
608 
610  if (dyn_cast<tosa::mxint8Type>(type))
611  return 8;
612  return type.getIntOrFloatBitWidth();
613 }
614 
615 //===----------------------------------------------------------------------===//
616 // TOSA Operator Verifiers.
617 //===----------------------------------------------------------------------===//
618 
619 template <typename T>
620 static LogicalResult verifyConvOp(T op) {
621  const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
622  const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
623 
624  auto inputEType = inputType.getElementType();
625  auto weightEType = weightType.getElementType();
626  auto biasEType =
627  llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
628  auto resultEType =
629  llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
630  bool biasIsFloat = llvm::isa<FloatType>(biasEType);
631  bool resultIsFloat = llvm::isa<FloatType>(resultEType);
632 
633  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
634  inputEType = quantType.getStorageType();
635 
636  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
637  weightEType = quantType.getStorageType();
638 
639  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
640  biasEType = quantType.getStorageType();
641 
642  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
643  resultEType = quantType.getStorageType();
644 
645  if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
646  // for now, only enforce bias element type == result element type for
647  // float types.
648  op.emitOpError(
649  "expect both bias and result to have same element type, got ")
650  << biasEType << " and " << resultEType;
651  return failure();
652  }
653 
654  if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
655  isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
656  if (inputEType != weightEType) {
657  op.emitOpError(
658  "expect both input and weight to have same element type, got ")
659  << inputEType << " and " << weightEType;
660  return failure();
661  }
662  }
663 
664  bool inputIsFloat = llvm::isa<FloatType>(inputEType);
665  bool weightIsFloat = llvm::isa<FloatType>(weightEType);
666 
667  // Either both must be float or both non-float.
668  if (inputIsFloat != weightIsFloat) {
669  op.emitOpError(
670  "expect both input and weight to be float or not together, got ")
671  << inputEType << " and " << weightEType;
672  return failure();
673  }
674 
675  auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
676  if (inputEType != inputZpEType) {
677  return op.emitOpError("expect both input and its zero point are the same "
678  "element type, got ")
679  << inputEType << " and " << inputZpEType;
680  }
681 
682  auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
683  if (weightEType != weightZpEType) {
684  return op.emitOpError("expect both weight and its zero point are the same "
685  "element type, got ")
686  << weightEType << " and " << weightZpEType;
687  }
688 
689  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
690  if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
691  return failure();
692 
693  FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
694  if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
695  return failure();
696 
697  return success();
698 }
699 
700 LogicalResult tosa::ConstOp::verify() {
701 
702  auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().getType());
703  auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
704 
705  if (!attrType || !outputType) {
706  emitOpError("expected tensors for attr/result type");
707  return failure();
708  }
709 
710  if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
711  outputType.getElementType())) {
712  if (result.getStorageType() == attrType.getElementType())
713  return success();
714  }
715 
716  if (attrType.getElementType() != outputType.getElementType()) {
717  emitOpError("expected same attr/result element types");
718  return failure();
719  }
720 
721  return success();
722 }
723 
724 template <typename T>
725 static LogicalResult verifyConvOpModes(T op) {
726  auto inputEType =
727  llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
728 
729  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
730  inputEType = quantType.getStorageType();
731 
732  auto accType = op.getAccType();
733  if (inputEType.isInteger(8) && !accType.isInteger(32))
734  return op.emitOpError("accumulator type for i8 tensor is not i32");
735 
736  if (inputEType.isInteger(16) && !accType.isInteger(48))
737  return op.emitOpError("accumulator type for i16 tensor is not i48");
738 
739  if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
740  return op.emitOpError("accumulator type for f8 tensor is not f16");
741 
742  if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
743  return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
744 
745  if (inputEType.isBF16() && !accType.isF32())
746  return op.emitOpError("accumulator type for bf16 tensor is not f32");
747 
748  if (inputEType.isF32() && !accType.isF32())
749  return op.emitOpError("accumulator type for f32 tensor is not f32");
750 
751  auto resultEType =
752  llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
753 
754  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
755  resultEType = quantType.getStorageType();
756 
757  return success();
758 }
759 
760 //===----------------------------------------------------------------------===//
761 // ERROR_IF functions.
762 // ERROR_IF is a predicate that must set an error if the condition holds.
763 //===----------------------------------------------------------------------===//
764 
765 template <typename T>
766 static LogicalResult verifyConvOpErrorIf(T op) {
767  llvm::ArrayRef<int64_t> padding = op.getPad();
768  if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
769  return op.emitOpError("expect all padding values to be >= 0, got ")
770  << padding;
771 
772  llvm::ArrayRef<int64_t> strides = op.getStride();
773  if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
774  return op.emitOpError("expect all stride values to be >= 1, got ")
775  << strides;
776 
777  llvm::ArrayRef<int64_t> dilations = op.getDilation();
778  if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
779  return op.emitOpError("expect all dilation values to be >= 1, got ")
780  << dilations;
781 
782  const RankedTensorType outputType =
783  llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
784  if (!outputType)
785  // Skip following checks if output is not ranked
786  return success();
787 
788  const RankedTensorType inputType =
789  llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
790  const RankedTensorType weightType =
791  llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
792 
793  if (inputType && weightType) {
794  const auto verifyOutputSize =
795  [&op](const int64_t inputSize, const int64_t kernelSize,
796  const int64_t outputSize, const int64_t padBefore,
797  const int64_t padAfter, const int64_t stride,
798  const int64_t dilation, const llvm::StringRef dimName,
799  const llvm::StringRef dimAxis,
800  const llvm::StringRef padBeforeName,
801  const llvm::StringRef padAfterName) -> LogicalResult {
802  if (inputSize == ShapedType::kDynamic ||
803  kernelSize == ShapedType::kDynamic)
804  return success();
805 
806  // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
807 
808  const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
809  inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
810  stride);
811  if (!calculatedOutSizeMinusOne.has_value())
812  return op.emitOpError("expected input_")
813  << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
814  << padAfterName << " - (kernel_" << dimName
815  << " - 1) * dilation_" << dimAxis
816  << " to be wholly divisible by stride_" << dimAxis << ", got ("
817  << inputSize << " - 1 + " << padBefore << " + " << padAfter
818  << " - (" << kernelSize << " - 1) * " << dilation << ") / "
819  << stride;
820 
821  const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
822  if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
823  return op.emitOpError("calculated output ")
824  << dimName << " did not match expected: "
825  << "calculated=" << calculatedOutSize
826  << ", expected=" << outputSize;
827 
828  return success();
829  };
830 
831  // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
832  if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
833  if (failed(verifyOutputSize(
834  inputType.getDimSize(1), weightType.getDimSize(1),
835  outputType.getDimSize(1), padding[0], padding[1], strides[0],
836  dilations[0], "height", "y", "top", "bottom")))
837  return failure();
838 
839  if (failed(verifyOutputSize(
840  inputType.getDimSize(2), weightType.getDimSize(2),
841  outputType.getDimSize(2), padding[2], padding[3], strides[1],
842  dilations[1], "width", "x", "left", "right")))
843  return failure();
844  }
845 
846  // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
847  if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
848  if (failed(verifyOutputSize(
849  inputType.getDimSize(1), weightType.getDimSize(0),
850  outputType.getDimSize(1), padding[0], padding[1], strides[0],
851  dilations[0], "height", "y", "top", "bottom")))
852  return failure();
853 
854  if (failed(verifyOutputSize(
855  inputType.getDimSize(2), weightType.getDimSize(1),
856  outputType.getDimSize(2), padding[2], padding[3], strides[1],
857  dilations[1], "width", "x", "left", "right")))
858  return failure();
859  }
860 
861  // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
862  if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
863  if (failed(verifyOutputSize(
864  inputType.getDimSize(1), weightType.getDimSize(1),
865  outputType.getDimSize(1), padding[0], padding[1], strides[0],
866  dilations[0], "depth", "d", "front", "back")))
867  return failure();
868 
869  if (failed(verifyOutputSize(
870  inputType.getDimSize(2), weightType.getDimSize(2),
871  outputType.getDimSize(2), padding[2], padding[3], strides[1],
872  dilations[1], "height", "y", "top", "bottom")))
873  return failure();
874 
875  if (failed(verifyOutputSize(
876  inputType.getDimSize(3), weightType.getDimSize(3),
877  outputType.getDimSize(3), padding[4], padding[5], strides[2],
878  dilations[2], "width", "x", "left", "right")))
879  return failure();
880  }
881  }
882 
883  const RankedTensorType biasType =
884  llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
885  if (!biasType)
886  // Skip following checks if bias is not ranked
887  return success();
888 
889  const int64_t biasChannels = biasType.getDimSize(0);
890  const int64_t outputChannels =
891  outputType.getDimSize(outputType.getRank() - 1);
892  if (biasChannels == ShapedType::kDynamic ||
893  outputChannels == ShapedType::kDynamic)
894  // Skip following checks if biasChannels or outputChannels is dynamic dim
895  return success();
896 
897  if (biasChannels != outputChannels && biasChannels != 1)
898  return op.emitOpError(
899  "bias channels expected to be equal to output channels (")
900  << outputChannels << ") or 1, got " << biasChannels;
901 
902  return success();
903 }
904 
905 // Verify whether same type and shape of the given two types.
906 static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
907  StringRef name1, Type type2,
908  StringRef name2) {
909  auto shapeType1 = dyn_cast<ShapedType>(type1);
910  auto shapeType2 = dyn_cast<ShapedType>(type2);
911  if (!shapeType1 || !shapeType2)
912  return failure();
913 
914  auto elemType1 = shapeType1.getElementType();
915  auto elemType2 = shapeType2.getElementType();
916  if (elemType1 != elemType2)
917  return op->emitOpError()
918  << "require same element type for " << name1 << " (" << elemType1
919  << ") and " << name2 << " (" << elemType2 << ")";
920 
921  if (failed(verifyCompatibleShape(type1, type2)))
922  return op->emitOpError()
923  << "require same shapes for " << name1 << " (" << type1 << ") and "
924  << name2 << " (" << type2 << ")";
925 
926  return success();
927 }
928 
929 // Verify whether same length, type, and shape of the given two tensor lists.
930 static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
931  StringRef name1,
932  ValueRange list2,
933  StringRef name2) {
934  if (list1.size() != list2.size())
935  return op->emitOpError()
936  << "require same number of values in " << name1 << " ("
937  << list1.size() << ") and " << name2 << " (" << list2.size() << ")";
938 
939  for (auto [type1, type2] :
940  llvm::zip_equal(list1.getTypes(), list2.getTypes())) {
941  if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
942  return failure();
943  }
944 
945  return success();
946 }
947 
948 static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
949  ShapeAdaptor shapeAdaptor(type);
950  if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
951  return success();
952 
953  return shapeAdaptor.getNumElements() == 1 ? success() : failure();
954 }
955 
956 template <typename T>
957 static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
958  Operation *symTableOp =
959  op->template getParentWithTrait<OpTrait::SymbolTable>();
960  if (!symTableOp)
961  // If the operation is not the scope of a symbol table, we cannot
962  // verify it against it's declaration.
963  return success();
964 
965  SymbolTable symTable(symTableOp);
966  const auto varOp = symTable.lookup<tosa::VariableOp>(op.getName());
967 
968  // Verify prior declaration
969  if (!varOp)
970  return op->emitOpError("'")
971  << op.getName() << "' has not been declared by 'tosa.variable'";
972 
973  // Verify type and shape
974  auto variableType = getVariableType(varOp);
975  if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
976  "the input tensor")
977  .failed())
978  return failure();
979  return success();
980 }
981 
982 // verify that inType and outType have same element types
983 template <typename T>
984 static LogicalResult verifySameElementTypes(T op, Type aType, Type bType,
985  StringRef aName = "input",
986  StringRef bName = "output") {
987  auto aTType = llvm::dyn_cast<TensorType>(aType);
988  auto bTType = llvm::dyn_cast<TensorType>(bType);
989  if (!aTType) {
990  op.emitOpError("expect shaped tensor for") << aName << ", got " << aType;
991  return failure();
992  }
993  if (!bTType) {
994  op.emitOpError("expect shaped tensor for") << bName << ", got" << bType;
995  return failure();
996  }
997  auto aElementType = aTType.getElementType();
998  auto bElementType = bTType.getElementType();
999  auto aQuantType =
1000  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
1001  auto bQuantType =
1002  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
1003  if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
1004  (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
1005  aElementType != bElementType) {
1006  // only check if both element types are int/index/float/UniformQuantized
1007  // eg, not sure how to check quant::QuantizedType
1008  // this happens in test_conv2d_q_grouped_convolution in
1009  // tfl-to-tosa-pipeline.mlir
1010  op.emitOpError("expect ")
1011  << aName << " and " << bName << " to have same element type, got "
1012  << aElementType << " and " << bElementType;
1013  return failure();
1014  }
1015  return success();
1016 }
1017 
1018 LogicalResult tosa::ArgMaxOp::verify() {
1019  const ShapedType resultType = llvm::cast<ShapedType>(getType());
1020 
1021  // Ensure output is of 32-bit integer
1022  if (const auto resultETy = resultType.getElementType();
1023  !resultETy.isIntOrIndex())
1024  return emitOpError("result tensor is not of integer type");
1025 
1026  const auto inputType = llvm::cast<ShapedType>(getInput().getType());
1027  if (!inputType.hasRank())
1028  return success();
1029 
1030  // Ensure axis is within the tensor rank
1031  const int64_t axis = getAxisAttr().getInt();
1032  if (((axis < 0) || axis >= inputType.getRank()))
1033  return emitOpError("specified axis is outside the rank of the tensor");
1034 
1035  if (!resultType.hasRank())
1036  return success();
1037 
1038  const ArrayRef<int64_t> inputShape = inputType.getShape();
1039  const ArrayRef<int64_t> outputShape = resultType.getShape();
1040  llvm::SmallVector<int64_t> expectedOutputShape(inputShape);
1041  expectedOutputShape.erase(expectedOutputShape.begin() + axis);
1042  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
1043  return emitOpError("expected output shape '")
1044  << expectedOutputShape << "', got '" << outputShape << "'";
1045 
1046  return success();
1047 }
1048 
1049 template <typename T>
1050 static LogicalResult verifyPoolingOp(T op) {
1051  const llvm::ArrayRef<int64_t> kernel = op.getKernel();
1052  if (llvm::any_of(kernel, [](int64_t s) { return s < 1; }))
1053  return op.emitOpError("expect all kernel values to be >= 1, got ")
1054  << kernel;
1055 
1056  const llvm::ArrayRef<int64_t> strides = op.getStride();
1057  if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
1058  return op.emitOpError("expect all stride values to be >= 1, got ")
1059  << strides;
1060 
1061  const llvm::ArrayRef<int64_t> padding = op.getPad();
1062  if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
1063  return op.emitOpError("expect all padding values to be >= 0, got ")
1064  << padding;
1065 
1066  // Padding must be less than kernel size to avoid a divide-by-zero
1067  const int64_t kernelX = kernel[1];
1068  const int64_t padLeft = padding[2];
1069  const int64_t padRight = padding[3];
1070  if (padRight >= kernelX || padLeft >= kernelX)
1071  return op.emitOpError("expected left/right padding to be less than the "
1072  "width of the kernel, got pad_left=")
1073  << padLeft << ", pad_right=" << padRight << ", kernel_x=" << kernelX;
1074 
1075  const int64_t kernelY = kernel[0];
1076  const int64_t padTop = padding[0];
1077  const int64_t padBottom = padding[1];
1078  if (padTop >= kernelY || padBottom >= kernelY)
1079  return op.emitOpError("expected top/bottom padding to be less than the "
1080  "height of the kernel, got pad_top=")
1081  << padTop << ", pad_bottom=" << padBottom
1082  << ", kernel_y=" << kernelY;
1083 
1084  const auto inputType =
1085  llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
1086  const auto outputType =
1087  llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
1088  if (!inputType || !outputType)
1089  return success();
1090 
1091  const auto verifyOutputSize =
1092  [&op](const int64_t inputSize, const int64_t outputSize,
1093  const int64_t kernelSize, const int64_t strideSize,
1094  const int64_t padBefore, const int64_t padAfter,
1095  const llvm::StringRef dimName, const llvm::StringRef dimAxis,
1096  const llvm::StringRef padBeforeName,
1097  const llvm::StringRef padAfterName) -> LogicalResult {
1098  if (ShapedType::isDynamic(inputSize))
1099  return success();
1100 
1101  const std::optional<int64_t> calculatedOutSizeMinusOne =
1102  idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1103  if (!calculatedOutSizeMinusOne.has_value())
1104  return op.emitOpError("expected input_")
1105  << dimName << " + pad_" << padBeforeName << " + pad_"
1106  << padAfterName << " - kernel_" << dimAxis
1107  << " to be wholly divisible by stride_" << dimAxis << ", got ("
1108  << inputSize << " + " << padBefore << " + " << padAfter << " - "
1109  << kernelSize << ") / " << strideSize;
1110 
1111  const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1112  if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1113  return op.emitOpError("calculated output ")
1114  << dimName << " did not match expected: "
1115  << "calculated=" << calculatedOutSize
1116  << ", expected=" << outputSize;
1117 
1118  return success();
1119  };
1120 
1121  if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
1122  kernel[0], strides[0], padding[0], padding[1],
1123  "height", "y", "top", "bottom")))
1124  return failure();
1125 
1126  if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
1127  kernel[1], strides[1], padding[2], padding[3],
1128  "width", "x", "left", "right")))
1129  return failure();
1130 
1131  return success();
1132 }
1133 
1134 LogicalResult tosa::AvgPool2dOp::verify() {
1135  if (failed(verifyPoolingOp(*this)))
1136  return failure();
1137 
1138  const Type inputETy = getStorageElementTypeOrSelf(getInput().getType());
1139  const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType());
1140  const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType());
1141  const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType());
1142 
1143  auto accType = getAccType();
1144  if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1145  return emitOpError("accumulator type for integer tensor is not i32");
1146 
1147  if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
1148  return emitOpError("accumulator type for f16 tensor is not f16/f32");
1149 
1150  if (inputETy.isBF16() && !accType.isF32())
1151  return emitOpError("accumulator type for bf16 tensor is not f32");
1152 
1153  if (inputETy.isF32() && !accType.isF32())
1154  return emitOpError("accumulator type for f32 tensor is not f32");
1155 
1156  if (inputETy != inputZpETy)
1157  return emitOpError("expect both input and its zero point are the same "
1158  "element type, got ")
1159  << inputETy << " and " << inputZpETy;
1160 
1161  if (resultETy != outputZpETy)
1162  return emitOpError("expect both output and its zero point are the same "
1163  "element type, got ")
1164  << resultETy << " and " << outputZpETy;
1165 
1166  FailureOr<int64_t> maybeIZp = getInputZeroPoint();
1167  if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
1168  return failure();
1169 
1170  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1171  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
1172  return failure();
1173 
1174  return success();
1175 }
1176 
1177 LogicalResult tosa::ClampOp::verify() {
1178  mlir::Type inputETy =
1179  llvm::cast<ShapedType>(getInput().getType()).getElementType();
1180  if (auto quantType =
1181  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1182  inputETy = quantType.getStorageType();
1183  }
1184  mlir::Type outputETy =
1185  llvm::cast<ShapedType>(getOutput().getType()).getElementType();
1186  if (auto quantType =
1187  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1188  outputETy = quantType.getStorageType();
1189  }
1190  if (inputETy != outputETy)
1191  return emitOpError("input/output element types are incompatible.");
1192 
1193  auto maxValAttr = getMaxValAttr();
1194  auto minValAttr = getMinValAttr();
1195 
1196  unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
1197 
1198  if (inputETy.isInteger(dataTypeBitWidth)) {
1199  // if input datatype is integer, check that the min_val/max_val attributes
1200  // are integer attributes, and that their type is the same as the input's
1201  // datatype
1202  auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1203  auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1204  if (!intMaxValAttr || !intMinValAttr ||
1205  (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1206  (intMaxValAttr.getType() != inputETy))
1207  return emitOpError("min/max attributes types are incompatible with "
1208  "input/output element types.");
1209 
1210  const bool isUnsigned = inputETy.isUnsignedInteger();
1211  const bool isBoolean = inputETy.isInteger(1);
1212  const APInt minVal = intMinValAttr.getValue();
1213  const APInt maxVal = intMaxValAttr.getValue();
1214  if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1215  return emitOpError("expected min_val <= max_val, got min_val=")
1216  << minValAttr << ", max_val=" << maxValAttr;
1217  } else {
1218  // otherwise, input datatype is float, check that the min_val/max_val
1219  // attributes share the same type and that their type is the same as the
1220  // input's datatype
1221  auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1222  auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1223  if (!floatMaxValAttr || !floatMinValAttr ||
1224  (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1225  (floatMaxValAttr.getType() != inputETy))
1226  return emitOpError("min/max attributes types are incompatible with "
1227  "input/output element types.");
1228 
1229  const APFloat minVal = floatMinValAttr.getValue();
1230  const APFloat maxVal = floatMaxValAttr.getValue();
1231  if (minVal.isNaN() || maxVal.isNaN())
1232  return emitOpError("min/max attributes should not be 'NaN', got min_val=")
1233  << minValAttr << ", max_val=" << maxValAttr;
1234 
1235  if (maxVal < minVal)
1236  return emitOpError("expected min_val <= max_val, got min_val=")
1237  << minValAttr << ", max_val=" << maxValAttr;
1238  }
1239 
1240  return success();
1241 }
1242 
1243 //===----------------------------------------------------------------------===//
1244 // TOSA Operator Quantization Builders.
1245 //===----------------------------------------------------------------------===//
1246 
1247 /// This builder is called on all convolution operators except TransposeConv,
1248 /// which has specialized output shape semantics. The builder also defines the
1249 /// bitwidth of the output given the bit width of the input & weight content.
1250 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
1251  Type outputType, Value input, Value weight,
1252  Value bias, DenseI64ArrayAttr pad,
1253  DenseI64ArrayAttr stride,
1254  DenseI64ArrayAttr dilation,
1255  TypeAttr accType) {
1256  auto zps = createZPsAsConst(builder, input, weight);
1257  result.addOperands({input, weight, bias, zps.first, zps.second});
1258  result.addAttribute("pad", pad);
1259  result.addAttribute("stride", stride);
1260  result.addAttribute("dilation", dilation);
1261  result.addAttribute("acc_type", accType);
1262  Type finalOutputType = outputType;
1263  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1264  if (quantAttr) {
1265  finalOutputType =
1266  buildConvOpResultTypeInfo(builder, outputType, input, weight);
1267  }
1268  result.addTypes(finalOutputType);
1269 }
1270 
1271 /// Handles tosa.transpose_conv2d which has outpad and output shape
1272 /// attributes.
1273 static void
1275  Type outputType, Value input, Value weight,
1276  Value bias, DenseI64ArrayAttr outpad,
1277  DenseI64ArrayAttr stride, TypeAttr accType) {
1278  auto zps = createZPsAsConst(builder, input, weight);
1279  result.addOperands({input, weight, bias, zps.first, zps.second});
1280  result.addAttribute("out_pad", outpad);
1281  result.addAttribute("stride", stride);
1282  result.addAttribute("acc_type", accType);
1283  Type finalOutputType = outputType;
1284  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1285  if (quantAttr) {
1286  finalOutputType =
1287  buildConvOpResultTypeInfo(builder, outputType, input, weight);
1288  }
1289  result.addTypes(finalOutputType);
1290 }
1291 
1292 /// The tosa.matmul op is also intended to be generated where a fully_connected
1293 /// op must be constructed where the weight is not a constant. In this case,
1294 /// the fully_connected op must be expressed using matmul.
1295 /// TODO: Add link to the leglization document explaining this.
1297  OperationState &result, Type outputType,
1298  Value a, Value b) {
1299  auto zps = createZPsAsConst(builder, a, b);
1300  result.addOperands({a, b, zps.first, zps.second});
1301 
1302  Type finalOutputType{outputType};
1303  if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
1304  auto eType = getStorageElementTypeOrSelf(a.getType());
1305  auto inputBits = eType.getIntOrFloatBitWidth();
1306 
1307  auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1308  assert(outputShapedType && "Output must be a shaped type");
1309 
1310  IntegerType accElementType;
1311  if (inputBits == 16)
1312  accElementType = builder.getIntegerType(48);
1313  else
1314  accElementType = builder.getI32Type();
1315 
1316  finalOutputType = outputShapedType.clone(accElementType);
1317  }
1318  result.addTypes(finalOutputType);
1319 }
1320 
1321 /// Both the tosa.avg_pool2d and unary ops use the same
1322 /// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
1323 /// has additional parameters not part of the unary ops.
1324 static void
1326  Type outputType, Value input,
1327  DenseArrayAttr kernel, DenseArrayAttr stride,
1328  DenseArrayAttr pad, TypeAttr accType) {
1329  const Location loc{result.location};
1330  int64_t inputZp{0};
1331  int64_t outputZp{0};
1332 
1333  if (auto quantAttr =
1334  buildUnaryOpQuantizationAttr(builder, input, outputType)) {
1335  inputZp = quantAttr.getInputZp();
1336  outputZp = quantAttr.getOutputZp();
1337  }
1338  const std::optional<Value> inputZpOp =
1339  createZeroPointTensor(builder, loc, input.getType(), inputZp);
1340  if (!inputZpOp) {
1341  (void)emitError(
1342  loc,
1343  "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1344  }
1345  const std::optional<Value> outputZpOp =
1346  createZeroPointTensor(builder, loc, outputType, outputZp);
1347  if (!outputZpOp) {
1348  (void)emitError(loc, "Failed to create output zero point tensor for "
1349  "quantized AVG_POOL2D op");
1350  }
1351 
1352  if (inputZpOp && outputZpOp) {
1353  result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1354  } else {
1355  // failed to create one or more zero points above: just add input as
1356  // operands this will trigger error in building the op because of missing
1357  // zero points
1358  result.addOperands({input});
1359  }
1360  result.addAttribute("kernel", kernel);
1361  result.addAttribute("stride", stride);
1362  result.addAttribute("pad", pad);
1363  result.addAttribute("acc_type", accType);
1364  result.types.push_back(outputType);
1365 }
1366 
1367 /// This builder is called on single-parameter negate operator
1368 /// to construct input and output zero points based on their
1369 /// types.
1371  OperationState &result, Type outputType,
1372  Value input) {
1373  const Location loc{result.location};
1374  int64_t input1Zp{0};
1375  int64_t outputZp{0};
1376  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
1377  if (quantAttr) {
1378  input1Zp = quantAttr.getInputZp();
1379  outputZp = quantAttr.getOutputZp();
1380  }
1381  const std::optional<Value> input1ZpOp =
1382  createZeroPointTensor(builder, loc, input.getType(), input1Zp);
1383  if (!input1ZpOp) {
1384  (void)emitError(
1385  loc, "Failed to create input1 zero point for quantized NEGATE op");
1386  }
1387 
1388  const std::optional<Value> outputZpOp =
1389  createZeroPointTensor(builder, loc, input.getType(), outputZp);
1390  if (!outputZpOp) {
1391  (void)emitError(
1392  loc, "Failed to create output zero point for quantized NEGATE op");
1393  }
1394 
1395  if (input1ZpOp && outputZpOp) {
1396  result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1397  } else {
1398  // failed to create one or more zero points above: just add input as
1399  // operands. This will trigger error in building the op because of
1400  // missing zero points
1401  result.addOperands({input});
1402  }
1403 
1404  result.types.push_back(outputType);
1405 }
1406 
1407 /// This builder is called on TOSA pad operator that needs to create its own
1408 /// OptionalAttr quantization_attr parameter to scale the padding values
1409 /// correctly. No pad_const is interpreted as zero-padding.
1410 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
1411  Type outputType, Value input,
1412  Value paddings) {
1413  const Location loc{result.location};
1414  int32_t zp{0};
1415  const auto quantAttr = buildPadOpQuantizationAttr(builder, input);
1416  if (quantAttr) {
1417  zp = static_cast<int32_t>(quantAttr.getInputZp());
1418  }
1419  const auto padConstOp{createPadConstTensor(builder, loc, input, zp)};
1420  result.addOperands({input, paddings, padConstOp});
1421  result.types.push_back(outputType);
1422 }
1423 
1424 static void buildVariableOp(OpBuilder &builder, OperationState &result,
1425  StringRef name, Type variableType,
1426  Attribute initialValue) {
1427  const Location loc{result.location};
1428  auto nameAttr = builder.getStringAttr(name);
1429 
1430  auto shapedType = dyn_cast<ShapedType>(variableType);
1431  if (!shapedType) {
1432  (void)emitError(loc, "variable type must be a shaped type");
1433  return;
1434  }
1435  if (!shapedType.hasRank()) {
1436  (void)emitError(loc, "variable type must be a ranked type");
1437  return;
1438  }
1439 
1440  auto elementType = shapedType.getElementType();
1441  auto elementTypeAttr = TypeAttr::get(elementType);
1442  ArrayRef<int64_t> shape = shapedType.getShape();
1443  auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
1444 
1445  result.addAttribute("sym_name", nameAttr);
1446  result.addAttribute("var_shape", varShapeAttr);
1447  result.addAttribute("type", elementTypeAttr);
1448  result.addAttribute("initial_value", initialValue);
1449 }
1450 
1451 //===----------------------------------------------------------------------===//
1452 // TOSA Operator Return Type Inference.
1453 //===----------------------------------------------------------------------===//
1454 
1455 static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
1456  SmallVector<int64_t> &outShape) {
1457  int64_t outRank = 0;
1458  for (int i = 0, e = operands.size(); i != e; ++i) {
1459  auto shape = operands.getShape(i);
1460  if (!shape.hasRank()) {
1461  // TODO(jennik): Update function to have better case handling for
1462  // invalid operands and for ranked tensors.
1463  return failure();
1464  }
1465  outRank = std::max<int64_t>(outRank, shape.getRank());
1466  }
1467 
1468  outShape.resize(outRank, 1);
1469 
1470  for (int i = 0, e = operands.size(); i != e; ++i) {
1471  auto shape = operands.getShape(i);
1472  auto rankDiff = outShape.size() - shape.getRank();
1473 
1474  for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
1475  auto dim1 = outShape[i + rankDiff];
1476  auto dim2 = shape.getDimSize(i);
1477  auto resolvedDim = dim1;
1478 
1479  if (dim1 == 1) {
1480  resolvedDim = dim2;
1481  } else if (dim2 == 1) {
1482  resolvedDim = dim1;
1483  } else if (dim1 != dim2) {
1484  return failure();
1485  }
1486  outShape[i + rankDiff] = resolvedDim;
1487  }
1488  }
1489 
1490  return success();
1491 }
1492 
1493 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1494  MLIRContext *context, ::std::optional<Location> location,
1495  ArgMaxOp::Adaptor adaptor,
1496  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1497  ShapeAdaptor inputShape(adaptor.getInput().getType());
1498  IntegerAttr axis = adaptor.getProperties().axis;
1499  int32_t axisVal = axis.getValue().getSExtValue();
1500 
1501  if (!inputShape.hasRank()) {
1502  inferredReturnShapes.push_back(ShapedTypeComponents());
1503  return success();
1504  }
1505 
1506  SmallVector<int64_t> outShape;
1507  outShape.reserve(inputShape.getRank() - 1);
1508  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1509  if (i == axisVal)
1510  continue;
1511  outShape.push_back(inputShape.getDimSize(i));
1512  }
1513 
1514  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1515  return success();
1516 }
1517 
1518 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1519  MLIRContext *context, ::std::optional<Location> location,
1520  RFFT2dOp::Adaptor adaptor,
1521  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1522  ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1523 
1524  if (!inputShape.hasRank())
1525  return failure();
1526 
1527  llvm::SmallVector<int64_t> outputShape;
1528  outputShape.resize(3, ShapedType::kDynamic);
1529  outputShape[0] = inputShape.getDimSize(0);
1530  outputShape[1] = inputShape.getDimSize(1);
1531  int64_t inWidth = inputShape.getDimSize(2);
1532 
1533  // Note that we can support this calculation symbolically
1534  // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
1535  if (inWidth != ShapedType::kDynamic)
1536  outputShape[2] = inWidth / 2 + 1;
1537 
1538  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1539  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1540 
1541  return success();
1542 }
1543 
1544 static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
1545  const llvm::StringRef dimName) {
1546  const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1547  if (!isPowerOfTwo)
1548  return op->emitOpError("expected ")
1549  << dimName << " to be a power of two, got " << dimSize;
1550 
1551  return success();
1552 }
1553 
1554 LogicalResult tosa::RFFT2dOp::verify() {
1555  const auto outputTypes = getResultTypes();
1556  if (failed(verifyCompatibleShapes(outputTypes)))
1557  return emitOpError("expected output shapes to match, got ") << outputTypes;
1558 
1559  const auto inputType =
1560  llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1561  if (!inputType)
1562  return success();
1563 
1564  const int64_t height = inputType.getDimSize(1);
1565  if (ShapedType::isStatic(height) &&
1566  failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1567  return failure();
1568 
1569  const int64_t width = inputType.getDimSize(2);
1570  if (ShapedType::isStatic(width) &&
1571  failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1572  return failure();
1573 
1574  const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1575  if (!outputType)
1576  return success();
1577 
1578  // Batch and height input/output dimensions should match
1579  if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
1580  outputType.getShape().drop_back())))
1581  return emitOpError("expected batch and height dimensions of input/output "
1582  "to match, got input=")
1583  << inputType << " output=" << outputType;
1584 
1585  // Output width dimension expected to be input_width / 2 + 1
1586  const int64_t outputWidth = outputType.getDimSize(2);
1587  if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1588  (outputWidth != (width / 2) + 1))
1589  return emitOpError(
1590  "expected output width to be equal to input_width / 2 + 1, got ")
1591  << outputWidth;
1592 
1593  return success();
1594 }
1595 
1596 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1597  MLIRContext *context, ::std::optional<Location> location,
1598  FFT2dOp::Adaptor adaptor,
1599  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1600  inferredReturnShapes.push_back(
1601  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
1602  inferredReturnShapes.push_back(
1603  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
1604  return success();
1605 }
1606 
1607 LogicalResult tosa::FFT2dOp::verify() {
1608  const auto inputRealType =
1609  llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1610  const auto inputImagType =
1611  llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
1612  if (!inputRealType || !inputImagType)
1613  return success();
1614 
1615  const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
1616  return ShapedType::isDynamic(a) ? a : b;
1617  };
1618 
1619  const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1620  inputImagType.getDimSize(1));
1621  if (ShapedType::isStatic(height) &&
1622  failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1623  return failure();
1624 
1625  const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1626  inputImagType.getDimSize(2));
1627  if (ShapedType::isStatic(width) &&
1628  failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1629  return failure();
1630 
1631  return success();
1632 }
1633 
1634 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1635  MLIRContext *context, ::std::optional<Location> location,
1636  ConcatOp::Adaptor adaptor,
1637  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1638  // Infer all dimension sizes by reducing based on inputs.
1639  const Properties &prop = adaptor.getProperties();
1640  int32_t axis = prop.axis.getValue().getSExtValue();
1641  llvm::SmallVector<int64_t> outputShape;
1642  bool hasRankedInput = false;
1643  for (auto operand : adaptor.getOperands()) {
1644  ShapeAdaptor operandShape(operand.getType());
1645  if (!operandShape.hasRank())
1646  continue;
1647 
1648  // Copy the Operand's rank.
1649  if (!hasRankedInput)
1650  outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1651 
1652  // Copy shapes until the dim is non-dynamic.
1653  for (int i = 0, s = operandShape.getRank(); i < s; i++) {
1654  if (i == axis || operandShape.isDynamicDim(i))
1655  continue;
1656  if (outputShape[i] == ShapedType::kDynamic)
1657  outputShape[i] = operandShape.getDimSize(i);
1658  if (outputShape[i] != operandShape.getDimSize(i))
1659  return emitOptionalError(location,
1660  "Cannot concat tensors with different sizes"
1661  " on the non-axis dimension ",
1662  i);
1663  }
1664 
1665  hasRankedInput = true;
1666  }
1667 
1668  if (adaptor.getInput1().empty())
1669  return failure();
1670 
1671  Type inputType =
1672  llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1673  if (!hasRankedInput) {
1674  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1675  return success();
1676  }
1677 
1678  // Determine the dimension size along the concatenation axis.
1679  int64_t concatDimSize = 0;
1680  for (auto operand : adaptor.getOperands()) {
1681  ShapeAdaptor operandShape(operand.getType());
1682 
1683  // We need to know the length of the concatenation axis of all inputs to
1684  // determine the dimension size of the output shape.
1685  if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1686  concatDimSize = ShapedType::kDynamic;
1687  break;
1688  }
1689 
1690  concatDimSize += operandShape.getDimSize(axis);
1691  }
1692 
1693  outputShape[axis] = concatDimSize;
1694 
1695  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1696  return success();
1697 }
1698 
1699 LogicalResult tosa::ConcatOp::verify() {
1700  // check that each input has same element type as output
1701  auto outType = getOutput().getType();
1702  const Operation::operand_range inputList = getInput1();
1703 
1704  // Check there is at least one input
1705  if (inputList.empty())
1706  return emitOpError("expect at least one input");
1707 
1708  if (!llvm::all_of(inputList, [&](auto input) {
1709  return succeeded(verifySameElementTypes(
1710  *this, /* inType = */ input.getType(), outType));
1711  })) {
1712  return failure();
1713  }
1714 
1715  const int32_t axis = getAxis();
1716  ShapeAdaptor firstRankedInputShape = nullptr;
1717  for (const auto &input : inputList) {
1718  const Type inputType = input.getType();
1719  ShapeAdaptor currShape(inputType);
1720  if (currShape.hasRank()) {
1721  firstRankedInputShape = currShape;
1722  // Check axis is in expected range
1723  if (axis < 0 || axis >= firstRankedInputShape.getRank())
1724  return emitOpError("expect axis to be within range 0 < axis < "
1725  "rank(input1[firstRankedTensorIdx]), got ")
1726  << axis;
1727  break;
1728  }
1729  }
1730 
1731  const auto allOperandsHasRank = [](const Value input) {
1732  return ShapeAdaptor(input.getType()).hasRank();
1733  };
1734  if (llvm::all_of(inputList, allOperandsHasRank)) {
1735  const int64_t firstInputRank = firstRankedInputShape.getRank();
1736 
1737  for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {
1738  const ShapeAdaptor inputShape(input.getType());
1739  const int64_t inputRank = inputShape.getRank();
1740  const size_t operandNum = index + 1;
1741 
1742  // Check that each operand has the same rank
1743  if (inputRank != firstInputRank)
1744  return emitOpError(
1745  "expect all operands to have the same rank, but got ")
1746  << firstInputRank << " vs " << inputRank << " on operands 0 and "
1747  << operandNum;
1748 
1749  // Check non-axis dims match
1750  for (int i = 0; i < inputRank; i++) {
1751  const int64_t inputDim = inputShape.getDimSize(i);
1752  const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
1753  if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
1754  inputShape.isDynamicDim(i))
1755  continue;
1756  if (inputDim != firstInputDim)
1757  return emitOpError("expect all operand shapes to have the same sizes "
1758  "on non-axis dimensions, but got ")
1759  << inputDim << " vs " << firstInputDim << " at index " << i
1760  << " on operands 0 and " << operandNum;
1761  }
1762  }
1763 
1764  // ERROR_IF(axis_sum != shape[axis]);
1765  int64_t axisSum = 0;
1766  for (const auto &input : inputList) {
1767  const ShapeAdaptor inputShape(input.getType());
1768  if (inputShape.isDynamicDim(axis)) {
1769  // make axisSum negative to indicate invalid value
1770  axisSum = -1;
1771  break;
1772  }
1773  axisSum += inputShape.getDimSize(axis);
1774  }
1775  const ShapeAdaptor outputShape(outType);
1776  if (axisSum >= 0 && outputShape.hasRank() &&
1777  !outputShape.isDynamicDim(axis) &&
1778  axisSum != outputShape.getDimSize(axis))
1779  return emitOpError("requires sum of axis dimensions of input1 "
1780  "equal to output axis dimension, got ")
1781  << axisSum << " and " << outputShape.getDimSize(axis);
1782  }
1783 
1784  return success();
1785 }
1786 
1787 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1788  MLIRContext *context, ::std::optional<Location> location,
1789  ValueShapeRange operands, DictionaryAttr attributes,
1790  OpaqueProperties properties, RegionRange regions,
1791  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1792  auto elementType = IntegerType::get(context, /*width=*/1);
1793 
1794  llvm::SmallVector<int64_t> outShape;
1795  if (resolveBroadcastShape(operands, outShape).failed()) {
1796  inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
1797  return success();
1798  }
1799 
1800  inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
1801  return success();
1802 }
1803 
1804 bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1805  if (l.size() != r.size() || l.size() != 1)
1806  return false;
1807  return succeeded(verifyCompatibleShape(l[0], r[0]));
1808 }
1809 
1810 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1811  MLIRContext *context, ::std::optional<Location> location,
1812  MatMulOp::Adaptor adaptor,
1813  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1814  ShapeAdaptor lhsShape(adaptor.getA().getType());
1815  ShapeAdaptor rhsShape(adaptor.getB().getType());
1816 
1817  // All shapes are dynamic.
1818  SmallVector<int64_t> outShape;
1819  outShape.resize(3, ShapedType::kDynamic);
1820 
1821  if (lhsShape.hasRank()) {
1822  outShape[0] = lhsShape.getDimSize(0);
1823  outShape[1] = lhsShape.getDimSize(1);
1824  }
1825 
1826  if (rhsShape.hasRank()) {
1827  outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1828  : outShape[0];
1829  outShape[2] = rhsShape.getDimSize(2);
1830  }
1831 
1832  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1833  return success();
1834 }
1835 
1836 LogicalResult MatMulOp::verify() {
1837  auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
1838  auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
1839 
1840  // Must be shaped tensor types
1841  if (!aType)
1842  return emitOpError("expect a shaped tensor for input a, got ")
1843  << getA().getType();
1844 
1845  if (!bType)
1846  return emitOpError("expect a shaped tensor for input b, got ")
1847  << getB().getType();
1848 
1849  auto aElementType = aType.getElementType();
1850  auto bElementType = bType.getElementType();
1851 
1852  auto aQuantizedEType =
1853  llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1854  auto bQuantizedEType =
1855  llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1856 
1857  if (aQuantizedEType || bQuantizedEType) {
1858  if (!aQuantizedEType || !bQuantizedEType) {
1859  return emitOpError("expect operands to be both quantized or both not "
1860  "quantized, got ")
1861  << aElementType << " and " << bElementType;
1862  }
1863  // both a and b have quantized element types
1864  auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1865  auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1866  if (aQuantWidth != bQuantWidth) {
1867  return emitOpError("expect quantized operands to have same widths, got ")
1868  << aQuantWidth << " and " << bQuantWidth;
1869  }
1870  }
1871 
1872  // check a_zp and b_zp
1873  auto aEType = getStorageElementTypeOrSelf(aType);
1874  auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
1875  if (aEType != aZpEType) {
1876  return emitOpError("expect input a and a_zp have the same "
1877  "element type, got ")
1878  << aEType << " and " << aZpEType;
1879  }
1880 
1881  auto bEType = getStorageElementTypeOrSelf(bType);
1882  auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
1883  if (bEType != bZpEType) {
1884  return emitOpError("expect input b and b_zp have the same "
1885  "element type, got ")
1886  << bEType << " and " << bZpEType;
1887  }
1888 
1889  FailureOr<int64_t> maybeAZp = getAZeroPoint();
1890  if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1891  return failure();
1892 
1893  FailureOr<int64_t> maybeBZp = getBZeroPoint();
1894  if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1895  return failure();
1896 
1897  return success();
1898 }
1899 
1900 LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
1901  MLIRContext *context, ::std::optional<Location> location,
1902  MatmulTBlockScaledOp::Adaptor adaptor,
1903  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1904  SmallVector<int64_t, 3> outShape(3, ShapedType::kDynamic);
1905 
1906  const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
1907  if (aDataShape.hasRank()) {
1908  outShape[0] = aDataShape.getDimSize(0);
1909  outShape[1] = aDataShape.getDimSize(1);
1910  }
1911 
1912  const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
1913  if (aScaleShape.hasRank()) {
1914  outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
1915  : outShape[0];
1916  outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
1917  : outShape[1];
1918  }
1919 
1920  // If B batch size is 1, it is broadcast across A's batch size
1921  const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
1922  if (bDataShape.hasRank()) {
1923  const int64_t bDataBatchSize = bDataShape.getDimSize(0);
1924  if (bDataBatchSize != 1)
1925  outShape[0] =
1926  ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
1927  outShape[2] = bDataShape.getDimSize(1);
1928  }
1929 
1930  const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
1931  if (bScaleShape.hasRank()) {
1932  const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
1933  if (bScaleBatchSize != 1)
1934  outShape[0] =
1935  ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
1936  outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
1937  : outShape[2];
1938  }
1939 
1940  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1941  return success();
1942 }
1943 
1944 LogicalResult MatmulTBlockScaledOp::verify() {
1945  // Verify same input data types
1946  const Type aDataType = getAData().getType();
1947  const Type bDataType = getBData().getType();
1948  if (failed(verifySameElementTypes(*this, aDataType, bDataType, "A_data",
1949  "B_data")))
1950  return failure();
1951 
1952  auto tryUpdateDimOrFailure = [&](int64_t &currDim, const int64_t newDim,
1953  const StringRef operandName,
1954  const StringRef dimName) -> LogicalResult {
1955  if (ShapedType::isDynamic(currDim)) {
1956  currDim = newDim;
1957  return success();
1958  } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
1959  return emitOpError("expected ")
1960  << dimName << " of " << operandName << " to match size " << currDim
1961  << ", got " << newDim;
1962  }
1963  return success();
1964  };
1965 
1966  // Verify input shape compatibility
1967  int64_t N = ShapedType::kDynamic;
1968  int64_t D = ShapedType::kDynamic;
1969  int64_t H = ShapedType::kDynamic;
1970  int64_t W = ShapedType::kDynamic;
1971  int64_t C = ShapedType::kDynamic;
1972  int64_t multiplesOfC = ShapedType::kDynamic;
1973 
1974  const ShapeAdaptor aDataShape = ShapeAdaptor(aDataType);
1975  if (aDataShape.hasRank()) {
1976  N = aDataShape.getDimSize(0);
1977  H = aDataShape.getDimSize(1);
1978  C = aDataShape.getDimSize(2);
1979  }
1980 
1981  const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType());
1982  if (aScaleShape.hasRank()) {
1983  if (failed(tryUpdateDimOrFailure(N, aScaleShape.getDimSize(0), "a_scale",
1984  "batch")) ||
1985  failed(tryUpdateDimOrFailure(H, aScaleShape.getDimSize(1), "a_scale",
1986  "height")))
1987  return failure();
1988  multiplesOfC = aScaleShape.getDimSize(2);
1989  }
1990 
1991  const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType);
1992  if (bDataShape.hasRank()) {
1993  if (failed(tryUpdateDimOrFailure(D, bDataShape.getDimSize(0), "b_data",
1994  "batch")) ||
1995  failed(tryUpdateDimOrFailure(C, bDataShape.getDimSize(2), "b_data",
1996  "channels")))
1997  return failure();
1998  W = bDataShape.getDimSize(1);
1999  }
2000 
2001  const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType());
2002  if (bScaleShape.hasRank()) {
2003  if (failed(tryUpdateDimOrFailure(D, bScaleShape.getDimSize(0), "b_scale",
2004  "batch")) ||
2005  failed(tryUpdateDimOrFailure(W, bScaleShape.getDimSize(1), "b_scale",
2006  "width")) ||
2007  failed(tryUpdateDimOrFailure(multiplesOfC, bScaleShape.getDimSize(2),
2008  "b_scale", "C/block_size")))
2009  return failure();
2010  }
2011 
2012  // Verify batch size is broadcast compatible
2013  if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
2014  return emitOpError("expect B matrix batch size to be broadcast compatible "
2015  "with A, got D=")
2016  << D << " vs N=" << N;
2017 
2018  // Verify C is a multiple of block size
2019  const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
2020  if (ShapedType::isStatic(C) && C % blockSize != 0)
2021  return emitOpError("expect C to be a multiple of block size, got C=")
2022  << C << ", block_size=" << blockSize;
2023 
2024  // Verify multiplesOfC is C / block size
2025  if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
2026  multiplesOfC != C / blockSize)
2027  return emitOpError(
2028  "expect scale operands dimension 2 to equal C/block_size (")
2029  << C << "/" << blockSize << ")"
2030  << ", got " << multiplesOfC;
2031 
2032  // Verify output shape
2033  N = ShapedType::isDynamic(N) ? D : N;
2034  const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W};
2035  const auto outputType = cast<ShapedType>(getResult().getType());
2036  if (outputType.hasRank() &&
2037  failed(
2038  verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) {
2039  InFlightDiagnostic opError = emitOpError("expected output shape ");
2040  auto stringifyDim = [&](int64_t d) {
2041  if (ShapedType::isDynamic(d))
2042  opError << "?";
2043  else
2044  opError << d;
2045  };
2046  llvm::interleaveComma(outputType.getShape(), opError, stringifyDim);
2047  opError << " to be compatible with expected output shape ";
2048  llvm::interleaveComma(expectedOutputShape, opError, stringifyDim);
2049  return opError;
2050  }
2051 
2052  return success();
2053 }
2054 
2055 LogicalResult tosa::PadOp::inferReturnTypeComponents(
2056  MLIRContext *context, ::std::optional<Location> location,
2057  PadOp::Adaptor adaptor,
2058  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2059  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2060  auto paddingRank =
2061  cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
2062  SmallVector<int64_t> outputShape;
2063 
2064  // If the input rank is unknown, we can infer the output rank using the
2065  // padding shape's rank divided by 2.
2066  if (!inputShape.hasRank()) {
2067  outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
2068  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2069  return success();
2070  }
2071 
2072  SmallVector<int64_t> paddingValues;
2073  // If the paddings value is not a constant, all dimensions must be dynamic.
2074  if (!tosa::getConstShapeValues(adaptor.getPadding().getDefiningOp(),
2075  paddingValues)) {
2076  outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
2077  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2078  return success();
2079  }
2080 
2081  outputShape.reserve(inputShape.getRank());
2082  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2083  if (inputShape.isDynamicDim(i)) {
2084  outputShape.push_back(ShapedType::kDynamic);
2085  continue;
2086  }
2087  auto padFront = paddingValues[i * 2];
2088  auto padBack = paddingValues[i * 2 + 1];
2089  if (padFront < 0 || padBack < 0) {
2090  // if either padding for dim i is -1, output dim is unknown
2091  outputShape.push_back(ShapedType::kDynamic);
2092  continue;
2093  }
2094 
2095  outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
2096  }
2097 
2098  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2099  return success();
2100 }
2101 
2102 LogicalResult tosa::PadOp::verify() {
2103  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2104  /* outType = */ getOutput().getType())
2105  .failed()) {
2106  return failure();
2107  }
2108 
2109  if (auto padConst = getPadConst()) {
2110  if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
2111  /* outType = */ getOutput().getType())
2112  .failed()) {
2113  return failure();
2114  }
2115  }
2116 
2117  RankedTensorType inputType =
2118  llvm::dyn_cast<RankedTensorType>(getInput1().getType());
2119  RankedTensorType outputType =
2120  llvm::dyn_cast<RankedTensorType>(getOutput().getType());
2121  if (!inputType || !outputType)
2122  return success();
2123 
2124  auto inputRank = inputType.getRank();
2125  auto outputRank = outputType.getRank();
2126  if (inputRank != outputRank)
2127  return emitOpError() << "expect same input and output tensor rank, but got "
2128  << "inputRank: " << inputRank
2129  << ", outputRank: " << outputRank;
2130 
2131  DenseIntElementsAttr paddingAttr;
2132  if (!matchPattern(getPadding(), m_Constant(&paddingAttr))) {
2133  return failure();
2134  }
2135 
2136  auto paddingValues = paddingAttr.getValues<APInt>();
2137  if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
2138  return emitOpError() << "padding tensor must have " << inputRank
2139  << " * 2 = " << inputRank * 2 << " elements, but got "
2140  << paddingValues.size();
2141 
2142  auto inputShape = inputType.getShape();
2143  auto outputShape = outputType.getShape();
2144 
2145  for (int64_t i = 0; i < inputRank; ++i) {
2146  int64_t padStart = paddingValues[i * 2].getSExtValue();
2147  int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
2148 
2149  if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
2150  return emitOpError()
2151  << "invalid padding values at dimension " << i
2152  << ": values must be non-negative or -1 for dynamic padding, got ["
2153  << padStart << ", " << padEnd << "]";
2154  }
2155 
2156  // Skip shape verification for dynamic input/output
2157  if (inputShape[i] == ShapedType::kDynamic ||
2158  outputShape[i] == ShapedType::kDynamic)
2159  continue;
2160 
2161  if (outputShape[i] != inputShape[i] + padStart + padEnd) {
2162  return emitOpError() << "mismatch in output shape at dimension " << i
2163  << ": expected " << inputShape[i] << " + "
2164  << padStart << " + " << padEnd << " = "
2165  << (inputShape[i] + padStart + padEnd)
2166  << ", but got " << outputShape[i];
2167  }
2168  }
2169 
2170  return success();
2171 }
2172 
2173 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
2174  MLIRContext *context, ::std::optional<Location> location,
2175  SliceOp::Adaptor adaptor,
2176  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2177 
2178  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2179  SmallVector<int64_t> start;
2180  SmallVector<int64_t> size;
2181 
2182  if (!tosa::getConstShapeValues(adaptor.getStart().getDefiningOp(), start) ||
2183  !tosa::getConstShapeValues(adaptor.getSize().getDefiningOp(), size)) {
2184  auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
2185  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2186  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2187  return success();
2188  }
2189 
2190  // if size[i] is -1, all remaining elements in dimension i are included
2191  // in the slice, similar to TF.
2192  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2193  // initialize outputShape to all unknown
2194  SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
2195  if (inputShape.hasRank()) {
2196  for (size_t i = 0; i < size.size(); i++) {
2197  if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
2198  (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
2199  start[i] < inputShape.getDimSize(i))) {
2200  // size[i] is not 0 and not < -1, and start[i] is in valid range
2201  if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
2202  // input shape has unknown dim[i] - only valid if size[i] > 0
2203  if (size[i] > 0) {
2204  outputShape[i] = size[i];
2205  }
2206  } else {
2207  // input shape has known dim[i]
2208  if (size[i] == -1) {
2209  outputShape[i] = inputShape.getDimSize(i) - start[i];
2210  } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2211  // start[i] + size[i] is within bound of input shape's dim[i]
2212  outputShape[i] = size[i];
2213  }
2214  }
2215  }
2216  }
2217  } else {
2218  outputShape = convertToMlirShape(size);
2219  }
2220  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2221  return success();
2222 }
2223 
2224 LogicalResult tosa::SliceOp::verify() {
2225  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2226  /* outType = */ getOutput().getType())
2227  .failed())
2228  return failure();
2229 
2230  const ShapeAdaptor inputShape(getInput1().getType());
2231  if (inputShape.hasRank()) {
2232  const auto inputRank = inputShape.getRank();
2233  const ShapeAdaptor outputShape(getOutput().getType());
2234  if (outputShape.hasRank() && inputRank != outputShape.getRank())
2235  return emitOpError(
2236  "expect input1 and output to have the same ranks, got ")
2237  << inputRank << " and " << outputShape.getRank();
2238 
2239  const auto startShapeRank =
2240  llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
2241  if (inputRank != startShapeRank)
2242  return emitOpError("length of start is not equal to rank of input shape");
2243 
2244  const auto sizeShapeRank =
2245  llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
2246  if (inputRank != sizeShapeRank)
2247  return emitOpError("length of size is not equal to rank of input shape");
2248  }
2249 
2250  return success();
2251 }
2252 
2253 LogicalResult tosa::MulOp::inferReturnTypeComponents(
2254  MLIRContext *context, ::std::optional<Location> location,
2255  ValueShapeRange operands, DictionaryAttr attributes,
2256  OpaqueProperties properties, RegionRange regions,
2257  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2258  // mul op's output shape only depend on input1 and input2, not on shift
2259  ValueShapeRange twoInputs = operands.drop_back();
2260  llvm::SmallVector<int64_t> outShape;
2261  if (resolveBroadcastShape(twoInputs, outShape).failed()) {
2262  inferredReturnShapes.push_back(ShapedTypeComponents());
2263  } else {
2264  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2265  }
2266  return success();
2267 }
2268 
2269 LogicalResult tosa::MulOp::verify() {
2270  const Value output = getOutput();
2271  auto resElemType = getElementTypeOrSelf(output);
2272 
2273  // Verify if the element type among operands and result match tosa
2274  // specification.
2275  if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2276  IntegerType lhsIntType =
2277  dyn_cast<IntegerType>(getElementTypeOrSelf(getInput1()));
2278  IntegerType rhsIntType =
2279  dyn_cast<IntegerType>(getElementTypeOrSelf(getInput2()));
2280  if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2281  return emitOpError("requires the same element type for all operands");
2282 
2283  // Though the spec requires the element type of result to be i32, a more
2284  // relaxed way is provided at dialect level for easier cooperating with
2285  // other dialects.
2286  if (lhsIntType.getWidth() > resIntType.getWidth())
2287  return emitOpError("invalid data type size for operands or result");
2288 
2289  } else {
2290  // For other supported type, the spec requires requires the same element
2291  // type for all operands (excludes `shift` operand) and results.
2292  for (int i = 0; i < 2; ++i) {
2293  if (getElementTypeOrSelf(getOperand(i)) != resElemType)
2294  return emitOpError(
2295  "requires the same element type for all operands and results");
2296  }
2297 
2298  // verify shift has value 0 for non-integer types
2299  ElementsAttr shift_elem;
2300  if (matchPattern(getShift(), m_Constant(&shift_elem))) {
2301  int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
2302  if (shift != 0) {
2303  return emitOpError() << "require shift to be 0 for float type";
2304  }
2305  }
2306  }
2307 
2308  // Verify the op has same ranks for all main operands (excludes extra operands
2309  // such as shift of mul op, so this is the only difference with the built-in
2310  // `SameOperandsAndResultRank` trait) and results types, if known.
2311  TypeRange operandTypes = getOperandTypes();
2312  ShapedType aType = cast<ShapedType>(operandTypes[0]);
2313  ShapedType bType = cast<ShapedType>(operandTypes[1]);
2314 
2315  const bool aHasRank = aType.hasRank();
2316  const bool bHasRank = bType.hasRank();
2317  if (aHasRank && bHasRank) {
2318  const int64_t aRank = aType.getRank();
2319  const int64_t bRank = bType.getRank();
2320  if (aRank != bRank)
2321  return emitOpError("a and b operands don't have matching ranks, got ")
2322  << aRank << " and " << bRank;
2323 
2324  // check for broadcast compatible shapes
2325  SmallVector<int64_t> resultShape;
2327  aType.getShape(), bType.getShape(), resultShape))
2328  return emitOpError("a and b operands don't have broadcast-compatible "
2329  "shapes, got ")
2330  << aType << " and " << bType;
2331  }
2332 
2333  ShapedType resultType = cast<ShapedType>(output.getType());
2334  if (!resultType.hasRank())
2335  return success();
2336 
2337  const int64_t resultRank = resultType.getRank();
2338  if (aHasRank && resultRank != aType.getRank())
2339  return emitOpError("result type has different rank than a, got ")
2340  << resultRank << " vs " << aType.getRank();
2341  if (bHasRank && resultRank != bType.getRank())
2342  return emitOpError("result type has different rank than b, got ")
2343  << resultRank << " vs " << bType.getRank();
2344 
2345  return success();
2346 }
2347 
2348 LogicalResult tosa::TableOp::inferReturnTypeComponents(
2349  MLIRContext *context, ::std::optional<Location> location,
2350  TableOp::Adaptor adaptor,
2351  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2352  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2353 
2354  if (!inputShape.hasRank()) {
2355  inferredReturnShapes.push_back(ShapedTypeComponents());
2356  return success();
2357  }
2358 
2359  inferredReturnShapes.resize(1);
2360  inputShape.getDims(inferredReturnShapes[0]);
2361  return success();
2362 }
2363 
2364 LogicalResult tosa::TableOp::verify() {
2365  const TensorType inputType = getInput1().getType();
2366  const TensorType outputType = getOutput().getType();
2367 
2368  if (!inputType.hasRank() || !outputType.hasRank())
2369  return success();
2370 
2371  if (inputType.getRank() != outputType.getRank())
2372  return emitOpError()
2373  << "expected input tensor rank to equal result tensor rank";
2374 
2375  auto inputDims = inputType.getShape();
2376  auto outputDims = outputType.getShape();
2377  for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
2378  int64_t dim = it.index();
2379  auto [inputDim, outputDim] = it.value();
2380  if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2381  return emitOpError() << "dim(result, " << dim << ") = " << outputDim
2382  << " doesn't match dim(input, " << dim
2383  << ") = " << inputDim;
2384  }
2385  }
2386  return success();
2387 }
2388 
2389 LogicalResult
2390 tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
2391  // Multiples must be constants.
2392  DenseIntElementsAttr multiplesAttr;
2393  if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
2394  return failure();
2395  multiples = llvm::to_vector(
2396  llvm::map_range(multiplesAttr.getValues<APInt>(),
2397  [](const APInt &val) { return val.getSExtValue(); }));
2398  return success();
2399 }
2400 
2401 LogicalResult tosa::TileOp::inferReturnTypeComponents(
2402  MLIRContext *context, ::std::optional<Location> location,
2403  TileOp::Adaptor adaptor,
2404  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2405  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2406  SmallVector<int64_t> multiples;
2407  if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
2408  multiples)) {
2409  auto rank =
2410  cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2411  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2412  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2413  return success();
2414  } else {
2415  multiples = convertToMlirShape(multiples);
2416  }
2417 
2418  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2419  SmallVector<int64_t> outputShape;
2420  if (!inputShape.hasRank()) {
2421  outputShape.resize(multiples.size(), ShapedType::kDynamic);
2422  inferredReturnShapes.push_back(
2423  ShapedTypeComponents(outputShape, inputType));
2424  return success();
2425  } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
2426  return failure();
2427 
2428  // Any non dynamic dimension can be multiplied to a known size.
2429  outputShape.reserve(multiples.size());
2430  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2431  if (multiples[i] == ShapedType::kDynamic) {
2432  outputShape.push_back(ShapedType::kDynamic);
2433  } else {
2434  int64_t dim = inputShape.getDimSize(i);
2435  if (dim != ShapedType::kDynamic)
2436  dim *= multiples[i];
2437  outputShape.push_back(dim);
2438  }
2439  }
2440 
2441  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2442  return success();
2443 }
2444 
2445 LogicalResult tosa::TileOp::verify() {
2446  if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
2447  /* outType = */ getOutput().getType())
2448  .failed()) {
2449  return failure();
2450  }
2451  ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
2452  ShapedType outputType = llvm::cast<ShapedType>(getType());
2453 
2454  shapeType multiplesType =
2455  llvm::cast<tosa::shapeType>(getMultiples().getType());
2456 
2457  auto multiplesRank = multiplesType.getRank();
2458 
2459  if (inputType.hasRank()) {
2460  if (inputType.getRank() != multiplesRank)
2461  return emitOpError("expect 'multiples' to have rank ")
2462  << inputType.getRank() << " but got " << multiplesRank << ".";
2463  if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
2464  return emitOpError("expect same input and output tensor rank.");
2465  } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2466  return emitOpError("expect 'multiples' array to have length ")
2467  << outputType.getRank() << " but got " << multiplesRank << ".";
2468 
2469  SmallVector<int64_t> multiples;
2470  if (getConstantMultiples(multiples).succeeded() &&
2471  llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
2472  return emitOpError(
2473  "expect element of 'multiples' to be positive integer or -1.");
2474 
2475  return success();
2476 }
2477 
2478 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2479  if (l.size() != r.size() || l.size() != 1)
2480  return false;
2481  return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
2482 }
2483 
2484 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2485  MLIRContext *context, ::std::optional<Location> location,
2486  ReshapeOp::Adaptor adaptor,
2487  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2488  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2489  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2490  llvm::SmallVector<int64_t> newShapeValue;
2491  if (!tosa::getConstShapeValues(adaptor.getShape().getDefiningOp(),
2492  newShapeValue)) {
2493  auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2494  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2495  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2496  return success();
2497  } else {
2498  newShapeValue = convertToMlirShape(newShapeValue);
2499  }
2500 
2501  // We cannot infer from the total number of elements so we must take the
2502  // shape attribute as exact.
2503  if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2504  inferredReturnShapes.push_back(
2505  ShapedTypeComponents(newShapeValue, inputType));
2506  return success();
2507  }
2508 
2509  // Determine the number of elements covered by the slice of all static
2510  // dimensions. This allows us to infer the length of the remaining dynamic
2511  // dimension.
2512  int64_t numElements = inputShape.getNumElements();
2513  int64_t staticMul = 1;
2514  for (auto val : newShapeValue) {
2515  if (ShapedType::isStatic(val)) {
2516  staticMul *= val;
2517  }
2518  }
2519 
2520  // Determine the length of the dynamic dimension.
2521  for (auto &val : newShapeValue) {
2522  if (ShapedType::isDynamic(val))
2523  val = numElements / staticMul;
2524  }
2525 
2526  inferredReturnShapes.push_back(
2527  ShapedTypeComponents(newShapeValue, inputType));
2528  return success();
2529 }
2530 
2531 llvm::LogicalResult tosa::ReshapeOp::verify() {
2532  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2533  /* outType = */ getOutput().getType())
2534  .failed()) {
2535  return failure();
2536  }
2537  TensorType inputType = getInput1().getType();
2538 
2539  SmallVector<int64_t> shapeValues;
2540  if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
2541  // skip following checks if shape is not constant
2542  return mlir::success();
2543  }
2544 
2545  int missingDims = llvm::count(shapeValues, -1);
2546  if (missingDims > 1)
2547  return emitOpError() << "expected at most one target dimension to be -1";
2548 
2549  const auto outputType = dyn_cast<RankedTensorType>(getType());
2550  if (!outputType)
2551  return success();
2552 
2553  if ((int64_t)shapeValues.size() != outputType.getRank())
2554  return emitOpError() << "new shape does not match result rank";
2555 
2556  for (auto [newShapeDim, outputShapeDim] :
2557  zip(shapeValues, outputType.getShape())) {
2558  if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2559  outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2560  return emitOpError() << "new shape is inconsistent with result shape";
2561 
2562  if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2563  return emitOpError() << "new shape has invalid tensor dimension size "
2564  << newShapeDim;
2565  }
2566 
2567  if (inputType.hasStaticShape()) {
2568  int64_t inputElementsNum = inputType.getNumElements();
2569  if (outputType.hasStaticShape()) {
2570  int64_t outputElementsNum = outputType.getNumElements();
2571  if (inputElementsNum != outputElementsNum) {
2572  return emitOpError() << "cannot reshape " << inputElementsNum
2573  << " elements into " << outputElementsNum;
2574  }
2575  }
2576 
2577  int64_t newShapeElementsNum =
2578  llvm::accumulate(shapeValues, int64_t(1), [](int64_t acc, int64_t dim) {
2579  return (dim > 0) ? acc * dim : acc;
2580  });
2581  bool isStaticNewShape =
2582  llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
2583  if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2584  (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2585  return emitOpError() << "cannot reshape " << inputElementsNum
2586  << " elements into " << newShapeElementsNum;
2587  }
2588  }
2589 
2590  return mlir::success();
2591 }
2592 
2593 // return failure if val is not a constant
2594 // set zp to -1 if val is non-zero float or val is not integer nor float
2595 // otherwise set zp to val's constant value
2596 static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
2597  ElementsAttr zpAttr;
2598  if (!matchPattern(val, m_Constant(&zpAttr))) {
2599  return failure();
2600  }
2601 
2602  Type zpElemType = zpAttr.getElementType();
2603 
2604  if (llvm::isa<FloatType>(zpElemType)) {
2605  if (zpAttr.getValues<APFloat>()[0].isZero()) {
2606  return 0;
2607  }
2608  // return non-zero value to trigger error check
2609  return -1;
2610  }
2611 
2612  if (llvm::isa<IntegerType>(zpElemType)) {
2613  if (signExtend)
2614  return zpAttr.getValues<APInt>()[0].getSExtValue();
2615  else
2616  return zpAttr.getValues<APInt>()[0].getZExtValue();
2617  }
2618 
2619  // return non-zero value to trigger error check
2620  return -1;
2621 }
2622 
2623 template <typename T>
2624 static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
2625  const std::string &operand) {
2626  Type zpElemType = getElementTypeOrSelf(val);
2627 
2628  if (!zpElemType.isInteger(8) && zp != 0) {
2629  // convert operand to lower case for error message
2630  std::string lower = operand;
2631  std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
2632  return op.emitOpError()
2633  << lower << " zero point must be zero for non-int8 integer types";
2634  }
2635 
2636  return success();
2637 }
2638 
2639 static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
2640  const int64_t &zp,
2641  const std::string &operand) {
2642  bool isInputZp = (operand == "Input");
2643 
2644  bool tensorUnsigned =
2645  isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2646  StringRef tensorName = isInputZp ? "input" : "output";
2647 
2648  Type zpElemType = getElementTypeOrSelf(zpVal);
2649 
2650  if (zp != 0) {
2651  if (!zpElemType.isInteger(8) &&
2652  !(zpElemType.isInteger(16) && tensorUnsigned)) {
2653  return op.emitOpError()
2654  << "expect " << tensorName << "_zp of 0, got " << zp;
2655  }
2656  if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
2657  return op.emitOpError() << "expect " << tensorName
2658  << "_zp of 0 or 32768 for unsigned int16 "
2659  << tensorName << ", got " << zp;
2660  }
2661  }
2662 
2663  return success();
2664 }
2665 
2666 #define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2667  FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2668  return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2669  } \
2670  LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2671  return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2672  }
2673 
2674 ZERO_POINT_HELPER(Conv2DOp, Input, true)
2675 ZERO_POINT_HELPER(Conv2DOp, Weight, true)
2676 ZERO_POINT_HELPER(Conv3DOp, Input, true)
2677 ZERO_POINT_HELPER(Conv3DOp, Weight, true)
2678 ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true)
2679 ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true)
2680 ZERO_POINT_HELPER(TransposeConv2DOp, Input, true)
2681 ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true)
2682 ZERO_POINT_HELPER(AvgPool2dOp, Input, true)
2683 ZERO_POINT_HELPER(AvgPool2dOp, Output, true)
2684 ZERO_POINT_HELPER(MatMulOp, A, true)
2685 ZERO_POINT_HELPER(MatMulOp, B, true)
2686 ZERO_POINT_HELPER(NegateOp, Input1, true)
2687 ZERO_POINT_HELPER(NegateOp, Output, true)
2688 ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned())
2689 ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned())
2690 #undef ZERO_POINT_HELPER
2691 
2692 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2693  MLIRContext *context, ::std::optional<Location> location,
2694  TransposeOp::Adaptor adaptor,
2695  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2696  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2697 
2698  // If input rank and permutation length is unknown, the output rank is
2699  // unknown.
2700  if (!inputShape.hasRank()) {
2701  inferredReturnShapes.push_back(ShapedTypeComponents());
2702  return success();
2703  }
2704 
2705  const auto inputRank = inputShape.getRank();
2706 
2707  // This would imply the number of permutations does not match the rank of
2708  // the input which is illegal.
2709  if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
2710  return failure();
2711  }
2712 
2713  SmallVector<int64_t> outputShape;
2714  // Rank-0 means no permutations matter.
2715  if (inputRank == 0) {
2716  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2717  return success();
2718  }
2719 
2720  // Check whether the input dimensions are all the same.
2721  bool allTheSame = true;
2722  for (int i = 1, s = inputRank; i < s; i++) {
2723  if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
2724  allTheSame = false;
2725  break;
2726  }
2727  }
2728 
2729  // If all of the input dimensions are the same we don't care about the
2730  // permutation.
2731  if (allTheSame) {
2732  outputShape.resize(inputRank, inputShape.getDimSize(0));
2733  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2734  return success();
2735  }
2736 
2737  outputShape.resize(inputRank, ShapedType::kDynamic);
2738 
2739  // Constant permutation values must be within the input rank.
2740  if (llvm::any_of(adaptor.getPerms(),
2741  [inputRank](const auto i) { return i >= inputRank; }))
2742  return failure();
2743 
2744  outputShape.reserve(inputRank);
2745  for (int i = 0, s = inputRank; i < s; i++) {
2746  outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
2747  }
2748 
2749  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2750  return success();
2751 }
2752 
2753 LogicalResult tosa::TransposeOp::verify() {
2754  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2755  /* outType = */ getOutput().getType())
2756  .failed()) {
2757  return failure();
2758  }
2759 
2760  const ShapeAdaptor inputShape(getInput1().getType());
2761  const ShapeAdaptor outputShape(getOutput().getType());
2762 
2763  const llvm::ArrayRef<int32_t> constantPerms = getPerms();
2764 
2765  if (inputShape.hasRank() &&
2766  constantPerms.size() != static_cast<size_t>(inputShape.getRank()))
2767  return emitOpError() << "expected perms attribute to have size "
2768  << inputShape.getRank()
2769  << " (input rank) but got size "
2770  << constantPerms.size();
2771 
2772  if (inputShape.hasRank() && outputShape.hasRank() &&
2773  inputShape.getRank() != outputShape.getRank())
2774  return emitOpError()
2775  << "expected input tensor rank to equal result tensor rank";
2776 
2777  if (outputShape.hasRank() &&
2778  constantPerms.size() != static_cast<size_t>(outputShape.getRank()))
2779  return emitOpError() << "expected perms attribute to have size "
2780  << outputShape.getRank()
2781  << " (output rank) but got size "
2782  << constantPerms.size();
2783 
2784  if (!llvm::all_of(constantPerms,
2785  [&constantPerms](int32_t s) {
2786  return s >= 0 &&
2787  static_cast<size_t>(s) < constantPerms.size();
2788  }) ||
2789  !isPermutationVector(llvm::to_vector(llvm::map_range(
2790  constantPerms, [](int32_t v) -> int64_t { return v; }))))
2791  return emitOpError() << "expected valid permutation indices";
2792 
2793  // ERROR_IF(tensor_size(shape1) != tensor_size(shape))
2794  if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2795  inputShape.getNumElements() != outputShape.getNumElements())
2796  return emitOpError() << "expected input1 and output to have same numbers "
2797  "of elements, got "
2798  << inputShape.getNumElements() << " and "
2799  << outputShape.getNumElements();
2800 
2801  // Verify that the types of the input and output tensors are properly
2802  // permuted.
2803  if (inputShape.hasRank() && outputShape.hasRank()) {
2804  for (auto i = 0; i < outputShape.getRank(); i++) {
2805  if (inputShape.isDynamicDim(constantPerms[i]) ||
2806  outputShape.isDynamicDim(i))
2807  continue;
2808 
2809  if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2810  return emitOpError()
2811  << "expected output tensor dim " << i << " to match "
2812  << "input dim " << constantPerms[i] << " with value of "
2813  << inputShape.getDimSize(constantPerms[i]);
2814  }
2815  }
2816 
2817  return success();
2818 }
2819 
2820 LogicalResult TransposeOp::reifyResultShapes(
2821  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2822 
2823  const llvm::ArrayRef<int32_t> transposePerms = getPerms();
2824 
2825  Value input = getInput1();
2826  auto inputType = cast<TensorType>(input.getType());
2827 
2828  SmallVector<OpFoldResult> returnedDims(inputType.getRank());
2829  for (auto dim : transposePerms) {
2830  int32_t dimInInput = transposePerms[dim];
2831  if (inputType.isDynamicDim(dimInInput))
2832  returnedDims[dim] =
2833  tensor::DimOp::create(builder, getLoc(), input, dimInInput)
2834  .getResult();
2835  else
2836  returnedDims[dim] =
2837  builder.getIndexAttr(inputType.getDimSize(dimInInput));
2838  }
2839 
2840  reifiedReturnShapes.emplace_back(std::move(returnedDims));
2841  return success();
2842 }
2843 
2844 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2845  MLIRContext *context, ::std::optional<Location> location,
2846  GatherOp::Adaptor adaptor,
2847  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2848  llvm::SmallVector<int64_t> outputShape;
2849  outputShape.resize(3, ShapedType::kDynamic);
2850 
2851  ShapeAdaptor valuesShape(adaptor.getValues().getType());
2852  if (valuesShape.hasRank()) {
2853  outputShape[0] = valuesShape.getDimSize(0);
2854  outputShape[2] = valuesShape.getDimSize(2);
2855  }
2856 
2857  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2858  if (indicesShape.hasRank()) {
2859  if (outputShape[0] == ShapedType::kDynamic)
2860  outputShape[0] = indicesShape.getDimSize(0);
2861  if (outputShape[1] == ShapedType::kDynamic)
2862  outputShape[1] = indicesShape.getDimSize(1);
2863  }
2864 
2865  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2866  return success();
2867 }
2868 
2869 LogicalResult tosa::GatherOp::verify() {
2870  if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
2871  /* outType = */ getOutput().getType())
2872  .failed()) {
2873  return failure();
2874  }
2875 
2876  const ShapeAdaptor valuesShape(getValues().getType());
2877  const ShapeAdaptor indicesShape(getIndices().getType());
2878  const ShapeAdaptor outputShape(getOutput().getType());
2879 
2880  int64_t N = ShapedType::kDynamic;
2881  int64_t W = ShapedType::kDynamic;
2882  int64_t C = ShapedType::kDynamic;
2883 
2884  if (valuesShape.hasRank()) {
2885  N = valuesShape.getDimSize(0);
2886  C = valuesShape.getDimSize(2);
2887  }
2888  if (indicesShape.hasRank()) {
2889  const int64_t indicesN = indicesShape.getDimSize(0);
2890  W = indicesShape.getDimSize(1);
2891  if (N == ShapedType::kDynamic)
2892  N = indicesN;
2893  else if (indicesN != ShapedType::kDynamic && N != indicesN)
2894  return emitOpError() << "requires indices dimension 0 to have size " << N
2895  << ", got " << indicesN;
2896  }
2897  if (outputShape.hasRank()) {
2898  const int64_t outputN = outputShape.getDimSize(0);
2899  const int64_t outputW = outputShape.getDimSize(1);
2900  const int64_t outputC = outputShape.getDimSize(2);
2901  if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2902  N != outputN)
2903  return emitOpError() << "requires output dimension 0 to have size " << N
2904  << ", got " << outputN;
2905 
2906  if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2907  W != outputW)
2908  return emitOpError() << "requires output dimension 1 to have size " << W
2909  << ", got " << outputW;
2910  if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2911  C != outputC)
2912  return emitOpError() << "requires output dimension 2 to have size " << C
2913  << ", got " << outputC;
2914  }
2915  return success();
2916 }
2917 
2918 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2919  MLIRContext *context, ::std::optional<Location> location,
2920  ResizeOp::Adaptor adaptor,
2921  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2922  llvm::SmallVector<int64_t, 4> outputShape;
2923  outputShape.resize(4, ShapedType::kDynamic);
2924 
2925  ShapeAdaptor inputShape(adaptor.getInput().getType());
2926  if (!inputShape.hasRank())
2927  return failure();
2928 
2929  outputShape[0] = inputShape.getDimSize(0);
2930  outputShape[3] = inputShape.getDimSize(3);
2931  int64_t inputHeight = inputShape.getDimSize(1);
2932  int64_t inputWidth = inputShape.getDimSize(2);
2933 
2934  if ((inputHeight == ShapedType::kDynamic) ||
2935  (inputWidth == ShapedType::kDynamic))
2936  return failure();
2937 
2938  SmallVector<int64_t> scaleInt, offsetInt, borderInt;
2939  if (!tosa::getConstShapeValues(adaptor.getScale().getDefiningOp(),
2940  scaleInt) ||
2941  !tosa::getConstShapeValues(adaptor.getOffset().getDefiningOp(),
2942  offsetInt) ||
2943  !tosa::getConstShapeValues(adaptor.getBorder().getDefiningOp(),
2944  borderInt)) {
2945  return failure();
2946  }
2947 
2948  // Compute the output shape based on attributes: scale, offset, and border.
2949  const int64_t outputHeight =
2950  (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2951  scaleInt[1]) +
2952  1;
2953 
2954  const int64_t outputWidth =
2955  (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2956  scaleInt[3]) +
2957  1;
2958 
2959  if (outputHeight < 0 || outputWidth < 0) {
2960  return emitOptionalError(
2961  location,
2962  "calculated output height and width must be non-negative, "
2963  "got height = ",
2964  outputHeight, ", width = ", outputWidth);
2965  }
2966 
2967  outputShape[1] = outputHeight;
2968  outputShape[2] = outputWidth;
2969  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2970  return success();
2971 }
2972 
2973 LogicalResult tosa::ResizeOp::verify() {
2974  const Value input = getInput();
2975  const Value output = getOutput();
2976  const RankedTensorType inputType =
2977  llvm::dyn_cast<RankedTensorType>(input.getType());
2978  const RankedTensorType outputType =
2979  llvm::dyn_cast<RankedTensorType>(output.getType());
2980 
2981  SmallVector<int64_t> scaleValues;
2982  SmallVector<int64_t> offsetValues;
2983  SmallVector<int64_t> borderValues;
2984  if (!tosa::getConstShapeValues(getScale().getDefiningOp(), scaleValues) ||
2985  !tosa::getConstShapeValues(getOffset().getDefiningOp(), offsetValues) ||
2986  !tosa::getConstShapeValues(getBorder().getDefiningOp(), borderValues)) {
2987  // Skip following checks if shape is not constant
2988  return success();
2989  }
2990 
2991  if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
2992  return emitOpError("expect all scale values to be > 0, got ")
2993  << scaleValues;
2994 
2995  const int64_t scaleYN = scaleValues[0];
2996  const int64_t scaleYD = scaleValues[1];
2997  const int64_t scaleXN = scaleValues[2];
2998  const int64_t scaleXD = scaleValues[3];
2999 
3000  const int64_t offsetY = offsetValues[0];
3001  const int64_t offsetX = offsetValues[1];
3002 
3003  const int64_t borderY = borderValues[0];
3004  const int64_t borderX = borderValues[1];
3005 
3006  if (!inputType)
3007  return success();
3008  if (!outputType)
3009  return success();
3010 
3011  const int64_t oh = outputType.getDimSize(1);
3012  const int64_t ow = outputType.getDimSize(2);
3013  const int64_t ih = inputType.getDimSize(1);
3014  const int64_t iw = inputType.getDimSize(2);
3015 
3016  // Don't check with input height that could be broadcast (ih != 1)
3017  // since Linalg, a consumer of TOSA, expects broadcasting support
3018  // in resize to be available. Taking the cautious approach for now,
3019  // we can consider removing support for broadcasting later.
3020  if (ih != ShapedType::kDynamic && ih != 1) {
3021  const std::optional<int64_t> calculatedOutHeightMinusOne =
3022  idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
3023  if (!calculatedOutHeightMinusOne.has_value())
3024  return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
3025  "border_y ")
3026  << "to be wholly divisible by scale_y_d, got ((" << ih
3027  << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
3028  << ") / " << scaleYD;
3029  const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
3030  if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
3031  return emitOpError("calculated output height did not match expected: ")
3032  << "calculated=" << calculatedOutHeight << ", expected=" << oh;
3033  }
3034 
3035  // Don't check with input width that could be broadcast (iw != 1)
3036  // since Linalg, a consumer of TOSA, expects broadcasting support
3037  // in resize to be available. Taking the cautious approach for now,
3038  // we can consider removing support for broadcasting later.
3039  if (iw != ShapedType::kDynamic && iw != 1) {
3040  const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
3041  const std::optional<int64_t> calculatedOutWidthMinusOne =
3042  idivCheck(scaledInWidth, scaleXD);
3043  if (!calculatedOutWidthMinusOne.has_value())
3044  return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
3045  "border_x ")
3046  << "to be wholly divisible by scale_x_d, got ((" << iw
3047  << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
3048  << ") / " << scaleXD;
3049  const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
3050  if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
3051  return emitOpError("calculated output width did not match expected: ")
3052  << "calculated=" << calculatedOutWidth << ", expected=" << ow;
3053  }
3054 
3055  return success();
3056 }
3057 
3058 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
3059  MLIRContext *context, ::std::optional<Location> location,
3060  ScatterOp::Adaptor adaptor,
3061  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3062  llvm::SmallVector<int64_t> outputShape;
3063  outputShape.resize(3, ShapedType::kDynamic);
3064 
3065  ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
3066  if (valuesInShape.hasRank()) {
3067  outputShape[0] = valuesInShape.getDimSize(0);
3068  outputShape[1] = valuesInShape.getDimSize(1);
3069  outputShape[2] = valuesInShape.getDimSize(2);
3070  }
3071 
3072  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3073  if (indicesShape.hasRank()) {
3074  if (outputShape[0] == ShapedType::kDynamic)
3075  outputShape[0] = indicesShape.getDimSize(0);
3076  }
3077 
3078  ShapeAdaptor inputShape(adaptor.getInput().getType());
3079  if (inputShape.hasRank()) {
3080  if (outputShape[0] == ShapedType::kDynamic)
3081  outputShape[0] = inputShape.getDimSize(0);
3082  if (outputShape[2] == ShapedType::kDynamic)
3083  outputShape[2] = inputShape.getDimSize(2);
3084  }
3085 
3086  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3087  return success();
3088 }
3089 
3090 LogicalResult tosa::ScatterOp::verify() {
3091  if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
3092  /* outType = */ getValuesOut().getType())
3093  .failed() ||
3094  verifySameElementTypes(*this, /* inType = */ getInput().getType(),
3095  /* outType = */ getValuesOut().getType())
3096  .failed()) {
3097  return failure();
3098  }
3099 
3100  const ShapeAdaptor valuesInShape(getValuesIn().getType());
3101  const ShapeAdaptor indicesShape(getIndices().getType());
3102  const ShapeAdaptor inputShape(getInput().getType());
3103  const ShapeAdaptor outputShape(getValuesOut().getType());
3104 
3105  int64_t N = ShapedType::kDynamic;
3106  int64_t K = ShapedType::kDynamic;
3107  int64_t W = ShapedType::kDynamic;
3108  int64_t C = ShapedType::kDynamic;
3109  if (valuesInShape.hasRank()) {
3110  N = valuesInShape.getDimSize(0);
3111  K = valuesInShape.getDimSize(1);
3112  C = valuesInShape.getDimSize(2);
3113  }
3114  if (indicesShape.hasRank()) {
3115  const int64_t indicesN = indicesShape.getDimSize(0);
3116  W = indicesShape.getDimSize(1);
3117  if (N == ShapedType::kDynamic)
3118  N = indicesN;
3119  else if (indicesN != ShapedType::kDynamic && N != indicesN)
3120  return emitOpError() << "requires indices dimension 0 to have size " << N
3121  << ", got " << indicesN;
3122  }
3123  if (inputShape.hasRank()) {
3124  const int64_t inputN = inputShape.getDimSize(0);
3125  const int64_t inputW = inputShape.getDimSize(1);
3126  const int64_t inputC = inputShape.getDimSize(2);
3127  if (N == ShapedType::kDynamic)
3128  N = inputN;
3129  else if (inputN != ShapedType::kDynamic && N != inputN)
3130  return emitOpError() << "requires input dimension 0 to have size " << N
3131  << ", got " << inputN;
3132  if (W == ShapedType::kDynamic)
3133  W = inputW;
3134  else if (inputW != ShapedType::kDynamic && W != inputW)
3135  return emitOpError() << "requires input dimension 1 to have size " << W
3136  << ", got " << inputW;
3137 
3138  if (C == ShapedType::kDynamic)
3139  C = inputC;
3140  else if (inputC != ShapedType::kDynamic && C != inputC)
3141  return emitOpError() << "requires input dimension 2 to have size " << C
3142  << ", got " << inputC;
3143  }
3144  if (outputShape.hasRank()) {
3145  const int64_t outputN = outputShape.getDimSize(0);
3146  const int64_t outputK = outputShape.getDimSize(1);
3147  const int64_t outputC = outputShape.getDimSize(2);
3148  if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3149  N != outputN)
3150  return emitOpError() << "requires values_out dimension 0 to have size "
3151  << N << ", got " << outputN;
3152  if (K == ShapedType::kDynamic)
3153  K = outputK;
3154  else if (outputK != ShapedType::kDynamic && K != outputK)
3155  return emitOpError() << "requires values_out dimension 1 to have size "
3156  << K << ", got " << outputK;
3157  if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3158  C != outputC)
3159  return emitOpError() << "requires values_out dimension 2 to have size "
3160  << C << ", got " << outputC;
3161  }
3162  if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
3163  return emitOpError() << "requires dimensions K >= W, got K=" << K
3164  << " and W=" << W;
3165 
3166  return success();
3167 }
3168 
3169 static LogicalResult ReduceInferReturnTypes(
3170  ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
3171  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3172  int64_t axisVal = axis.getValue().getSExtValue();
3173  if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
3174  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
3175  return success();
3176  }
3177 
3178  SmallVector<int64_t> outputShape;
3179  operandShape.getDims(outputShape);
3180  outputShape[axisVal] = 1;
3181  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
3182  return success();
3183 }
3184 
3185 #define COMPATIBLE_RETURN_TYPES(OP) \
3186  bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
3187  if (l.size() != r.size() || l.size() != 1) \
3188  return false; \
3189  if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
3190  return false; \
3191  return succeeded(verifyCompatibleShape(l[0], r[0])); \
3192  }
3193 
3194 #define REDUCE_SHAPE_INFER(OP) \
3195  LogicalResult OP::inferReturnTypeComponents( \
3196  MLIRContext *context, ::std::optional<Location> location, \
3197  OP::Adaptor adaptor, \
3198  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3199  Type inputType = \
3200  llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
3201  ShapeAdaptor inputShape(adaptor.getInput().getType()); \
3202  const Properties &prop = adaptor.getProperties(); \
3203  return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
3204  inferredReturnShapes); \
3205  } \
3206  COMPATIBLE_RETURN_TYPES(OP)
3207 
3208 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
3209 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
3210 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
3211 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
3212 REDUCE_SHAPE_INFER(tosa::ReduceProductOp)
3213 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
3214 #undef REDUCE_SHAPE_INFER
3215 COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
3216 #undef COMPATIBLE_RETURN_TYPES
3217 
3218 template <typename T>
3219 static LogicalResult verifyReduceOp(T op) {
3220  // All TOSA reduce Ops have input, output and axis.
3221  TensorType inputType = op.getInput().getType();
3222  TensorType outputType = op.getOutput().getType();
3223  int32_t reduceAxis = op.getAxis();
3224 
3225  if (reduceAxis < 0) {
3226  op.emitOpError("reduce axis must not be negative");
3227  return failure();
3228  }
3229  if (inputType.hasRank()) {
3230  int64_t inputRank = inputType.getRank();
3231  // We allow for a special case where the input/output shape has rank 0 and
3232  // axis is also 0.
3233  if (reduceAxis >= inputRank && (reduceAxis != 0 || inputRank != 0)) {
3234  op.emitOpError("expect input tensor rank (")
3235  << inputRank << ") to be larger than reduce axis (" << reduceAxis
3236  << ")";
3237  return failure();
3238  }
3239  }
3240  if (outputType.hasRank()) {
3241  int64_t outputRank = outputType.getRank();
3242  if (inputType.hasRank() && outputRank != inputType.getRank()) {
3243  op.emitOpError(
3244  "expect output tensor rank to be equal to input tensor rank");
3245  return failure();
3246  }
3247  if (reduceAxis >= outputRank && (reduceAxis != 0 || outputRank != 0)) {
3248  op.emitOpError("expect output tensor rank (")
3249  << outputRank << ") to be larger than reduce axis (" << reduceAxis
3250  << ")";
3251  return failure();
3252  }
3253  // We can only verify the reduced dimension size to be 1 if this is not
3254  // the special case of output rank == 0.
3255  if (outputRank != 0) {
3256  auto outputShape = outputType.getShape();
3257  if (!outputType.isDynamicDim(reduceAxis) &&
3258  outputShape[reduceAxis] != 1) {
3259  op.emitOpError("expect reduced dimension size to be 1, got ")
3260  << outputShape[reduceAxis];
3261  return failure();
3262  }
3263  }
3264  }
3265  return success();
3266 }
3267 
3268 LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
3269 LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
3270 LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
3271 LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
3272 LogicalResult tosa::ReduceProductOp::verify() { return verifyReduceOp(*this); }
3273 LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
3274 
3275 static LogicalResult NAryInferReturnTypes(
3276  const ValueShapeRange &operands,
3277  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3278  llvm::SmallVector<int64_t> outShape;
3279  if (resolveBroadcastShape(operands, outShape).failed()) {
3280  inferredReturnShapes.push_back(ShapedTypeComponents());
3281  } else {
3282  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
3283  }
3284  return success();
3285 }
3286 
3287 #define NARY_SHAPE_INFER(OP) \
3288  LogicalResult OP::inferReturnTypeComponents( \
3289  MLIRContext *context, ::std::optional<Location> location, \
3290  ValueShapeRange operands, DictionaryAttr attributes, \
3291  OpaqueProperties properties, RegionRange regions, \
3292  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3293  return NAryInferReturnTypes(operands, inferredReturnShapes); \
3294  }
3295 
3296 NARY_SHAPE_INFER(tosa::AbsOp)
3297 NARY_SHAPE_INFER(tosa::AddOp)
3298 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
3299 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
3300 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
3301 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
3302 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
3303 NARY_SHAPE_INFER(tosa::CastOp)
3304 NARY_SHAPE_INFER(tosa::CeilOp)
3305 NARY_SHAPE_INFER(tosa::ClampOp)
3306 NARY_SHAPE_INFER(tosa::ClzOp)
3307 NARY_SHAPE_INFER(tosa::CosOp)
3308 NARY_SHAPE_INFER(tosa::ExpOp)
3309 NARY_SHAPE_INFER(tosa::FloorOp)
3310 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
3311 NARY_SHAPE_INFER(tosa::GreaterOp)
3312 NARY_SHAPE_INFER(tosa::IdentityOp)
3313 NARY_SHAPE_INFER(tosa::IntDivOp)
3314 NARY_SHAPE_INFER(tosa::LogOp)
3315 NARY_SHAPE_INFER(tosa::LogicalAndOp)
3316 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
3317 NARY_SHAPE_INFER(tosa::LogicalNotOp)
3318 NARY_SHAPE_INFER(tosa::LogicalOrOp)
3319 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
3320 NARY_SHAPE_INFER(tosa::LogicalXorOp)
3321 NARY_SHAPE_INFER(tosa::MaximumOp)
3322 NARY_SHAPE_INFER(tosa::MinimumOp)
3323 NARY_SHAPE_INFER(tosa::PowOp)
3324 NARY_SHAPE_INFER(tosa::ReciprocalOp)
3325 NARY_SHAPE_INFER(tosa::ReverseOp)
3326 NARY_SHAPE_INFER(tosa::RsqrtOp)
3327 NARY_SHAPE_INFER(tosa::SinOp)
3328 NARY_SHAPE_INFER(tosa::SelectOp)
3329 NARY_SHAPE_INFER(tosa::SubOp)
3330 NARY_SHAPE_INFER(tosa::TanhOp)
3331 NARY_SHAPE_INFER(tosa::ErfOp)
3332 NARY_SHAPE_INFER(tosa::SigmoidOp)
3333 #undef PRED_SHAPE_INFER
3334 
3335 LogicalResult tosa::NegateOp::inferReturnTypeComponents(
3336  MLIRContext *context, ::std::optional<Location> location,
3337  NegateOp::Adaptor adaptor,
3338  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3339  ShapeAdaptor inputShape(adaptor.getInput1().getType());
3340  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3341  return success();
3342 }
3343 
3344 LogicalResult tosa::NegateOp::verify() {
3345  // Verify same element type
3346  const Type input1Type = getInput1().getType();
3347  const Type outputType = getOutput().getType();
3348  if (verifySameElementTypes(*this, input1Type, outputType).failed())
3349  return failure();
3350 
3351  // Verify same shape
3352  const SmallVector<Type, 2> types = {input1Type, outputType};
3353  if (failed(verifyCompatibleShapes(types)))
3354  return emitOpError() << "requires the same shape for input1 and output";
3355 
3356  const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
3357  const Type input1ZpEType =
3358  getStorageElementTypeOrSelf(getInput1Zp().getType());
3359  if (input1EType != input1ZpEType) {
3360  return emitOpError("expect both input1 and its zero point are the same "
3361  "element type, got ")
3362  << input1EType << " and " << input1ZpEType;
3363  }
3364  const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
3365  const Type outputZpEType =
3366  getStorageElementTypeOrSelf(getOutputZp().getType());
3367  if (outputEType != outputZpEType) {
3368  return emitOpError("expect both output and its zero point are the same "
3369  "element type, got ")
3370  << outputEType << " and " << outputZpEType;
3371  }
3372 
3373  FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
3374  if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
3375  return failure();
3376 
3377  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3378  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3379  return failure();
3380 
3381  return success();
3382 }
3383 
3384 static LogicalResult poolingInferReturnTypes(
3385  ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
3386  ArrayRef<int64_t> pad,
3387  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3388  llvm::SmallVector<int64_t> outputShape;
3389  outputShape.resize(4, ShapedType::kDynamic);
3390 
3391  // We only know the rank if the input type is unranked.
3392  if (!inputShape) {
3393  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3394  return success();
3395  }
3396 
3397  // Batch and number of channels are identical for pooling layer.
3398  outputShape[0] = inputShape.getDimSize(0);
3399  outputShape[3] = inputShape.getDimSize(3);
3400 
3401  int64_t height = inputShape.getDimSize(1);
3402  int64_t width = inputShape.getDimSize(2);
3403 
3404  if (ShapedType::isStatic(height)) {
3405  int64_t padded = height + pad[0] + pad[1] - kernel[0];
3406  outputShape[1] = padded / stride[0] + 1;
3407  }
3408 
3409  if (ShapedType::isStatic(width)) {
3410  int64_t padded = width + pad[2] + pad[3] - kernel[1];
3411  outputShape[2] = padded / stride[1] + 1;
3412  }
3413 
3414  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3415  return success();
3416 }
3417 
3418 LogicalResult Conv2DOp::inferReturnTypeComponents(
3419  MLIRContext *context, ::std::optional<Location> location,
3420  Conv2DOp::Adaptor adaptor,
3421  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3422  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3423 
3424  int64_t inputWidth = ShapedType::kDynamic;
3425  int64_t inputHeight = ShapedType::kDynamic;
3426  int64_t weightWidth = ShapedType::kDynamic;
3427  int64_t weightHeight = ShapedType::kDynamic;
3428 
3429  // Input shape describes input width/height and batch.
3430 
3431  ShapeAdaptor inputShape(adaptor.getInput().getType());
3432  if (inputShape.hasRank()) {
3433  outputShape[0] = inputShape.getDimSize(0);
3434  inputHeight = inputShape.getDimSize(1);
3435  inputWidth = inputShape.getDimSize(2);
3436  }
3437 
3438  // Weight shapes describes the filter width/height and the output channels.
3439  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3440  if (weightShape.hasRank()) {
3441  outputShape[3] = weightShape.getDimSize(0);
3442  weightHeight = weightShape.getDimSize(1);
3443  weightWidth = weightShape.getDimSize(2);
3444  }
3445 
3446  // Bias shape can describe the output channels.
3447  ShapeAdaptor biasShape(adaptor.getBias().getType());
3448  if (biasShape.hasRank()) {
3449  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3450  ? biasShape.getDimSize(0)
3451  : outputShape[3];
3452  }
3453 
3454  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3455  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3456  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3457 
3458  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3459  int64_t inputSize = inputHeight + padding[0] + padding[1];
3460  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3461  int64_t unstridedResult = inputSize - filterSize + 1;
3462  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3463  }
3464 
3465  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3466  int64_t inputSize = inputWidth + padding[2] + padding[3];
3467  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3468  int64_t unstridedResult = inputSize - filterSize + 1;
3469  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3470  }
3471 
3472  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3473  return success();
3474 }
3475 
3476 LogicalResult Conv2DOp::verify() {
3477  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3478  verifyConvOpErrorIf(*this).failed())
3479  return failure();
3480  return success();
3481 }
3482 
3483 LogicalResult Conv3DOp::inferReturnTypeComponents(
3484  MLIRContext *context, ::std::optional<Location> location,
3485  Conv3DOp::Adaptor adaptor,
3486  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3487  llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
3488 
3489  int64_t inputWidth = ShapedType::kDynamic;
3490  int64_t inputHeight = ShapedType::kDynamic;
3491  int64_t inputDepth = ShapedType::kDynamic;
3492 
3493  int64_t weightWidth = ShapedType::kDynamic;
3494  int64_t weightHeight = ShapedType::kDynamic;
3495  int64_t weightDepth = ShapedType::kDynamic;
3496 
3497  // Input shape describes input width/height and batch.
3498  ShapeAdaptor inputShape(adaptor.getInput().getType());
3499  if (inputShape.hasRank()) {
3500  outputShape[0] = inputShape.getDimSize(0);
3501  inputDepth = inputShape.getDimSize(1);
3502  inputHeight = inputShape.getDimSize(2);
3503  inputWidth = inputShape.getDimSize(3);
3504  }
3505 
3506  // Weight shapes describes the filter width/height and the output channels.
3507  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3508  if (weightShape.hasRank()) {
3509  outputShape[4] = weightShape.getDimSize(0);
3510  weightDepth = weightShape.getDimSize(1);
3511  weightHeight = weightShape.getDimSize(2);
3512  weightWidth = weightShape.getDimSize(3);
3513  }
3514 
3515  // Bias shape can describe the output channels.
3516  ShapeAdaptor biasShape(adaptor.getBias().getType());
3517  if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
3518  outputShape[4] = biasShape.getDimSize(0);
3519  }
3520 
3521  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3522  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3523  llvm::ArrayRef<int64_t> pad = adaptor.getPad();
3524 
3525  if (ShapedType::isStatic(inputDepth) && ShapedType::isStatic(weightDepth)) {
3526  int32_t inputSize = inputDepth + pad[0] + pad[1];
3527  int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
3528  int32_t unstridedResult = inputSize - filterSize + 1;
3529  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3530  }
3531 
3532  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3533  int32_t inputSize = inputHeight + pad[2] + pad[3];
3534  int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
3535  int32_t unstridedResult = inputSize - filterSize + 1;
3536  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3537  }
3538 
3539  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3540  int32_t inputSize = inputWidth + pad[4] + pad[5];
3541  int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
3542  int32_t unstridedResult = inputSize - filterSize + 1;
3543  outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3544  }
3545 
3546  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3547  return success();
3548 }
3549 
3550 LogicalResult Conv3DOp::verify() {
3551  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3552  verifyConvOpErrorIf(*this).failed())
3553  return failure();
3554  return success();
3555 }
3556 
3557 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3558  MLIRContext *context, ::std::optional<Location> location,
3559  AvgPool2dOp::Adaptor adaptor,
3560  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3561  ShapeAdaptor inputShape(adaptor.getInput().getType());
3562  const Properties &prop = adaptor.getProperties();
3563  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3564  inferredReturnShapes);
3565 }
3566 
3567 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3568  MLIRContext *context, ::std::optional<Location> location,
3569  MaxPool2dOp::Adaptor adaptor,
3570  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3571  ShapeAdaptor inputShape(adaptor.getInput().getType());
3572  const Properties &prop = adaptor.getProperties();
3573  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3574  inferredReturnShapes);
3575 }
3576 
3577 LogicalResult MaxPool2dOp::verify() {
3578  if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
3579  /* outType = */ getOutput().getType())))
3580  return failure();
3581 
3582  if (failed(verifyPoolingOp(*this)))
3583  return failure();
3584 
3585  return success();
3586 }
3587 
3588 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3589  MLIRContext *context, ::std::optional<Location> location,
3590  DepthwiseConv2DOp::Adaptor adaptor,
3591  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3592  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3593 
3594  int64_t inputWidth = ShapedType::kDynamic;
3595  int64_t inputHeight = ShapedType::kDynamic;
3596  int64_t inputChannels = ShapedType::kDynamic;
3597 
3598  int64_t weightWidth = ShapedType::kDynamic;
3599  int64_t weightHeight = ShapedType::kDynamic;
3600  int64_t depthChannels = ShapedType::kDynamic;
3601 
3602  // Input shape describes input width/height and batch.
3603  ShapeAdaptor inputShape(adaptor.getInput().getType());
3604  if (inputShape.hasRank()) {
3605  outputShape[0] = inputShape.getDimSize(0);
3606  inputHeight = inputShape.getDimSize(1);
3607  inputWidth = inputShape.getDimSize(2);
3608  inputChannels = inputShape.getDimSize(3);
3609  }
3610 
3611  // Weight shapes describes the filter width/height and the output channels.
3612  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3613  if (weightShape.hasRank()) {
3614  weightHeight = weightShape.getDimSize(0);
3615  weightWidth = weightShape.getDimSize(1);
3616  inputChannels = ShapedType::isDynamic(inputChannels)
3617  ? weightShape.getDimSize(2)
3618  : inputChannels;
3619  depthChannels = weightShape.getDimSize(3);
3620  }
3621 
3622  // If both inputChannels and depthChannels are available we can determine
3623  // the output channels.
3624  if (ShapedType::isStatic(inputChannels) &&
3625  ShapedType::isStatic(depthChannels)) {
3626  outputShape[3] = inputChannels * depthChannels;
3627  }
3628 
3629  // Bias shape can describe the output channels.
3630  ShapeAdaptor biasShape(adaptor.getBias().getType());
3631  if (biasShape.hasRank()) {
3632  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3633  ? biasShape.getDimSize(0)
3634  : outputShape[3];
3635  }
3636 
3637  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3638  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3639  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3640 
3641  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3642  int64_t inputSize = inputHeight + padding[0] + padding[1];
3643  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3644  int64_t unstridedResult = inputSize - filterSize + 1;
3645  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3646  }
3647 
3648  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3649  int64_t inputSize = inputWidth + padding[2] + padding[3];
3650  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3651  int64_t unstridedResult = inputSize - filterSize + 1;
3652  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3653  }
3654 
3655  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3656  return success();
3657 }
3658 
3659 LogicalResult DepthwiseConv2DOp::verify() {
3660  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3661  verifyConvOpErrorIf(*this).failed())
3662  return failure();
3663  return success();
3664 }
3665 
3666 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3667  MLIRContext *context, ::std::optional<Location> location,
3668  TransposeConv2DOp::Adaptor adaptor,
3669  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3670  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3671 
3672  int64_t inputWidth = ShapedType::kDynamic;
3673  int64_t inputHeight = ShapedType::kDynamic;
3674  int64_t weightWidth = ShapedType::kDynamic;
3675  int64_t weightHeight = ShapedType::kDynamic;
3676 
3677  // Input shape describes input width/height and batch.
3678  ShapeAdaptor inputShape(adaptor.getInput().getType());
3679  if (inputShape.hasRank()) {
3680  outputShape[0] = ShapedType::isDynamic(outputShape[0])
3681  ? inputShape.getDimSize(0)
3682  : outputShape[0];
3683  inputHeight = inputShape.getDimSize(1);
3684  inputWidth = inputShape.getDimSize(2);
3685  }
3686 
3687  // Weight shapes describes the filter width/height and the output channels.
3688  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3689  if (weightShape.hasRank()) {
3690  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3691  ? weightShape.getDimSize(0)
3692  : outputShape[3];
3693  weightHeight = weightShape.getDimSize(1);
3694  weightWidth = weightShape.getDimSize(2);
3695  }
3696 
3697  // Bias shape can describe the output channels.
3698  ShapeAdaptor biasShape(adaptor.getInput().getType());
3699  if (biasShape.hasRank()) {
3700  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3701  ? biasShape.getDimSize(0)
3702  : outputShape[3];
3703  }
3704 
3705  llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
3706  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3707 
3708  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3709  int64_t calculateSize =
3710  (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3711  outputShape[1] =
3712  ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3713  }
3714 
3715  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3716  int64_t calculateSize =
3717  (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3718  outputShape[2] =
3719  ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3720  }
3721 
3722  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3723  return success();
3724 }
3725 
3726 LogicalResult TransposeConv2DOp::verify() {
3727  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
3728  return failure();
3729 
3730  const llvm::ArrayRef<int64_t> strides = getStride();
3731  const int64_t strideY = strides[0];
3732  const int64_t strideX = strides[1];
3733 
3734  if (strideY < 1 || strideX < 1)
3735  return emitOpError("expect all stride values to be >= 1, got [")
3736  << strides << "]";
3737 
3738  const auto checkPadAgainstKernelDim =
3739  [this](int64_t pad_value, int64_t kernel_dim_size,
3740  llvm::StringRef pad_name,
3741  llvm::StringRef kernel_dim_name) -> LogicalResult {
3742  if (pad_value <= -kernel_dim_size)
3743  return emitOpError("expected ")
3744  << pad_name << " > -" << kernel_dim_name
3745  << ", but got: " << pad_name << "=" << pad_value << " and "
3746  << kernel_dim_name << "=" << kernel_dim_size;
3747  return success();
3748  };
3749 
3750  const llvm::ArrayRef<int64_t> padding = getOutPad();
3751  const int64_t outPadTop = padding[0];
3752  const int64_t outPadBottom = padding[1];
3753  const int64_t outPadLeft = padding[2];
3754  const int64_t outPadRight = padding[3];
3755 
3756  const auto weightType =
3757  llvm::dyn_cast<RankedTensorType>(getWeight().getType());
3758 
3759  if (weightType) {
3760  const int64_t kernelHeight = weightType.getDimSize(1);
3761  if (ShapedType::isStatic(kernelHeight)) {
3762  if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
3763  "out_pad_top", "KH")))
3764  return failure();
3765 
3766  if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3767  "out_pad_bottom", "KH")))
3768  return failure();
3769  }
3770 
3771  const int64_t kernelWidth = weightType.getDimSize(2);
3772  if (ShapedType::isStatic(kernelWidth)) {
3773  if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
3774  "out_pad_left", "KW")))
3775  return failure();
3776 
3777  if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3778  "out_pad_right", "KW")))
3779  return failure();
3780  }
3781  }
3782 
3783  // Rest of the checks depend on the output type being a RankedTensorType
3784  const auto outputType =
3785  llvm::dyn_cast<RankedTensorType>(getOutput().getType());
3786  if (!outputType)
3787  return success();
3788 
3789  const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
3790  if (inputType && weightType) {
3791  const int64_t inputHeight = inputType.getDimSize(1);
3792  const int64_t kernelHeight = weightType.getDimSize(1);
3793  const int64_t outputHeight = outputType.getDimSize(1);
3794 
3795  if (ShapedType::isStatic(inputHeight) &&
3796  ShapedType::isStatic(outputHeight)) {
3797  if (outputHeight !=
3798  (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3799  return emitOpError(
3800  "dimension mismatch: expected OH == (IH - 1) * stride_y "
3801  "+ out_pad_top + out_pad_bottom + KH, but got ")
3802  << outputHeight << " != (" << inputHeight << " - 1) * "
3803  << strideY << " + " << outPadTop << " + " << outPadBottom
3804  << " + " << kernelHeight;
3805  }
3806 
3807  const int64_t inputWidth = inputType.getDimSize(2);
3808  const int64_t kernelWidth = weightType.getDimSize(2);
3809  const int64_t outputWidth = outputType.getDimSize(2);
3810 
3811  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
3812  if (outputWidth !=
3813  (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3814  return emitOpError(
3815  "dimension mismatch: expected OW == (IW - 1) * stride_x "
3816  "+ out_pad_left + out_pad_right + KW, but got ")
3817  << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
3818  << " + " << outPadLeft << " + " << outPadRight << " + "
3819  << kernelWidth;
3820  }
3821  }
3822 
3823  const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().getType());
3824 
3825  if (!biasType)
3826  return success();
3827 
3828  const int64_t biasChannels = biasType.getDimSize(0);
3829 
3830  // Skip further checks if bias is dynamic
3831  if (biasChannels == ShapedType::kDynamic)
3832  return success();
3833 
3834  const int64_t outputChannels = outputType.getDimSize(3);
3835  if (!ShapedType::isDynamic(outputChannels) &&
3836  biasChannels != outputChannels && biasChannels != 1)
3837  return emitOpError(
3838  "bias channels expected to be equal to output channels (")
3839  << outputChannels << ") or 1, got " << biasChannels;
3840 
3841  return success();
3842 }
3843 
3844 LogicalResult RescaleOp::verify() {
3845  auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
3846  if (!inputType) {
3847  emitOpError("expect shaped tensor for input, got ") << getInput().getType();
3848  return failure();
3849  }
3850 
3851  auto inputElementType =
3852  getStorageElementTypeOrSelf(inputType.getElementType());
3853  if (!mlir::isa<IntegerType>(inputElementType)) {
3854  emitOpError("expect input to have integer element type, got ")
3855  << inputElementType;
3856  return failure();
3857  }
3858 
3859  auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
3860  if (!outputType) {
3861  emitOpError("expect shaped tensor for output, got ")
3862  << getOutput().getType();
3863  return failure();
3864  }
3865 
3866  auto outputElementType =
3867  getStorageElementTypeOrSelf(outputType.getElementType());
3868  if (!mlir::isa<IntegerType>(outputElementType)) {
3869  emitOpError("expect output to have integer element type, got ")
3870  << outputElementType;
3871  return failure();
3872  }
3873 
3874  if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
3875  .failed())
3876  return failure();
3877 
3878  if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
3879  .failed())
3880  return failure();
3881 
3882  FailureOr<int64_t> maybeIZp = getInputZeroPoint();
3883  if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
3884  return failure();
3885 
3886  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3887  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3888  return failure();
3889 
3890  auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
3891  if (!multiplierType) {
3892  emitOpError("expect shaped tensor for multiplier, got ")
3893  << getMultiplier().getType();
3894  return failure();
3895  }
3896 
3897  auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
3898  if (!shiftType) {
3899  emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
3900  return failure();
3901  }
3902 
3903  // multiplier element type must be i32 for scale32 = true
3904  if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
3905  emitOpError("expect i32 element type for multiplier for scale32=true, got ")
3906  << multiplierType.getElementType();
3907  return failure();
3908  }
3909 
3910  // multiplier element type must be i16 for scale32 = false
3911  if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
3912  emitOpError(
3913  "expect i16 element type for multiplier for scale32=false, got ")
3914  << multiplierType.getElementType();
3915  return failure();
3916  }
3917 
3918  if (!inputType.hasRank())
3919  return success();
3920 
3921  // multiplier/shift must have shape = {numChannels},
3922  // where numChannel is 1 if per_channel = false
3923  // otherwise numChannel is dimension in input shape's last axis
3924  int64_t numChannels = 1;
3925  if (getPerChannel()) {
3926  if (inputType.getRank() < 1) {
3927  emitOpError("requires input to be at least rank 1 when per_channel is "
3928  "true, but got rank ")
3929  << inputType.getRank();
3930  return failure();
3931  }
3932  numChannels = inputType.getDimSize(inputType.getRank() - 1);
3933  }
3934 
3935  if (!multiplierType.hasRank())
3936  return success();
3937 
3938  ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
3939  // multiplier input has rank 1 by dialect definition
3940  if (multiplierShape[0] != ShapedType::kDynamic &&
3941  multiplierShape[0] != numChannels) {
3942  emitOpError("expect shape of { ")
3943  << numChannels << " } for multiplier input, got { "
3944  << multiplierShape[0] << " }";
3945  return failure();
3946  }
3947 
3948  if (!shiftType.hasRank())
3949  return success();
3950 
3951  ArrayRef<int64_t> shiftShape = shiftType.getShape();
3952  // shift input has rank 1 by dialect definition
3953  if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3954  emitOpError("expect shape of { ")
3955  << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
3956  return failure();
3957  }
3958 
3959  return success();
3960 }
3961 
3962 LogicalResult RescaleOp::inferReturnTypeComponents(
3963  MLIRContext *context, ::std::optional<Location> location,
3964  RescaleOp::Adaptor adaptor,
3965  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3966  ShapeAdaptor inputShape(adaptor.getInput().getType());
3967  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3968  return success();
3969 }
3970 
3971 LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
3972  MLIRContext *context, ::std::optional<Location> location,
3973  CastFromBlockScaledOp::Adaptor adaptor,
3974  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3975  const ShapeAdaptor inputShape(adaptor.getInputData().getType());
3976  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3977  return success();
3978 }
3979 
3980 LogicalResult CastFromBlockScaledOp::verify() {
3981  const Type inputDataType = getInputData().getType();
3982  const Type outputDataType = getResult().getType();
3983  if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
3984  return emitOpError() << "require compatible shapes for input_data ("
3985  << inputDataType << ") and "
3986  << "output_data (" << outputDataType << ")";
3987 
3988  const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
3989 
3990  if (inputDataShape.hasRank()) {
3991  const unsigned int blockSize =
3992  BlockSizeAttr::getBlockSizeValue(getBlockSize());
3993  const int64_t inputDataLastDim =
3994  inputDataShape.getDimSize(inputDataShape.getRank() - 1);
3995  if (inputDataLastDim % blockSize != 0)
3996  return emitOpError() << "expect last dimension of input_data ("
3997  << inputDataLastDim
3998  << ") to be divisible by block_size (" << blockSize
3999  << ")";
4000 
4001  const Type inputScaleType = getInputScale().getType();
4002  const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
4003 
4004  if (inputScaleShape.hasRank()) {
4005  SmallVector<int64_t> inputDataDims, inputScaleDims;
4006  inputDataShape.getDims(inputDataDims);
4007  inputScaleShape.getDims(inputScaleDims);
4008 
4009  if (inputDataDims.size() != inputScaleDims.size() ||
4011  ArrayRef<int64_t>(inputDataDims).drop_back(1),
4012  ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
4013  return emitOpError() << "require compatible shapes for input_data ("
4014  << inputDataType << ") and "
4015  << "input_scale (" << inputScaleType
4016  << ") except for the last dimension";
4017 
4018  const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
4019  inputScaleDims.back()};
4020  if (ShapedType::isStatic(inputDataLastDim) &&
4021  failed(verifyCompatibleDims(dimsToCheck)))
4022  return emitOpError()
4023  << "expect last dimension of input_scale ("
4024  << inputScaleDims.back()
4025  << ") to be equal to last dimension of input_data / block_size ("
4026  << inputDataDims.back() / blockSize << ")";
4027  }
4028  }
4029 
4030  return success();
4031 }
4032 
4033 LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
4034  MLIRContext *context, ::std::optional<Location> location,
4035  CastToBlockScaledOp::Adaptor adaptor,
4036  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4037  const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4038  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4039  if (!inputShape.hasRank())
4040  return success();
4041 
4042  // Calculate output_scale shape if ranked input provided
4043  SmallVector<int64_t> outputScaleShape;
4044  inputShape.getDims(outputScaleShape);
4045  const int64_t lastDimLoc = inputShape.getRank() - 1;
4046  const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
4047  if (ShapedType::isStatic(lastDimSize)) {
4048  const unsigned int blockSize =
4049  BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
4050  outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
4051  }
4052  inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
4053  return success();
4054 }
4055 
4056 LogicalResult CastToBlockScaledOp::verify() {
4057  const Type inputDataType = getInputData().getType();
4058  const Type outputDataType = getResult(0).getType();
4059  if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
4060  return emitOpError() << "require compatible shapes for input_data ("
4061  << inputDataType << ") and "
4062  << "output_data (" << outputDataType << ")";
4063 
4064  const unsigned int blockSize =
4065  BlockSizeAttr::getBlockSizeValue(getBlockSize());
4066  const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
4067  if (inputDataShape.hasRank()) {
4068  const int64_t inputDataLastDim =
4069  inputDataShape.getDimSize(inputDataShape.getRank() - 1);
4070  if (ShapedType::isStatic(inputDataLastDim) &&
4071  inputDataLastDim % blockSize != 0)
4072  return emitOpError() << "expect last dimension of input_data ("
4073  << inputDataLastDim
4074  << ") to be divisible by block_size (" << blockSize
4075  << ")";
4076  }
4077 
4078  const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
4079  const Type outputScaleType = getResult(1).getType();
4080  const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
4081  if (outputDataShape.hasRank() && outputScaleShape.hasRank()) {
4082  SmallVector<int64_t> outputDataDims, outputScaleDims;
4083  outputDataShape.getDims(outputDataDims);
4084  outputScaleShape.getDims(outputScaleDims);
4085 
4086  if (outputDataDims.size() != outputScaleDims.size() ||
4088  ArrayRef<int64_t>(outputDataDims).drop_back(1),
4089  ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
4090  return emitOpError() << "require compatible shapes for output_data ("
4091  << outputDataType << ") and "
4092  << "output_scale (" << outputScaleType
4093  << ") except for the last dimension";
4094 
4095  const int64_t outputDataLastDim = outputDataDims.back();
4096  const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
4097  outputScaleDims.back()};
4098  if (ShapedType::isStatic(outputDataLastDim) &&
4099  failed(verifyCompatibleDims(dimsToCheck)))
4100  return emitOpError()
4101  << "expect last dimension of output_scale ("
4102  << outputScaleDims.back()
4103  << ") to be equal to last dimension of output_data / block_size ("
4104  << outputDataDims.back() / blockSize << ")";
4105  }
4106 
4107  return success();
4108 }
4109 
4110 LogicalResult IfOp::inferReturnTypeComponents(
4111  MLIRContext *context, ::std::optional<Location> location,
4112  IfOp::Adaptor adaptor,
4113  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4115  for (Region *region : adaptor.getRegions()) {
4116  for (auto &block : *region)
4117  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4118  yieldOps.push_back(returnOp);
4119  }
4120 
4121  if (yieldOps.empty())
4122  return failure();
4123 
4124  // Get the initial type information for the yield op.
4125  llvm::SmallVector<ValueKnowledge> resultKnowledge;
4126  resultKnowledge.reserve(yieldOps.front().getNumOperands());
4127  for (auto operand : yieldOps.front().getOperands()) {
4128  resultKnowledge.push_back(
4129  ValueKnowledge::getKnowledgeFromType(operand.getType()));
4130  }
4131 
4132  for (auto yieldOp : yieldOps) {
4133  if (resultKnowledge.size() != yieldOp.getNumOperands())
4134  return failure();
4135 
4136  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
4137  int32_t index = it.index();
4138  auto meet = ValueKnowledge::meet(
4139  resultKnowledge[index],
4140  ValueKnowledge::getKnowledgeFromType(it.value().getType()));
4141  if (!meet)
4142  continue;
4143  resultKnowledge[index] = meet;
4144  }
4145  }
4146 
4147  for (const ValueKnowledge &result : resultKnowledge) {
4148  inferredReturnShapes.push_back(result.getShapedTypeComponents());
4149  }
4150 
4151  return success();
4152 }
4153 
4154 LogicalResult WhileOp::inferReturnTypeComponents(
4155  MLIRContext *context, ::std::optional<Location> location,
4156  WhileOp::Adaptor adaptor,
4157  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4159  for (auto &block : adaptor.getBodyGraph())
4160  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4161  yieldOps.push_back(returnOp);
4162 
4163  // TOSA's while must have a tosa.yield as its terminator. If not found this
4164  // tosa.while is invalid.
4165  if (yieldOps.empty())
4166  return failure();
4167 
4168  // Get the initial type information from the operand types.
4169  llvm::SmallVector<ValueKnowledge> resultKnowledge;
4170  resultKnowledge.reserve(yieldOps.front().getNumOperands());
4171  for (auto operand : yieldOps.front().getOperands()) {
4172  resultKnowledge.push_back(
4173  ValueKnowledge::getKnowledgeFromType(operand.getType()));
4174  }
4175 
4176  for (auto yieldOp : yieldOps) {
4177  if (resultKnowledge.size() != yieldOp.getNumOperands())
4178  return failure();
4179 
4180  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
4181  int32_t index = it.index();
4182  if (auto meet = ValueKnowledge::meet(
4183  resultKnowledge[index],
4184  ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
4185  resultKnowledge[index] = meet;
4186  }
4187  }
4188  }
4189 
4190  for (const ValueKnowledge &result : resultKnowledge) {
4191  inferredReturnShapes.push_back(result.getShapedTypeComponents());
4192  }
4193 
4194  return success();
4195 }
4196 
4197 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
4198  if (auto vt = llvm::dyn_cast<VectorType>(getType()))
4199  return llvm::to_vector<4>(vt.getShape());
4200  return std::nullopt;
4201 }
4202 
4204  Block::BlockArgListType blocksArgs,
4205  ValueRange initializers,
4206  StringRef prefix = "") {
4207  assert(blocksArgs.size() == initializers.size() &&
4208  "expected same length of arguments and initializers");
4209  if (initializers.empty())
4210  return;
4211 
4212  parser << prefix << '(';
4213  llvm::interleaveComma(
4214  llvm::zip(blocksArgs, initializers), parser,
4215  [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
4216  parser << ")";
4217 }
4218 
4219 // parse and print of IfOp refer to the implementation of SCF dialect.
4220 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
4221  // Create the regions for 'then'.
4222  result.regions.reserve(2);
4223  Region *thenRegion = result.addRegion();
4224  Region *elseRegion = result.addRegion();
4225 
4227 
4228  if (parser.parseOperand(cond))
4229  return failure();
4230 
4233 
4234  // Parse the optional block arguments
4235  OptionalParseResult listResult =
4236  parser.parseOptionalAssignmentList(regionArgs, operands);
4237  if (listResult.has_value() && failed(listResult.value()))
4238  return failure();
4239 
4240  // Parse a colon.
4241  if (failed(parser.parseColon()))
4242  return parser.emitError(parser.getCurrentLocation(),
4243  "expected type for condition operand");
4244 
4245  // Parse the type of the condition operand
4246  Type condType;
4247  if (failed(parser.parseType(condType)))
4248  return parser.emitError(parser.getCurrentLocation(),
4249  "expected type for condition operand");
4250 
4251  // Resolve operand with provided type
4252  if (failed(parser.resolveOperand(cond, condType, result.operands)))
4253  return failure();
4254 
4255  // Parse optional block arg types
4256  if (listResult.has_value()) {
4257  FunctionType functionType;
4258 
4259  if (failed(parser.parseType(functionType)))
4260  return parser.emitError(parser.getCurrentLocation())
4261  << "expected list of types for block arguments "
4262  << "followed by arrow type and list of return types";
4263 
4264  result.addTypes(functionType.getResults());
4265 
4266  if (functionType.getNumInputs() != operands.size()) {
4267  return parser.emitError(parser.getCurrentLocation())
4268  << "expected as many input types as operands "
4269  << "(expected " << operands.size() << " got "
4270  << functionType.getNumInputs() << ")";
4271  }
4272 
4273  // Resolve input operands.
4274  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
4275  parser.getCurrentLocation(),
4276  result.operands)))
4277  return failure();
4278  } else {
4279  // Parse optional results type list.
4280  if (parser.parseOptionalArrowTypeList(result.types))
4281  return failure();
4282  }
4283 
4284  // Parse the 'then' region.
4285  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
4286  return failure();
4287 
4288  // If we find an 'else' keyword then parse the 'else' region.
4289  if (!parser.parseOptionalKeyword("else")) {
4290  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
4291  return failure();
4292  }
4293 
4294  // Parse the optional attribute list.
4295  if (parser.parseOptionalAttrDict(result.attributes))
4296  return failure();
4297  return success();
4298 }
4299 
4300 void IfOp::print(OpAsmPrinter &p) {
4301  p << " " << getCondition();
4302 
4303  printInitializationList(p, getThenGraph().front().getArguments(),
4304  getInputList(), " ");
4305  p << " : ";
4306  p << getCondition().getType();
4307 
4308  if (!getInputList().empty()) {
4309  p << " (";
4310  llvm::interleaveComma(getInputList().getTypes(), p);
4311  p << ")";
4312  }
4313  p.printArrowTypeList(getResultTypes());
4314  p << " ";
4315 
4316  p.printRegion(getThenGraph());
4317 
4318  // Print the 'else' regions if it exists and has a block.
4319  auto &elseRegion = getElseGraph();
4320  if (!elseRegion.empty()) {
4321  p << " else ";
4322  p.printRegion(elseRegion);
4323  }
4324 
4325  p.printOptionalAttrDict((*this)->getAttrs());
4326 }
4327 
4328 LogicalResult IfOp::verify() {
4329  if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
4330  "'then_graph' arguments", getInputList(),
4331  "'input_list'")
4332  .failed())
4333  return failure();
4334 
4335  if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(),
4336  "'else_graph' arguments", getInputList(),
4337  "'input_list'")
4338  .failed())
4339  return failure();
4340 
4341  // MLIR will verify the absence of the terminator for us if otherwise.
4342  if (getThenGraph().front().mightHaveTerminator()) {
4343  auto thenYield =
4344  dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
4345  if (thenYield && errorIfTypeOrShapeMismatch(
4346  *this, thenYield.getInputs(), "'then_graph' results",
4347  getOutputList(), "'output_list'")
4348  .failed())
4349  return failure();
4350  }
4351 
4352  // MLIR will verify the absence of the terminator for us if otherwise.
4353  if (getElseGraph().front().mightHaveTerminator()) {
4354  auto elseYield =
4355  dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
4356  if (elseYield && errorIfTypeOrShapeMismatch(
4357  *this, elseYield.getInputs(), "'else_graph' results",
4358  getOutputList(), "'output_list'")
4359  .failed())
4360  return failure();
4361  }
4362 
4363  auto condType = getCondition().getType();
4364  if (errorIfShapeNotSizeOne(*this, condType).failed())
4365  return emitOpError() << "'condition' must be a size 1 tensor, got "
4366  << condType;
4367 
4368  return success();
4369 }
4370 
4371 LogicalResult WhileOp::verify() {
4372  if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'",
4373  getOutputList(), "'output_list'")
4374  .failed())
4375  return failure();
4376 
4377  if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(),
4378  "'cond_graph' arguments", getInputList(),
4379  "'input_list'")
4380  .failed())
4381  return failure();
4382 
4383  if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(),
4384  "'body_graph' arguments", getInputList(),
4385  "'input_list'")
4386  .failed())
4387  return failure();
4388 
4389  if (getBodyGraph().front().mightHaveTerminator()) {
4390  auto bodyYield =
4391  dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
4392  if (bodyYield && errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
4393  "'body_graph' results",
4394  getInputList(), "'input_list'")
4395  .failed())
4396  return failure();
4397  }
4398 
4399  // Condition block output must be a single element tensor with a single bool
4400  // value.
4401  if (!getCondGraph().front().mightHaveTerminator())
4402  return success();
4403 
4404  auto condYield =
4405  dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
4406  if (!condYield)
4407  return success();
4408 
4409  if (condYield.getInputs().size() != 1)
4410  return emitOpError() << "require 'cond_graph' only have one result";
4411 
4412  auto condOutType = condYield.getInputs()[0].getType();
4413  if (errorIfShapeNotSizeOne(*this, condOutType).failed())
4414  return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
4415  << condOutType;
4416 
4417  if (!getElementTypeOrSelf(condOutType).isInteger(1))
4418  return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
4419  << condOutType;
4420 
4421  return success();
4422 }
4423 
4424 LogicalResult ReverseOp::verify() {
4425  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
4426  /* outType = */ getOutput().getType())
4427  .failed())
4428  return failure();
4429  TensorType inputType = getInput1().getType();
4430  TensorType outputType = getOutput().getType();
4431  int32_t reverseAxis = getAxis();
4432 
4433  if (reverseAxis < 0)
4434  return emitOpError("expected non-negative reverse axis");
4435  if (inputType.hasRank()) {
4436  int64_t inputRank = inputType.getRank();
4437  // We allow for a special case where the input/output shape has rank 0 and
4438  // axis is also 0.
4439  if (reverseAxis >= inputRank && (reverseAxis != 0 || inputRank != 0))
4440  return emitOpError("expect input tensor rank (")
4441  << inputRank << ") to be larger than reverse axis (" << reverseAxis
4442  << ")";
4443  }
4444  if (outputType.hasRank()) {
4445  int64_t outputRank = outputType.getRank();
4446  if (inputType.hasRank() && outputRank != inputType.getRank())
4447  return emitOpError(
4448  "expect output tensor rank to be equal to input tensor rank");
4449  if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0))
4450  return emitOpError("expect output tensor rank (")
4451  << outputRank << ") to be larger than reverse axis ("
4452  << reverseAxis << ")";
4453  }
4454  return success();
4455 }
4456 
4457 LogicalResult tosa::SelectOp::verify() {
4458  // verify input2 and input3 have same element type as output
4459  if (verifySameElementTypes(*this, /* inType = */ getOnTrue().getType(),
4460  /* outType = */ getOutput().getType())
4461  .failed() ||
4462  verifySameElementTypes(*this, /* inType = */ getOnFalse().getType(),
4463  /* outType = */ getOutput().getType())
4464  .failed()) {
4465  return failure();
4466  }
4467  // verify input1 has element type of bool
4468  auto predicateType = llvm::dyn_cast<ShapedType>(getPred().getType());
4469  if (!predicateType) {
4470  return emitOpError("expect shaped tensor for input1, got ")
4471  << getInput1().getType();
4472  }
4473  auto predicateElementType = predicateType.getElementType();
4474  if (!predicateElementType.isInteger(1)) {
4475  return emitOpError("expect element type of bool for input1, got ")
4476  << predicateElementType;
4477  }
4478 
4479  return success();
4480 }
4481 
4482 LogicalResult tosa::VariableReadOp::verify() {
4483  if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
4484  .failed())
4485  return failure();
4486 
4487  return success();
4488 }
4489 
4490 LogicalResult tosa::VariableWriteOp::verify() {
4491  if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
4492  .failed())
4493  return failure();
4494 
4495  return success();
4496 }
4497 
4498 // parse and print of WhileOp refer to the implementation of SCF dialect.
4499 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
4502  Region *cond = result.addRegion();
4503  Region *body = result.addRegion();
4504 
4505  OptionalParseResult listResult =
4506  parser.parseOptionalAssignmentList(regionArgs, operands);
4507  if (listResult.has_value() && failed(listResult.value()))
4508  return failure();
4509 
4510  FunctionType functionType;
4511  SMLoc typeLoc = parser.getCurrentLocation();
4512  if (failed(parser.parseColonType(functionType)))
4513  return failure();
4514 
4515  result.addTypes(functionType.getResults());
4516 
4517  if (functionType.getNumInputs() != operands.size()) {
4518  return parser.emitError(typeLoc)
4519  << "expected as many input types as operands "
4520  << "(expected " << operands.size() << " got "
4521  << functionType.getNumInputs() << ")";
4522  }
4523 
4524  // Resolve input operands.
4525  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
4526  parser.getCurrentLocation(),
4527  result.operands)))
4528  return failure();
4529 
4530  // Propagate the types into the region arguments.
4531  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
4532  regionArgs[i].type = functionType.getInput(i);
4533 
4534  return failure(parser.parseRegion(*cond, regionArgs) ||
4535  parser.parseKeyword("do") || parser.parseRegion(*body) ||
4537 }
4538 
4539 void WhileOp::print(OpAsmPrinter &parser) {
4540  printInitializationList(parser, getCondGraph().front().getArguments(),
4541  getInputList(), " ");
4542  parser << " : ";
4543  parser.printFunctionalType(getInputList().getTypes(),
4544  getResults().getTypes());
4545  parser << ' ';
4546  parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
4547  parser << " do ";
4548  parser.printRegion(getBodyGraph());
4549  parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
4550 }
4551 
4552 // Create a rank-1 const tensor for zero point of the source tensor.
4553 std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
4554  Location loc,
4555  Type srcElemType,
4556  int64_t zp) {
4557  srcElemType = getStorageElementTypeOrSelf(srcElemType);
4558  auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
4559  if (llvm::isa<FloatType>(srcElemType)) {
4560  auto zpAttr = DenseElementsAttr::get(
4561  zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
4562  return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4563  }
4564  if (llvm::isa<IntegerType>(srcElemType)) {
4565  auto zpAttr =
4566  DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp));
4567  return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4568  }
4569  llvm::errs() << "zero point is not allowed for unsupported data types\n";
4570  return std::nullopt;
4571 }
4572 
4573 //===----------------------------------------------------------------------===//
4574 // TOSA Shape and Shape Operators Helper functions.
4575 //===----------------------------------------------------------------------===//
4576 
4578  return mlir::isa<tosa::shapeType>(t);
4579 }
4580 
4581 LogicalResult
4583  int rank) {
4584  if (rank < 0)
4585  return emitError() << "invalid rank (must be >= 0): " << rank;
4586  return success();
4587 }
4588 
4590  for (auto v : op->getOperands()) {
4591  if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
4592  Operation *definingOp = v.getDefiningOp();
4593  if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
4594  return op->emitOpError("shape operand is not compile time resolvable");
4595  }
4596  }
4597  }
4598  return success();
4599 }
4600 
4602  for (auto type : op->getOperandTypes()) {
4603  if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4604  return op->emitOpError("must have operands with tosa shape type");
4605  }
4606  }
4607  for (auto type : op->getResultTypes()) {
4608  if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4609  return op->emitOpError("must have result with tosa shape type");
4610  }
4611  }
4612  return success();
4613 }
4614 
4615 LogicalResult
4619  return failure();
4620 
4621  // delegate function that returns rank of shape type
4622  auto getRank = [](const Type type) {
4623  return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4624  };
4625  auto operandTypes = op->getOperandTypes();
4626  auto resultTypes = op->getResultTypes();
4627 
4628  auto rank = getRank(*op->getOperandTypes().begin());
4629  for (auto type : operandTypes) {
4630  if (getRank(type) != rank) {
4631  return op->emitOpError("operands don't have matching ranks");
4632  }
4633  }
4634  for (auto type : resultTypes) {
4635  if (getRank(type) != rank) {
4636  return op->emitOpError("result shape has different rank than operands");
4637  }
4638  }
4639  return success();
4640 }
4641 
4642 //===----------------------------------------------------------------------===//
4643 // TOSA Shape Operators verify functions.
4644 //===----------------------------------------------------------------------===//
4645 
4646 LogicalResult tosa::ConstShapeOp::verify() {
4647  // check one dimensional rank
4648  auto valuesRank = getValues().getType().getRank();
4649  if (valuesRank != 1)
4650  return emitOpError("expect elements in attribute values with rank 1");
4651  // check that number of elements in values attr equal to rank of result shape
4652  auto count = getValues().getNumElements();
4653  auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
4654  if (count != rank && (count != 1 || rank != 0)) {
4655  return emitOpError("expect number of elements in attribute values (")
4656  << count << ") to be equal to the rank (" << rank
4657  << ") for the result shape type";
4658  }
4659  return success();
4660 }
4661 
4662 //===----------------------------------------------------------------------===//
4663 // TOSA Attribute Definitions.
4664 //===----------------------------------------------------------------------===//
4665 
4666 #define GET_ATTRDEF_CLASSES
4667 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
4668 
4669 //===----------------------------------------------------------------------===//
4670 // TOSA Type Definitions.
4671 //===----------------------------------------------------------------------===//
4672 #define GET_TYPEDEF_CLASSES
4673 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
4674 
4675 //===----------------------------------------------------------------------===//
4676 // TOSA Operator Definitions.
4677 //===----------------------------------------------------------------------===//
4678 
4679 #define GET_OP_CLASSES
4680 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:51
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition: TosaOps.cpp:1296
static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, StringRef aName="input", StringRef bName="output")
Definition: TosaOps.cpp:984
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
Definition: TosaOps.cpp:2596
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:3169
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
Definition: TosaOps.cpp:574
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
Definition: TosaOps.cpp:948
#define REDUCE_SHAPE_INFER(OP)
Definition: TosaOps.cpp:3194
static LogicalResult verifyConvOp(T op)
Definition: TosaOps.cpp:620
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
Definition: TosaOps.cpp:957
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:3384
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition: TosaOps.cpp:1410
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
Definition: TosaOps.cpp:1424
static LogicalResult verifyReduceOp(T op)
Definition: TosaOps.cpp:3219
#define NARY_SHAPE_INFER(OP)
Definition: TosaOps.cpp:3287
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
Definition: TosaOps.cpp:2666
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition: TosaOps.cpp:1274
static LogicalResult verifyConvOpErrorIf(T op)
Definition: TosaOps.cpp:766
static LogicalResult verifyConvOpModes(T op)
Definition: TosaOps.cpp:725
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:3275
static Type getStorageElementTypeOrSelf(Type type)
Definition: TosaOps.cpp:563
#define COMPATIBLE_RETURN_TYPES(OP)
Definition: TosaOps.cpp:3185
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition: TosaOps.cpp:1455
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
Definition: TosaOps.cpp:1370
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition: TosaOps.cpp:1250
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
Definition: TosaOps.cpp:1325
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
Definition: TosaOps.cpp:906
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition: TosaOps.cpp:136
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
Definition: TosaOps.cpp:2624
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Definition: TosaOps.cpp:4203
static LogicalResult verifyPoolingOp(T op)
Definition: TosaOps.cpp:1050
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition: TosaOps.cpp:557
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
Definition: TosaOps.cpp:1544
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:108
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:228
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:254
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:262
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:193
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:316
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:94
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:207
This class indicates that op operates on tosa shape types.
Definition: TosaOps.h:130
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:40
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:53
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:50
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
Definition: Operation.cpp:918
LogicalResult verifyTosaShapeOperator(Operation *op)
Definition: TosaOps.cpp:4601
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
Definition: TosaOps.cpp:4616
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
Definition: TosaOps.cpp:4589
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition: Traits.cpp:59
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Definition: QuantUtils.cpp:198
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
Definition: QuantUtils.cpp:289
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
Definition: TosaOps.cpp:225
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
Definition: QuantUtils.cpp:269
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
Definition: QuantUtils.cpp:162
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
Definition: TosaOps.cpp:250
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
Definition: QuantUtils.cpp:214
unsigned getBitWidth(Type type)
Definition: TosaOps.cpp:609
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
Definition: TosaOps.cpp:4553
bool isa_tosa_shape_type(mlir::Type t)
Definition: TosaOps.cpp:4577
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Definition: QuantUtils.cpp:243
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
Definition: TosaOps.cpp:594
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:499
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verifyCompatibleDims(ArrayRef< int64_t > dims)
Dimensions are compatible if all non-dynamic dims are equal.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
Definition: ShapeUtils.h:33
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45