MLIR  21.0.0git
TosaOps.cpp
Go to the documentation of this file.
1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://www.mlplatform.org/tosa/tosa_spec.html
12 //
13 //===----------------------------------------------------------------------===//
14 
22 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 #include <numeric>
34 
35 using namespace mlir;
36 using namespace mlir::tosa;
37 
38 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
40 
41 //===----------------------------------------------------------------------===//
42 // Tosa dialect interface includes.
43 //===----------------------------------------------------------------------===//
44 
45 #include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
46 #include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
47 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
48 #include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
49 
50 namespace {
51 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
52 
53 //===----------------------------------------------------------------------===//
54 // Dialect Function Inliner Interface.
55 //===----------------------------------------------------------------------===//
56 struct TosaInlinerInterface : public DialectInlinerInterface {
58 
59  //===--------------------------------------------------------------------===//
60  // Analysis Hooks.
61  //===--------------------------------------------------------------------===//
62 
63  /// All operations can be inlined by default.
64  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
65  IRMapping &map) const final {
66  return true;
67  }
68 
69  /// All regions with If and While parent operators can be inlined.
70  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
71  IRMapping &map) const final {
72  return (isa<tosa::IfOp>(dest->getParentOp()) ||
73  isa<tosa::WhileOp>(dest->getParentOp()));
74  }
75 };
76 
77 /// This class implements the bytecode interface for the Tosa dialect.
78 struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
79  TosaDialectBytecodeInterface(Dialect *dialect)
80  : BytecodeDialectInterface(dialect) {}
81 
82  //===--------------------------------------------------------------------===//
83  // Attributes
84 
85  Attribute readAttribute(DialectBytecodeReader &reader) const override {
86  return ::readAttribute(getContext(), reader);
87  }
88 
89  LogicalResult writeAttribute(Attribute attr,
90  DialectBytecodeWriter &writer) const override {
91  return ::writeAttribute(attr, writer);
92  }
93 
94  //===--------------------------------------------------------------------===//
95  // Types
96 
97  Type readType(DialectBytecodeReader &reader) const override {
98  return ::readType(getContext(), reader);
99  }
100 
101  LogicalResult writeType(Type type,
102  DialectBytecodeWriter &writer) const override {
103  return ::writeType(type, writer);
104  }
105 
106  void writeVersion(DialectBytecodeWriter &writer) const final {
107  // TODO: Populate.
108  }
109 
110  std::unique_ptr<DialectVersion>
111  readVersion(DialectBytecodeReader &reader) const final {
112  // TODO: Populate
113  reader.emitError("Dialect does not support versioning");
114  return nullptr;
115  }
116 
117  LogicalResult upgradeFromVersion(Operation *topLevelOp,
118  const DialectVersion &version) const final {
119  return success();
120  }
121 };
122 
123 } // namespace
124 
125 //===----------------------------------------------------------------------===//
126 // TOSA control flow support.
127 //===----------------------------------------------------------------------===//
128 
129 /// Returns the while loop body.
130 SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
131  return {&getBodyGraph()};
132 }
133 
134 //===----------------------------------------------------------------------===//
135 // Tosa dialect initialization.
136 //===----------------------------------------------------------------------===//
137 
138 void TosaDialect::initialize() {
139  addTypes<
140 #define GET_TYPEDEF_LIST
141 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
142  >();
143  addOperations<
144 #define GET_OP_LIST
145 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
146  >();
147  addAttributes<
148 #define GET_ATTRDEF_LIST
149 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
150  >();
151  addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
152  declarePromisedInterfaces<
153  mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
154  ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
155  LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
156  LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
157  BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
158  NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
159  GreaterEqualOp, MatMulOp>();
160 }
161 
163  Type type, Location loc) {
164  // Tosa dialect constants only support ElementsAttr unlike standard dialect
165  // constant which supports all attributes.
166  if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
167  return builder.create<tosa::ConstShapeOp>(
168  loc, type, llvm::cast<DenseIntElementsAttr>(value));
169  }
170  if (llvm::isa<ElementsAttr>(value))
171  return builder.create<tosa::ConstOp>(loc, type,
172  llvm::cast<ElementsAttr>(value));
173  return nullptr;
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // Parsers and printers
178 //===----------------------------------------------------------------------===//
179 
180 ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
181  Attribute &attr) {
182  if (succeeded(parser.parseOptionalEqual())) {
183  if (failed(parser.parseAttribute(attr))) {
184  return parser.emitError(parser.getCurrentLocation())
185  << "expected attribute";
186  }
187  if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
188  typeAttr = TypeAttr::get(typedAttr.getType());
189  }
190  return success();
191  }
192 
193  Type type;
194  if (failed(parser.parseColonType(type))) {
195  return parser.emitError(parser.getCurrentLocation()) << "expected type";
196  }
197  typeAttr = TypeAttr::get(type);
198 
199  return success();
200 }
201 
203  Attribute attr) {
204  bool needsSpace = false;
205  auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
206  if (!typedAttr || typedAttr.getType() != type.getValue()) {
207  p << ": ";
208  p.printAttribute(type);
209  needsSpace = true; // subsequent attr value needs a space separator
210  }
211  if (attr) {
212  if (needsSpace)
213  p << ' ';
214  p << "= ";
215  p.printAttribute(attr);
216  }
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // Tosa utilities.
221 //===----------------------------------------------------------------------===//
222 
223 std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
224  if (lhs % rhs != 0)
225  return std::nullopt;
226  return lhs / rhs;
227 }
228 
230  auto srcType = getElementTypeOrSelf(type);
231  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
232  srcType = quantType.getStorageType();
233  return srcType;
234 }
235 
237  return getStorageElementTypeOrSelf(value.getType());
238 }
239 
240 static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
241  Value valZp, StringRef name) {
243  Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
244 
245  bool bothInts =
246  mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
247  bool sameBitWidth =
248  (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
249 
250  if (!bothInts || !sameBitWidth) {
251  return op->emitOpError()
252  << "expected " << name << " and " << name
253  << "_zp to both be integer of the same bitwidth, but got " << eType
254  << " vs. " << eZpType;
255  }
256  return success();
257 }
258 
259 // Create a pad-const const tensor with value of `val` of required data-type
261  Value src, int32_t val) {
262  const auto srcType = getElementTypeOrSelf(src);
263  const auto srcElemType = getStorageElementTypeOrSelf(src);
264  const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
265  const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
266  const auto padConstAttr{
267  llvm::isa<FloatType>(srcElemType)
268  ? DenseElementsAttr::get(padConstEType,
269  builder.getFloatAttr(srcElemType, val))
270  : DenseElementsAttr::get(padConstEType,
271  builder.getIntegerAttr(srcElemType, val))};
272  return builder.create<tosa::ConstOp>(loc, padConstType, padConstAttr);
273 }
274 
275 //===----------------------------------------------------------------------===//
276 // TOSA Operator Verifiers.
277 //===----------------------------------------------------------------------===//
278 
279 template <typename T>
280 static LogicalResult verifyConvOp(T op) {
281  const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
282  const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
283 
284  auto inputEType = inputType.getElementType();
285  auto weightEType = weightType.getElementType();
286  auto biasEType =
287  llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
288  auto resultEType =
289  llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
290  bool biasIsFloat = llvm::isa<FloatType>(biasEType);
291  bool resultIsFloat = llvm::isa<FloatType>(resultEType);
292 
293  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
294  inputEType = quantType.getStorageType();
295 
296  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
297  weightEType = quantType.getStorageType();
298 
299  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
300  biasEType = quantType.getStorageType();
301 
302  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
303  resultEType = quantType.getStorageType();
304 
305  if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
306  // for now, only enforce bias element type == result element type for
307  // float types.
308  op.emitOpError(
309  "expect both bias and result to have same element type, got ")
310  << biasEType << " and " << resultEType;
311  return failure();
312  }
313 
314  if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
315  isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
316  if (inputEType != weightEType) {
317  op.emitOpError(
318  "expect both input and weight to have same element type, got ")
319  << inputEType << " and " << weightEType;
320  return failure();
321  }
322  }
323 
324  bool inputIsFloat = llvm::isa<FloatType>(inputEType);
325  bool weightIsFloat = llvm::isa<FloatType>(weightEType);
326 
327  // Either both must be float or both non-float.
328  if (inputIsFloat != weightIsFloat) {
329  op.emitOpError(
330  "expect both input and weight to be float or not together, got ")
331  << inputEType << " and " << weightEType;
332  return failure();
333  }
334 
335  auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
336  if (inputEType != inputZpEType) {
337  return op.emitOpError("expect both input and its zero point are the same "
338  "element type, got ")
339  << inputEType << " and " << inputZpEType;
340  }
341 
342  auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
343  if (weightEType != weightZpEType) {
344  return op.emitOpError("expect both weight and its zero point are the same "
345  "element type, got ")
346  << weightEType << " and " << weightZpEType;
347  }
348 
349  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
350  if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
351  return failure();
352 
353  FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
354  if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
355  return failure();
356 
357  return success();
358 }
359 
360 LogicalResult tosa::ConstOp::verify() {
361 
362  auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().getType());
363  auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
364 
365  if (!attrType || !outputType) {
366  emitOpError("expected tensors for attr/result type");
367  return failure();
368  }
369 
370  if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
371  outputType.getElementType())) {
372  if (result.getStorageType() == attrType.getElementType())
373  return success();
374  }
375 
376  if (attrType.getElementType() != outputType.getElementType()) {
377  emitOpError("expected same attr/result element types");
378  return failure();
379  }
380 
381  return success();
382 }
383 
384 template <typename T>
385 static LogicalResult verifyConvOpModes(T op) {
386  auto inputEType =
387  llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
388 
389  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
390  inputEType = quantType.getStorageType();
391 
392  auto accType = op.getAccType();
393  if (inputEType.isInteger(8) && !accType.isInteger(32))
394  return op.emitOpError("accumulator type for i8 tensor is not i32");
395 
396  if (inputEType.isInteger(16) && !accType.isInteger(48))
397  return op.emitOpError("accumulator type for i16 tensor is not i48");
398 
399  if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
400  return op.emitOpError("accumulator type for f8 tensor is not f16");
401 
402  if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
403  return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
404 
405  if (inputEType.isBF16() && !accType.isF32())
406  return op.emitOpError("accumulator type for bf16 tensor is not f32");
407 
408  if (inputEType.isF32() && !accType.isF32())
409  return op.emitOpError("accumulator type for f32 tensor is not f32");
410 
411  auto resultEType =
412  llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
413 
414  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
415  resultEType = quantType.getStorageType();
416 
417  return success();
418 }
419 
420 //===----------------------------------------------------------------------===//
421 // ERROR_IF functions.
422 // ERROR_IF is a predicate that must set an error if the condition holds.
423 //===----------------------------------------------------------------------===//
424 
425 template <typename T>
426 static LogicalResult verifyConvOpErrorIf(T op) {
427  llvm::ArrayRef<int64_t> padding = op.getPad();
428  if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
429  return op.emitOpError("expect all padding values to be >= 0, got ")
430  << padding;
431 
432  llvm::ArrayRef<int64_t> strides = op.getStride();
433  if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
434  return op.emitOpError("expect all stride values to be >= 1, got ")
435  << strides;
436 
437  llvm::ArrayRef<int64_t> dilations = op.getDilation();
438  if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
439  return op.emitOpError("expect all dilation values to be >= 1, got ")
440  << dilations;
441 
442  const RankedTensorType outputType =
443  llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
444  if (!outputType)
445  // Skip following checks if output is not ranked
446  return success();
447 
448  const RankedTensorType inputType =
449  llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
450  const RankedTensorType weightType =
451  llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
452 
453  if (inputType && weightType) {
454  const auto verifyOutputSize =
455  [&op](const int64_t inputSize, const int64_t kernelSize,
456  const int64_t outputSize, const int64_t padBefore,
457  const int64_t padAfter, const int64_t stride,
458  const int64_t dilation, const llvm::StringRef dimName,
459  const llvm::StringRef dimAxis,
460  const llvm::StringRef padBeforeName,
461  const llvm::StringRef padAfterName) -> LogicalResult {
462  if (inputSize == ShapedType::kDynamic ||
463  kernelSize == ShapedType::kDynamic)
464  return success();
465 
466  // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
467 
468  const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
469  inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
470  stride);
471  if (!calculatedOutSizeMinusOne.has_value())
472  return op.emitOpError("expected input_")
473  << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
474  << padAfterName << " - (kernel_" << dimName
475  << " - 1) * dilation_" << dimAxis
476  << " to be wholly divisible by stride_" << dimAxis << ", got ("
477  << inputSize << " - 1 + " << padBefore << " + " << padAfter
478  << " - (" << kernelSize << " - 1) * " << dilation << ") / "
479  << stride;
480 
481  const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
482  if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
483  return op.emitOpError("calculated output ")
484  << dimName << " did not match expected: "
485  << "calculated=" << calculatedOutSize
486  << ", expected=" << outputSize;
487 
488  return success();
489  };
490 
491  // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
492  if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
493  if (failed(verifyOutputSize(
494  inputType.getDimSize(1), weightType.getDimSize(1),
495  outputType.getDimSize(1), padding[0], padding[1], strides[0],
496  dilations[0], "height", "y", "top", "bottom")))
497  return failure();
498 
499  if (failed(verifyOutputSize(
500  inputType.getDimSize(2), weightType.getDimSize(2),
501  outputType.getDimSize(2), padding[2], padding[3], strides[1],
502  dilations[1], "width", "x", "left", "right")))
503  return failure();
504  }
505 
506  // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
507  if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
508  if (failed(verifyOutputSize(
509  inputType.getDimSize(1), weightType.getDimSize(0),
510  outputType.getDimSize(1), padding[0], padding[1], strides[0],
511  dilations[0], "height", "y", "top", "bottom")))
512  return failure();
513 
514  if (failed(verifyOutputSize(
515  inputType.getDimSize(2), weightType.getDimSize(1),
516  outputType.getDimSize(2), padding[2], padding[3], strides[1],
517  dilations[1], "width", "x", "left", "right")))
518  return failure();
519  }
520 
521  // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
522  if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
523  if (failed(verifyOutputSize(
524  inputType.getDimSize(1), weightType.getDimSize(1),
525  outputType.getDimSize(1), padding[0], padding[1], strides[0],
526  dilations[0], "depth", "d", "front", "back")))
527  return failure();
528 
529  if (failed(verifyOutputSize(
530  inputType.getDimSize(2), weightType.getDimSize(2),
531  outputType.getDimSize(2), padding[2], padding[3], strides[1],
532  dilations[1], "height", "y", "top", "bottom")))
533  return failure();
534 
535  if (failed(verifyOutputSize(
536  inputType.getDimSize(3), weightType.getDimSize(3),
537  outputType.getDimSize(3), padding[4], padding[5], strides[2],
538  dilations[2], "width", "x", "left", "right")))
539  return failure();
540  }
541  }
542 
543  const RankedTensorType biasType =
544  llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
545  if (!biasType)
546  // Skip following checks if bias is not ranked
547  return success();
548 
549  const int64_t biasChannels = biasType.getDimSize(0);
550  const int64_t outputChannels =
551  outputType.getDimSize(outputType.getRank() - 1);
552  if (biasChannels == ShapedType::kDynamic ||
553  outputChannels == ShapedType::kDynamic)
554  // Skip following checks if biasChannels or outputChannels is dynamic dim
555  return success();
556 
557  if (biasChannels != outputChannels && biasChannels != 1)
558  return op.emitOpError(
559  "bias channels expected to be equal to output channels (")
560  << outputChannels << ") or 1, got " << biasChannels;
561 
562  return success();
563 }
564 
565 // Verify whether same type and shape of the given two types.
566 static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
567  StringRef name1, Type type2,
568  StringRef name2) {
569  auto shapeType1 = dyn_cast<ShapedType>(type1);
570  auto shapeType2 = dyn_cast<ShapedType>(type2);
571  if (!shapeType1 || !shapeType2)
572  return failure();
573 
574  auto elemType1 = shapeType1.getElementType();
575  auto elemType2 = shapeType2.getElementType();
576  if (elemType1 != elemType2)
577  return op->emitOpError()
578  << "require same element type for " << name1 << " (" << elemType1
579  << ") and " << name2 << " (" << elemType2 << ")";
580 
581  if (failed(verifyCompatibleShape(type1, type2)))
582  return op->emitOpError()
583  << "require same shapes for " << name1 << " (" << type1 << ") and "
584  << name2 << " (" << type2 << ")";
585 
586  return success();
587 }
588 
589 // Verify whether same length, type, and shape of the given two tensor lists.
590 static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
591  StringRef name1,
592  ValueRange list2,
593  StringRef name2) {
594  if (list1.size() != list2.size())
595  return op->emitOpError()
596  << "require same number of values in " << name1 << " ("
597  << list1.size() << ") and " << name2 << " (" << list2.size() << ")";
598 
599  for (auto [type1, type2] :
600  llvm::zip_equal(list1.getTypes(), list2.getTypes())) {
601  if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
602  return failure();
603  }
604 
605  return success();
606 }
607 
608 static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
609  ShapeAdaptor shapeAdaptor(type);
610  if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
611  return success();
612 
613  return shapeAdaptor.getNumElements() == 1 ? success() : failure();
614 }
615 
616 // Returns the first declaration point prior to this operation or failure if
617 // not found.
618 static FailureOr<tosa::VariableOp> findVariableDecl(Operation *op,
619  StringRef symName) {
620  ModuleOp module = op->getParentOfType<ModuleOp>();
621  tosa::VariableOp varOp = nullptr;
622 
623  // TODO: Adopt SymbolTable trait to Varible ops.
624  // Currently, the variable's definition point is searched via walk(),
625  // starting from the top-level ModuleOp and stopping at the point of use. Once
626  // TOSA control flow and variable extensions reach the complete state, may
627  // leverage MLIR's Symbol Table functionality to look up symbol and enhance
628  // the search to a TOSA specific graph traversal over the IR structure.
629  module.walk([&](Operation *tempOp) {
630  // Reach this op itself.
631  if (tempOp == op) {
632  return WalkResult::interrupt();
633  }
634 
635  if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
636  if (symName == tosaOp.getName()) {
637  varOp = tosaOp;
638  return WalkResult::interrupt();
639  }
640  }
641 
642  return WalkResult::advance();
643  });
644 
645  if (varOp)
646  return varOp;
647 
648  return failure();
649 }
650 
651 template <typename T>
652 static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
653  StringRef symName = op.getName();
654  FailureOr<tosa::VariableOp> varOp = findVariableDecl(op, symName);
655  if (failed(varOp))
656  return op->emitOpError("'")
657  << symName << "' has not been declared by 'tosa.variable'";
658 
659  // Verify type and shape
660  Type varType = cast<tosa::VariableOp>(varOp.value()).getType();
661  if (errorIfTypeOrShapeMismatch(op, type, name, varType, "the input tensor")
662  .failed())
663  return failure();
664 
665  return success();
666 }
667 
668 // verify that inType and outType have same element types
669 template <typename T>
670 static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
671  auto inputType = llvm::dyn_cast<TensorType>(inType);
672  auto outputType = llvm::dyn_cast<TensorType>(outType);
673  if (!inputType) {
674  op.emitOpError("expect shaped tensor for input, got ") << inType;
675  return failure();
676  }
677  if (!outputType) {
678  op.emitOpError("expect shaped tensor for output, got ") << outType;
679  return failure();
680  }
681  auto inputElementType = inputType.getElementType();
682  auto outputElementType = outputType.getElementType();
683  auto inputQuantType =
684  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
685  auto outputQuantType =
686  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
687  if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
688  (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
689  inputElementType != outputElementType) {
690  // only check if both element types are int/index/float/UniformQuantized
691  // eg, not sure how to check quant::QuantizedType
692  // this happens in test_conv2d_q_grouped_convolution in
693  // tfl-to-tosa-pipeline.mlir
694  op.emitOpError("expect input and output to have same element type, got ")
695  << inputElementType << " and " << outputElementType;
696  return failure();
697  }
698  return success();
699 }
700 
701 LogicalResult tosa::ArgMaxOp::verify() {
702  const ShapedType resultType = llvm::cast<ShapedType>(getType());
703 
704  // Ensure output is of 32-bit integer
705  if (const auto resultETy = resultType.getElementType();
706  !resultETy.isIntOrIndex())
707  return emitOpError("result tensor is not of integer type");
708 
709  const auto inputType = llvm::cast<ShapedType>(getInput().getType());
710  if (!inputType.hasRank())
711  return success();
712 
713  // Ensure axis is within the tensor rank
714  const int64_t axis = getAxisAttr().getInt();
715  if (((axis < 0) || axis >= inputType.getRank()))
716  return emitOpError("specified axis is outside the rank of the tensor");
717 
718  if (!resultType.hasRank())
719  return success();
720 
721  const ArrayRef<int64_t> inputShape = inputType.getShape();
722  const ArrayRef<int64_t> outputShape = resultType.getShape();
723  llvm::SmallVector<int64_t> expectedOutputShape(inputShape);
724  expectedOutputShape.erase(expectedOutputShape.begin() + axis);
725  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
726  return emitOpError("expected output shape '")
727  << expectedOutputShape << "', got '" << outputShape << "'";
728 
729  return success();
730 }
731 
732 template <typename T>
733 static LogicalResult verifyPoolingOp(T op) {
734  const llvm::ArrayRef<int64_t> kernel = op.getKernel();
735  if (llvm::any_of(kernel, [](int64_t s) { return s < 1; }))
736  return op.emitOpError("expect all kernel values to be >= 1, got ")
737  << kernel;
738 
739  const llvm::ArrayRef<int64_t> strides = op.getStride();
740  if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
741  return op.emitOpError("expect all stride values to be >= 1, got ")
742  << strides;
743 
744  const llvm::ArrayRef<int64_t> padding = op.getPad();
745  if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
746  return op.emitOpError("expect all padding values to be >= 0, got ")
747  << padding;
748 
749  // Padding must be less than kernel size to avoid a divide-by-zero
750  const int64_t kernelX = kernel[1];
751  const int64_t padLeft = padding[2];
752  const int64_t padRight = padding[3];
753  if (padRight >= kernelX || padLeft >= kernelX)
754  return op.emitOpError("expected left/right padding to be less than the "
755  "width of the kernel, got pad_left=")
756  << padLeft << ", pad_right=" << padRight << ", kernel_x=" << kernelX;
757 
758  const int64_t kernelY = kernel[0];
759  const int64_t padTop = padding[0];
760  const int64_t padBottom = padding[1];
761  if (padTop >= kernelY || padBottom >= kernelY)
762  return op.emitOpError("expected top/bottom padding to be less than the "
763  "height of the kernel, got pad_top=")
764  << padTop << ", pad_bottom=" << padBottom
765  << ", kernel_y=" << kernelY;
766 
767  const auto inputType =
768  llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
769  const auto outputType =
770  llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
771  if (!inputType || !outputType)
772  return success();
773 
774  const auto verifyOutputSize =
775  [&op](const int64_t inputSize, const int64_t outputSize,
776  const int64_t kernelSize, const int64_t strideSize,
777  const int64_t padBefore, const int64_t padAfter,
778  const llvm::StringRef dimName, const llvm::StringRef dimAxis,
779  const llvm::StringRef padBeforeName,
780  const llvm::StringRef padAfterName) -> LogicalResult {
781  if (ShapedType::isDynamic(inputSize))
782  return success();
783 
784  const std::optional<int64_t> calculatedOutSizeMinusOne =
785  idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
786  if (!calculatedOutSizeMinusOne.has_value())
787  return op.emitOpError("expected input_")
788  << dimName << " + pad_" << padBeforeName << " + pad_"
789  << padAfterName << " - kernel_" << dimAxis
790  << " to be wholly divisible by stride_" << dimAxis << ", got ("
791  << inputSize << " + " << padBefore << " + " << padAfter << " - "
792  << kernelSize << ") / " << strideSize;
793 
794  const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
795  if (!ShapedType::isDynamic(outputSize) && calculatedOutSize != outputSize)
796  return op.emitOpError("calculated output ")
797  << dimName << " did not match expected: "
798  << "calculated=" << calculatedOutSize
799  << ", expected=" << outputSize;
800 
801  return success();
802  };
803 
804  if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
805  kernel[0], strides[0], padding[0], padding[1],
806  "height", "y", "top", "bottom")))
807  return failure();
808 
809  if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
810  kernel[1], strides[1], padding[2], padding[3],
811  "width", "x", "left", "right")))
812  return failure();
813 
814  return success();
815 }
816 
817 LogicalResult tosa::AvgPool2dOp::verify() {
818  if (failed(verifyPoolingOp(*this)))
819  return failure();
820 
821  const Type inputETy = getStorageElementTypeOrSelf(getInput().getType());
822  const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType());
823  const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType());
824  const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType());
825 
826  auto accType = getAccType();
827  if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
828  return emitOpError("accumulator type for integer tensor is not i32");
829 
830  if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
831  return emitOpError("accumulator type for f16 tensor is not f16/f32");
832 
833  if (inputETy.isBF16() && !accType.isF32())
834  return emitOpError("accumulator type for bf16 tensor is not f32");
835 
836  if (inputETy.isF32() && !accType.isF32())
837  return emitOpError("accumulator type for f32 tensor is not f32");
838 
839  if (inputETy != inputZpETy)
840  return emitOpError("expect both input and its zero point are the same "
841  "element type, got ")
842  << inputETy << " and " << inputZpETy;
843 
844  if (resultETy != outputZpETy)
845  return emitOpError("expect both output and its zero point are the same "
846  "element type, got ")
847  << resultETy << " and " << outputZpETy;
848 
849  FailureOr<int64_t> maybeIZp = getInputZeroPoint();
850  if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
851  return failure();
852 
853  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
854  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
855  return failure();
856 
857  return success();
858 }
859 
860 LogicalResult tosa::ClampOp::verify() {
861  mlir::Type inputETy =
862  llvm::cast<ShapedType>(getInput().getType()).getElementType();
863  if (auto quantType =
864  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
865  inputETy = quantType.getStorageType();
866  }
867  mlir::Type outputETy =
868  llvm::cast<ShapedType>(getOutput().getType()).getElementType();
869  if (auto quantType =
870  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
871  outputETy = quantType.getStorageType();
872  }
873  if (inputETy != outputETy)
874  return emitOpError("input/output element types are incompatible.");
875 
876  auto maxValAttr = getMaxValAttr();
877  auto minValAttr = getMinValAttr();
878 
879  unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
880 
881  if (inputETy.isInteger(dataTypeBitWidth)) {
882  // if input datatype is integer, check that the min_val/max_val attributes
883  // are integer attributes, and that their type is the same as the input's
884  // datatype
885  auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
886  auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
887  if (!intMaxValAttr || !intMinValAttr ||
888  (intMaxValAttr.getType() != intMinValAttr.getType()) ||
889  (intMaxValAttr.getType() != inputETy))
890  return emitOpError("min/max attributes types are incompatible with "
891  "input/output element types.");
892 
893  const bool isUnsigned = cast<IntegerType>(inputETy).isUnsigned();
894  const APInt minVal = intMinValAttr.getValue();
895  const APInt maxVal = intMaxValAttr.getValue();
896  if (isUnsigned ? maxVal.ult(minVal) : maxVal.slt(minVal))
897  return emitOpError("expected min_val <= max_val, got min_val=")
898  << minValAttr << ", max_val=" << maxValAttr;
899  } else {
900  // otherwise, input datatype is float, check that the min_val/max_val
901  // attributes share the same type and that their type is the same as the
902  // input's datatype
903  auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
904  auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
905  if (!floatMaxValAttr || !floatMinValAttr ||
906  (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
907  (floatMaxValAttr.getType() != inputETy))
908  return emitOpError("min/max attributes types are incompatible with "
909  "input/output element types.");
910 
911  const APFloat minVal = floatMinValAttr.getValue();
912  const APFloat maxVal = floatMaxValAttr.getValue();
913  if (minVal.isNaN() || maxVal.isNaN())
914  return emitOpError("min/max attributes should not be 'NaN', got min_val=")
915  << minValAttr << ", max_val=" << maxValAttr;
916 
917  if (maxVal < minVal)
918  return emitOpError("expected min_val <= max_val, got min_val=")
919  << minValAttr << ", max_val=" << maxValAttr;
920  }
921 
922  return success();
923 }
924 
925 //===----------------------------------------------------------------------===//
926 // TOSA Operator Quantization Builders.
927 //===----------------------------------------------------------------------===//
928 
929 /// This builder is called on all convolution operators except TransposeConv,
930 /// which has specialized output shape semantics. The builder also defines the
931 /// bitwidth of the output given the bit width of the input & weight content.
932 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
933  Type outputType, Value input, Value weight,
934  Value bias, DenseI64ArrayAttr pad,
935  DenseI64ArrayAttr stride,
936  DenseI64ArrayAttr dilation,
937  TypeAttr accType) {
938  auto zps = createZPsAsConst(builder, input, weight);
939  result.addOperands({input, weight, bias, zps.first, zps.second});
940  result.addAttribute("pad", pad);
941  result.addAttribute("stride", stride);
942  result.addAttribute("dilation", dilation);
943  result.addAttribute("acc_type", accType);
944  Type finalOutputType = outputType;
945  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
946  if (quantAttr) {
947  finalOutputType =
948  buildConvOpResultTypeInfo(builder, outputType, input, weight);
949  }
950  result.addTypes(finalOutputType);
951 }
952 
953 /// Handles tosa.transpose_conv2d which has outpad and output shape
954 /// attributes.
955 static void
957  Type outputType, Value input, Value weight,
958  Value bias, DenseI64ArrayAttr outpad,
959  DenseI64ArrayAttr stride, TypeAttr accType) {
960  auto zps = createZPsAsConst(builder, input, weight);
961  result.addOperands({input, weight, bias, zps.first, zps.second});
962  result.addAttribute("out_pad", outpad);
963  result.addAttribute("stride", stride);
964  result.addAttribute("acc_type", accType);
965  Type finalOutputType = outputType;
966  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
967  if (quantAttr) {
968  finalOutputType =
969  buildConvOpResultTypeInfo(builder, outputType, input, weight);
970  }
971  result.addTypes(finalOutputType);
972 }
973 
974 /// The tosa.matmul op is also intended to be generated where a fully_connected
975 /// op must be constructed where the weight is not a constant. In this case,
976 /// the fully_connected op must be expressed using matmul.
977 /// TODO: Add link to the leglization document explaining this.
979  OperationState &result, Type outputType,
980  Value a, Value b) {
981  auto zps = createZPsAsConst(builder, a, b);
982  result.addOperands({a, b, zps.first, zps.second});
983 
984  Type finalOutputType{outputType};
985  if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
986  auto eType = getStorageElementTypeOrSelf(a.getType());
987  auto inputBits = eType.getIntOrFloatBitWidth();
988 
989  auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
990  assert(outputShapedType && "Output must be a shaped type");
991 
992  IntegerType accElementType;
993  if (inputBits == 16)
994  accElementType = builder.getIntegerType(48);
995  else
996  accElementType = builder.getI32Type();
997 
998  finalOutputType = outputShapedType.clone(accElementType);
999  }
1000  result.addTypes(finalOutputType);
1001 }
1002 
1003 /// Both the tosa.avg_pool2d and unary ops use the same
1004 /// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
1005 /// has additional parameters not part of the unary ops.
1006 static void
1008  Type outputType, Value input,
1009  DenseArrayAttr kernel, DenseArrayAttr stride,
1010  DenseArrayAttr pad, TypeAttr accType) {
1011  const Location loc{result.location};
1012  int64_t inputZp{0};
1013  int64_t outputZp{0};
1014 
1015  if (auto quantAttr =
1016  buildUnaryOpQuantizationAttr(builder, input, outputType)) {
1017  inputZp = quantAttr.getInputZp();
1018  outputZp = quantAttr.getOutputZp();
1019  }
1020  const std::optional<Value> inputZpOp =
1021  createZeroPointTensor(builder, loc, input.getType(), inputZp);
1022  if (!inputZpOp) {
1023  (void)emitError(
1024  loc,
1025  "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1026  }
1027  const std::optional<Value> outputZpOp =
1028  createZeroPointTensor(builder, loc, outputType, outputZp);
1029  if (!outputZpOp) {
1030  (void)emitError(loc, "Failed to create output zero point tensor for "
1031  "quantized AVG_POOL2D op");
1032  }
1033 
1034  if (inputZpOp && outputZpOp) {
1035  result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1036  } else {
1037  // failed to create one or more zero points above: just add input as
1038  // operands this will trigger error in building the op because of missing
1039  // zero points
1040  result.addOperands({input});
1041  }
1042  result.addAttribute("kernel", kernel);
1043  result.addAttribute("stride", stride);
1044  result.addAttribute("pad", pad);
1045  result.addAttribute("acc_type", accType);
1046  result.types.push_back(outputType);
1047 }
1048 
1049 /// This builder is called on single-parameter negate operator
1050 /// to construct input and output zero points based on their
1051 /// types.
1053  OperationState &result, Type outputType,
1054  Value input) {
1055  const Location loc{result.location};
1056  int64_t input1Zp{0};
1057  int64_t outputZp{0};
1058  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
1059  if (quantAttr) {
1060  input1Zp = quantAttr.getInputZp();
1061  outputZp = quantAttr.getOutputZp();
1062  }
1063  const std::optional<Value> input1ZpOp =
1064  createZeroPointTensor(builder, loc, input.getType(), input1Zp);
1065  if (!input1ZpOp) {
1066  (void)emitError(
1067  loc, "Failed to create input1 zero point for quantized NEGATE op");
1068  }
1069 
1070  const std::optional<Value> outputZpOp =
1071  createZeroPointTensor(builder, loc, input.getType(), outputZp);
1072  if (!outputZpOp) {
1073  (void)emitError(
1074  loc, "Failed to create output zero point for quantized NEGATE op");
1075  }
1076 
1077  if (input1ZpOp && outputZpOp) {
1078  result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1079  } else {
1080  // failed to create one or more zero points above: just add input as
1081  // operands. This will trigger error in building the op because of
1082  // missing zero points
1083  result.addOperands({input});
1084  }
1085 
1086  result.types.push_back(outputType);
1087 }
1088 
1089 /// This builder is called on TOSA pad operator that needs to create its own
1090 /// OptionalAttr quantization_attr parameter to scale the padding values
1091 /// correctly. No pad_const is interpreted as zero-padding.
1092 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
1093  Type outputType, Value input,
1094  Value paddings) {
1095  const Location loc{result.location};
1096  int32_t zp{0};
1097  const auto quantAttr = buildPadOpQuantizationAttr(builder, input);
1098  if (quantAttr) {
1099  zp = static_cast<int32_t>(quantAttr.getInputZp());
1100  }
1101  const auto padConstOp{createPadConstTensor(builder, loc, input, zp)};
1102  result.addOperands({input, paddings, padConstOp});
1103  result.types.push_back(outputType);
1104 }
1105 
1106 //===----------------------------------------------------------------------===//
1107 // TOSA Operator Return Type Inference.
1108 //===----------------------------------------------------------------------===//
1109 
1110 static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
1111  SmallVector<int64_t> &outShape) {
1112  int64_t outRank = 0;
1113  for (int i = 0, e = operands.size(); i != e; ++i) {
1114  auto shape = operands.getShape(i);
1115  if (!shape.hasRank()) {
1116  // TODO(jennik): Update function to have better case handling for
1117  // invalid operands and for ranked tensors.
1118  return failure();
1119  }
1120  outRank = std::max<int64_t>(outRank, shape.getRank());
1121  }
1122 
1123  outShape.resize(outRank, 1);
1124 
1125  for (int i = 0, e = operands.size(); i != e; ++i) {
1126  auto shape = operands.getShape(i);
1127  auto rankDiff = outShape.size() - shape.getRank();
1128 
1129  for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
1130  auto dim1 = outShape[i + rankDiff];
1131  auto dim2 = shape.getDimSize(i);
1132  auto resolvedDim = dim1;
1133 
1134  if (dim1 == 1) {
1135  resolvedDim = dim2;
1136  } else if (dim2 == 1) {
1137  resolvedDim = dim1;
1138  } else if (dim1 != dim2) {
1139  return failure();
1140  }
1141  outShape[i + rankDiff] = resolvedDim;
1142  }
1143  }
1144 
1145  return success();
1146 }
1147 
1148 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1149  MLIRContext *context, ::std::optional<Location> location,
1150  ArgMaxOp::Adaptor adaptor,
1151  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1152  ShapeAdaptor inputShape(adaptor.getInput().getType());
1153  IntegerAttr axis = adaptor.getProperties().axis;
1154  int32_t axisVal = axis.getValue().getSExtValue();
1155 
1156  if (!inputShape.hasRank()) {
1157  inferredReturnShapes.push_back(ShapedTypeComponents());
1158  return success();
1159  }
1160 
1161  SmallVector<int64_t> outShape;
1162  outShape.reserve(inputShape.getRank() - 1);
1163  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1164  if (i == axisVal)
1165  continue;
1166  outShape.push_back(inputShape.getDimSize(i));
1167  }
1168 
1169  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1170  return success();
1171 }
1172 
1173 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1174  MLIRContext *context, ::std::optional<Location> location,
1175  RFFT2dOp::Adaptor adaptor,
1176  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1177  ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1178 
1179  if (!inputShape.hasRank())
1180  return failure();
1181 
1182  llvm::SmallVector<int64_t> outputShape;
1183  outputShape.resize(3, ShapedType::kDynamic);
1184  outputShape[0] = inputShape.getDimSize(0);
1185  outputShape[1] = inputShape.getDimSize(1);
1186  int64_t inWidth = inputShape.getDimSize(2);
1187 
1188  // Note that we can support this calculation symbolically
1189  // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
1190  if (inWidth != ShapedType::kDynamic)
1191  outputShape[2] = inWidth / 2 + 1;
1192 
1193  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1194  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1195 
1196  return success();
1197 }
1198 
1199 static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
1200  const llvm::StringRef dimName) {
1201  const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1202  if (!isPowerOfTwo)
1203  return op->emitOpError("expected ")
1204  << dimName << " to be a power of two, got " << dimSize;
1205 
1206  return success();
1207 }
1208 
1209 LogicalResult tosa::RFFT2dOp::verify() {
1210  const auto outputTypes = getResultTypes();
1211  if (failed(verifyCompatibleShapes(outputTypes)))
1212  return emitOpError("expected output shapes to match, got ") << outputTypes;
1213 
1214  const auto inputType =
1215  llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1216  if (!inputType)
1217  return success();
1218 
1219  const int64_t height = inputType.getDimSize(1);
1220  if (!ShapedType::isDynamic(height) &&
1221  failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1222  return failure();
1223 
1224  const int64_t width = inputType.getDimSize(2);
1225  if (!ShapedType::isDynamic(width) &&
1226  failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1227  return failure();
1228 
1229  const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1230  if (!outputType)
1231  return success();
1232 
1233  // Batch and height input/output dimensions should match
1234  if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
1235  outputType.getShape().drop_back())))
1236  return emitOpError("expected batch and height dimensions of input/output "
1237  "to match, got input=")
1238  << inputType << " output=" << outputType;
1239 
1240  // Output width dimension expected to be input_width / 2 + 1
1241  const int64_t outputWidth = outputType.getDimSize(2);
1242  if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) &&
1243  (outputWidth != (width / 2) + 1))
1244  return emitOpError(
1245  "expected output width to be equal to input_width / 2 + 1, got ")
1246  << outputWidth;
1247 
1248  return success();
1249 }
1250 
1251 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1252  MLIRContext *context, ::std::optional<Location> location,
1253  FFT2dOp::Adaptor adaptor,
1254  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1255  inferredReturnShapes.push_back(
1256  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
1257  inferredReturnShapes.push_back(
1258  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
1259  return success();
1260 }
1261 
1262 LogicalResult tosa::FFT2dOp::verify() {
1263  const auto inputRealType =
1264  llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1265  const auto inputImagType =
1266  llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
1267  if (!inputRealType || !inputImagType)
1268  return success();
1269 
1270  const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
1271  return ShapedType::isDynamic(a) ? a : b;
1272  };
1273 
1274  const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1275  inputImagType.getDimSize(1));
1276  if (!ShapedType::isDynamic(height) &&
1277  failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1278  return failure();
1279 
1280  const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1281  inputImagType.getDimSize(2));
1282  if (!ShapedType::isDynamic(width) &&
1283  failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1284  return failure();
1285 
1286  return success();
1287 }
1288 
1289 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1290  MLIRContext *context, ::std::optional<Location> location,
1291  ConcatOp::Adaptor adaptor,
1292  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1293  // Infer all dimension sizes by reducing based on inputs.
1294  const Properties &prop = adaptor.getProperties();
1295  int32_t axis = prop.axis.getValue().getSExtValue();
1296  llvm::SmallVector<int64_t> outputShape;
1297  bool hasRankedInput = false;
1298  for (auto operand : adaptor.getOperands()) {
1299  ShapeAdaptor operandShape(operand.getType());
1300  if (!operandShape.hasRank())
1301  continue;
1302 
1303  // Copy the Operand's rank.
1304  if (!hasRankedInput)
1305  outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1306 
1307  // Copy shapes until the dim is non-dynamic.
1308  for (int i = 0, s = operandShape.getRank(); i < s; i++) {
1309  if (i == axis || operandShape.isDynamicDim(i))
1310  continue;
1311  if (outputShape[i] == ShapedType::kDynamic)
1312  outputShape[i] = operandShape.getDimSize(i);
1313  if (outputShape[i] != operandShape.getDimSize(i))
1314  return emitOptionalError(location,
1315  "Cannot concat tensors with different sizes"
1316  " on the non-axis dimension ",
1317  i);
1318  }
1319 
1320  hasRankedInput = true;
1321  }
1322 
1323  if (adaptor.getInput1().empty())
1324  return failure();
1325 
1326  Type inputType =
1327  llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1328  if (!hasRankedInput) {
1329  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1330  return success();
1331  }
1332 
1333  // Determine the dimension size along the concatenation axis.
1334  int64_t concatDimSize = 0;
1335  for (auto operand : adaptor.getOperands()) {
1336  ShapeAdaptor operandShape(operand.getType());
1337 
1338  // We need to know the length of the concatenation axis of all inputs to
1339  // determine the dimension size of the output shape.
1340  if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1341  concatDimSize = ShapedType::kDynamic;
1342  break;
1343  }
1344 
1345  concatDimSize += operandShape.getDimSize(axis);
1346  }
1347 
1348  outputShape[axis] = concatDimSize;
1349 
1350  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1351  return success();
1352 }
1353 
1354 LogicalResult tosa::ConcatOp::verify() {
1355  // check that each input has same element type as output
1356  auto outType = getOutput().getType();
1357  const Operation::operand_range inputList = getInput1();
1358 
1359  // Check there is at least one input
1360  if (inputList.empty())
1361  return emitOpError("expect at least one input");
1362 
1363  if (!llvm::all_of(inputList, [&](auto input) {
1364  return succeeded(verifySameElementTypes(
1365  *this, /* inType = */ input.getType(), outType));
1366  })) {
1367  return failure();
1368  }
1369 
1370  const int32_t axis = getAxis();
1371  ShapeAdaptor firstRankedInputShape = nullptr;
1372  for (const auto &input : inputList) {
1373  const Type inputType = input.getType();
1374  ShapeAdaptor currShape(inputType);
1375  if (currShape.hasRank()) {
1376  firstRankedInputShape = currShape;
1377  // Check axis is in expected range
1378  if (axis < 0 || axis >= firstRankedInputShape.getRank())
1379  return emitOpError("expect axis to be within range 0 < axis < "
1380  "rank(input1[firstRankedTensorIdx]), got ")
1381  << axis;
1382  break;
1383  }
1384  }
1385 
1386  const auto allOperandsHasRank = [](const Value input) {
1387  return ShapeAdaptor(input.getType()).hasRank();
1388  };
1389  if (llvm::all_of(inputList, allOperandsHasRank)) {
1390  const int64_t firstInputRank = firstRankedInputShape.getRank();
1391 
1392  for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {
1393  const ShapeAdaptor inputShape(input.getType());
1394  const int64_t inputRank = inputShape.getRank();
1395  const size_t operandNum = index + 1;
1396 
1397  // Check that each operand has the same rank
1398  if (inputRank != firstInputRank)
1399  return emitOpError(
1400  "expect all operands to have the same rank, but got ")
1401  << firstInputRank << " vs " << inputRank << " on operands 0 and "
1402  << operandNum;
1403 
1404  // Check non-axis dims match
1405  for (int i = 0; i < inputRank; i++) {
1406  const int64_t inputDim = inputShape.getDimSize(i);
1407  const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
1408  if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
1409  inputShape.isDynamicDim(i))
1410  continue;
1411  if (inputDim != firstInputDim)
1412  return emitOpError("expect all operand shapes to have the same sizes "
1413  "on non-axis dimensions, but got ")
1414  << inputDim << " vs " << firstInputDim << " at index " << i
1415  << " on operands 0 and " << operandNum;
1416  }
1417  }
1418 
1419  // ERROR_IF(axis_sum != shape[axis]);
1420  int64_t axisSum = 0;
1421  for (const auto &input : inputList) {
1422  const ShapeAdaptor inputShape(input.getType());
1423  if (inputShape.isDynamicDim(axis)) {
1424  // make axisSum negative to indicate invalid value
1425  axisSum = -1;
1426  break;
1427  }
1428  axisSum += inputShape.getDimSize(axis);
1429  }
1430  const ShapeAdaptor outputShape(outType);
1431  if (axisSum >= 0 && outputShape.hasRank() &&
1432  !outputShape.isDynamicDim(axis) &&
1433  axisSum != outputShape.getDimSize(axis))
1434  return emitOpError("requires sum of axis dimensions of input1 "
1435  "equal to output axis dimension, got ")
1436  << axisSum << " and " << outputShape.getDimSize(axis);
1437  }
1438 
1439  return success();
1440 }
1441 
1442 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1443  MLIRContext *context, ::std::optional<Location> location,
1444  ValueShapeRange operands, DictionaryAttr attributes,
1445  OpaqueProperties properties, RegionRange regions,
1446  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1447  auto elementType = IntegerType::get(context, /*width=*/1);
1448 
1449  llvm::SmallVector<int64_t> outShape;
1450  if (resolveBroadcastShape(operands, outShape).failed()) {
1451  inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
1452  return success();
1453  }
1454 
1455  inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
1456  return success();
1457 }
1458 
1459 bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1460  if (l.size() != r.size() || l.size() != 1)
1461  return false;
1462  return succeeded(verifyCompatibleShape(l[0], r[0]));
1463 }
1464 
1465 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1466  MLIRContext *context, ::std::optional<Location> location,
1467  MatMulOp::Adaptor adaptor,
1468  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1469  ShapeAdaptor lhsShape(adaptor.getA().getType());
1470  ShapeAdaptor rhsShape(adaptor.getB().getType());
1471 
1472  // All shapes are dynamic.
1473  SmallVector<int64_t> outShape;
1474  outShape.resize(3, ShapedType::kDynamic);
1475 
1476  if (lhsShape.hasRank()) {
1477  outShape[0] = lhsShape.getDimSize(0);
1478  outShape[1] = lhsShape.getDimSize(1);
1479  }
1480 
1481  if (rhsShape.hasRank()) {
1482  outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1483  : outShape[0];
1484  outShape[2] = rhsShape.getDimSize(2);
1485  }
1486 
1487  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1488  return success();
1489 }
1490 
1491 LogicalResult MatMulOp::verify() {
1492  auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
1493  auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
1494 
1495  // Must be shaped tensor types
1496  if (!aType)
1497  return emitOpError("expect a shaped tensor for input a, got ")
1498  << getA().getType();
1499 
1500  if (!bType)
1501  return emitOpError("expect a shaped tensor for input b, got ")
1502  << getB().getType();
1503 
1504  auto aElementType = aType.getElementType();
1505  auto bElementType = bType.getElementType();
1506 
1507  auto aQuantizedEType =
1508  llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1509  auto bQuantizedEType =
1510  llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1511 
1512  if (aQuantizedEType || bQuantizedEType) {
1513  if (!aQuantizedEType || !bQuantizedEType) {
1514  return emitOpError("expect operands to be both quantized or both not "
1515  "quantized, got ")
1516  << aElementType << " and " << bElementType;
1517  }
1518  // both a and b have quantized element types
1519  auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1520  auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1521  if (aQuantWidth != bQuantWidth) {
1522  return emitOpError("expect quantized operands to have same widths, got ")
1523  << aQuantWidth << " and " << bQuantWidth;
1524  }
1525  } else {
1526  // non-quantized element types
1527  if (aElementType != bElementType) {
1528  return emitOpError("expect same element type for inputs a and b, got ")
1529  << aElementType << " and " << bElementType;
1530  }
1531  }
1532 
1533  // check a_zp and b_zp
1534  auto aEType = getStorageElementTypeOrSelf(aType);
1535  auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
1536  if (aEType != aZpEType) {
1537  return emitOpError("expect input a and a_zp have the same "
1538  "element type, got ")
1539  << aEType << " and " << aZpEType;
1540  }
1541 
1542  auto bEType = getStorageElementTypeOrSelf(bType);
1543  auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
1544  if (bEType != bZpEType) {
1545  return emitOpError("expect input b and b_zp have the same "
1546  "element type, got ")
1547  << bEType << " and " << bZpEType;
1548  }
1549 
1550  FailureOr<int64_t> maybeAZp = getAZeroPoint();
1551  if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1552  return failure();
1553 
1554  FailureOr<int64_t> maybeBZp = getBZeroPoint();
1555  if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1556  return failure();
1557 
1558  return success();
1559 }
1560 
1561 LogicalResult tosa::PadOp::inferReturnTypeComponents(
1562  MLIRContext *context, ::std::optional<Location> location,
1563  PadOp::Adaptor adaptor,
1564  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1565  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1566  auto paddingRank =
1567  cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
1568  SmallVector<int64_t> outputShape;
1569 
1570  // If the input rank is unknown, we can infer the output rank using the
1571  // padding shape's rank divided by 2.
1572  if (!inputShape.hasRank()) {
1573  outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
1574  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1575  return success();
1576  }
1577 
1578  SmallVector<int64_t> paddingValues;
1579  // If the paddings value is not a constant, all dimensions must be dynamic.
1580  if (!tosa::getConstShapeValues(adaptor.getPadding().getDefiningOp(),
1581  paddingValues)) {
1582  outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1583  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1584  return success();
1585  }
1586 
1587  outputShape.reserve(inputShape.getRank());
1588  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1589  if (inputShape.isDynamicDim(i)) {
1590  outputShape.push_back(ShapedType::kDynamic);
1591  continue;
1592  }
1593  auto padFront = paddingValues[i * 2];
1594  auto padBack = paddingValues[i * 2 + 1];
1595  if (padFront < 0 || padBack < 0) {
1596  // if either padding for dim i is -1, output dim is unknown
1597  outputShape.push_back(ShapedType::kDynamic);
1598  continue;
1599  }
1600 
1601  outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
1602  }
1603 
1604  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1605  return success();
1606 }
1607 
1608 LogicalResult tosa::PadOp::verify() {
1609  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1610  /* outType = */ getOutput().getType())
1611  .failed()) {
1612  return failure();
1613  }
1614 
1615  if (auto padConst = getPadConst()) {
1616  if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
1617  /* outType = */ getOutput().getType())
1618  .failed()) {
1619  return failure();
1620  }
1621  }
1622 
1623  RankedTensorType inputType =
1624  llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1625  RankedTensorType outputType =
1626  llvm::dyn_cast<RankedTensorType>(getOutput().getType());
1627  if (!inputType || !outputType)
1628  return success();
1629 
1630  auto inputRank = inputType.getRank();
1631  auto outputRank = outputType.getRank();
1632  if (inputRank != outputRank)
1633  return emitOpError() << "expect same input and output tensor rank, but got "
1634  << "inputRank: " << inputRank
1635  << ", outputRank: " << outputRank;
1636 
1637  DenseIntElementsAttr paddingAttr;
1638  if (!matchPattern(getPadding(), m_Constant(&paddingAttr))) {
1639  return failure();
1640  }
1641 
1642  auto paddingValues = paddingAttr.getValues<APInt>();
1643  if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
1644  return emitOpError() << "padding tensor must have " << inputRank
1645  << " * 2 = " << inputRank * 2 << " elements, but got "
1646  << paddingValues.size();
1647 
1648  auto inputShape = inputType.getShape();
1649  auto outputShape = outputType.getShape();
1650 
1651  for (int64_t i = 0; i < inputRank; ++i) {
1652  int64_t padStart = paddingValues[i * 2].getSExtValue();
1653  int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
1654 
1655  if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
1656  return emitOpError()
1657  << "invalid padding values at dimension " << i
1658  << ": values must be non-negative or -1 for dynamic padding, got ["
1659  << padStart << ", " << padEnd << "]";
1660  }
1661 
1662  // Skip shape verification for dynamic input/output
1663  if (inputShape[i] == ShapedType::kDynamic ||
1664  outputShape[i] == ShapedType::kDynamic)
1665  continue;
1666 
1667  if (outputShape[i] != inputShape[i] + padStart + padEnd) {
1668  return emitOpError() << "mismatch in output shape at dimension " << i
1669  << ": expected " << inputShape[i] << " + "
1670  << padStart << " + " << padEnd << " = "
1671  << (inputShape[i] + padStart + padEnd)
1672  << ", but got " << outputShape[i];
1673  }
1674  }
1675 
1676  return success();
1677 }
1678 
1680  return to_vector(llvm::map_range(shape, [](int64_t dim) {
1681  return dim == -1 ? ShapedType::kDynamic : dim;
1682  }));
1683 }
1684 
1685 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
1686  MLIRContext *context, ::std::optional<Location> location,
1687  SliceOp::Adaptor adaptor,
1688  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1689 
1690  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
1691  SmallVector<int64_t> start;
1692  SmallVector<int64_t> size;
1693 
1694  if (!tosa::getConstShapeValues(adaptor.getStart().getDefiningOp(), start) ||
1695  !tosa::getConstShapeValues(adaptor.getSize().getDefiningOp(), size)) {
1696  auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
1697  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
1698  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
1699  return success();
1700  }
1701 
1702  // if size[i] is -1, all remaining elements in dimension i are included
1703  // in the slice, similar to TF.
1704  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1705  // initialize outputShape to all unknown
1706  SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
1707  if (inputShape.hasRank()) {
1708  for (size_t i = 0; i < size.size(); i++) {
1709  if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
1710  (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
1711  start[i] < inputShape.getDimSize(i))) {
1712  // size[i] is not 0 and not < -1, and start[i] is in valid range
1713  if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
1714  // input shape has unknown dim[i] - only valid if size[i] > 0
1715  if (size[i] > 0) {
1716  outputShape[i] = size[i];
1717  }
1718  } else {
1719  // input shape has known dim[i]
1720  if (size[i] == -1) {
1721  outputShape[i] = inputShape.getDimSize(i) - start[i];
1722  } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
1723  // start[i] + size[i] is within bound of input shape's dim[i]
1724  outputShape[i] = size[i];
1725  }
1726  }
1727  }
1728  }
1729  } else {
1730  outputShape = convertToMlirShape(size);
1731  }
1732  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1733  return success();
1734 }
1735 
1736 LogicalResult tosa::SliceOp::verify() {
1737  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1738  /* outType = */ getOutput().getType())
1739  .failed())
1740  return failure();
1741 
1742  const ShapeAdaptor inputShape(getInput1().getType());
1743  if (inputShape.hasRank()) {
1744  const auto inputRank = inputShape.getRank();
1745  const ShapeAdaptor outputShape(getOutput().getType());
1746  if (outputShape.hasRank() && inputRank != outputShape.getRank())
1747  return emitOpError(
1748  "expect input1 and output to have the same ranks, got ")
1749  << inputRank << " and " << outputShape.getRank();
1750 
1751  const auto startShapeRank =
1752  llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
1753  if (inputRank != startShapeRank)
1754  return emitOpError("length of start is not equal to rank of input shape");
1755 
1756  const auto sizeShapeRank =
1757  llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
1758  if (inputRank != sizeShapeRank)
1759  return emitOpError("length of size is not equal to rank of input shape");
1760  }
1761 
1762  return success();
1763 }
1764 
1765 LogicalResult tosa::MulOp::inferReturnTypeComponents(
1766  MLIRContext *context, ::std::optional<Location> location,
1767  ValueShapeRange operands, DictionaryAttr attributes,
1768  OpaqueProperties properties, RegionRange regions,
1769  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1770  // mul op's output shape only depend on input1 and input2, not on shift
1771  ValueShapeRange twoInputs = operands.drop_back();
1772  llvm::SmallVector<int64_t> outShape;
1773  if (resolveBroadcastShape(twoInputs, outShape).failed()) {
1774  inferredReturnShapes.push_back(ShapedTypeComponents());
1775  } else {
1776  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1777  }
1778  return success();
1779 }
1780 
1781 LogicalResult tosa::MulOp::verify() {
1782  auto resElemType = getElementTypeOrSelf(getOutput());
1783 
1784  // Verify if the element type among operands and result match tosa
1785  // specification.
1786  if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
1787  IntegerType lhsIntType =
1788  cast<IntegerType>(getElementTypeOrSelf(getInput1()));
1789  IntegerType rhsIntType =
1790  cast<IntegerType>(getElementTypeOrSelf(getInput2()));
1791  if (lhsIntType != rhsIntType)
1792  return emitOpError("requires the same element type for all operands");
1793 
1794  // Though the spec requires the element type of result to be i32, a more
1795  // relaxed way is provided at dialect level for easier cooperating with
1796  // other dialects.
1797  if (lhsIntType.getWidth() > resIntType.getWidth())
1798  return emitOpError("invalid data type size for operands or result");
1799 
1800  } else {
1801  // For other supported type, the spec requires requires the same element
1802  // type for all operands (excludes `shift` operand) and results.
1803  for (int i = 0; i < 2; ++i) {
1804  if (getElementTypeOrSelf(getOperand(i)) != resElemType)
1805  return emitOpError(
1806  "requires the same element type for all operands and results");
1807  }
1808 
1809  // verify shift has value 0 for non-integer types
1810  ElementsAttr shift_elem;
1811  if (matchPattern(getShift(), m_Constant(&shift_elem))) {
1812  int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1813  if (shift != 0) {
1814  return emitOpError() << "require shift to be 0 for float type";
1815  }
1816  }
1817  }
1818 
1819  // Verify the op has same ranks for all main operands (excludes extra operands
1820  // such as shift of mul op, so this is the only difference with the built-in
1821  // `SameOperandsAndResultRank` trait) and results types, if known.
1822 
1823  // delegate function that returns true if type is a shaped type with known
1824  // rank
1825  auto hasRank = [](const Type type) {
1826  if (auto shaped_type = dyn_cast<ShapedType>(type))
1827  return shaped_type.hasRank();
1828 
1829  return false;
1830  };
1831 
1832  auto rankedOperandTypes =
1833  llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));
1834 
1835  auto rankedResultTypes =
1836  llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
1837 
1838  // If all operands and results are unranked, then no further verification.
1839  if (rankedOperandTypes.empty() && rankedResultTypes.empty())
1840  return success();
1841 
1842  // delegate function that returns rank of shaped type with known rank
1843  auto getRank = [](const Type type) {
1844  return cast<ShapedType>(type).getRank();
1845  };
1846 
1847  auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
1848  : getRank(*rankedResultTypes.begin());
1849 
1850  for (size_t i = 0; i < 2; ++i) {
1851  if (rank != getRank(rankedOperandTypes[i])) {
1852  return emitOpError("operands don't have matching ranks");
1853  }
1854  }
1855 
1856  for (const auto type : rankedResultTypes) {
1857  if (rank != getRank(type)) {
1858  return emitOpError("result type has different rank than operands");
1859  }
1860  }
1861 
1862  // check for broadcast compatible shapes in first two operands (ignoring
1863  // shift)
1864 
1865  // delegate function that returns shape of shaped type
1866  auto getShape = [](const Type type) {
1867  return mlir::cast<ShapedType>(type).getShape();
1868  };
1869  SmallVector<int64_t> resultShape;
1870  if (!mlir::OpTrait::util::getBroadcastedShape(getShape(rankedOperandTypes[0]),
1871  getShape(rankedOperandTypes[1]),
1872  resultShape)) {
1873  return emitOpError("operands don't have broadcast-compatible shapes");
1874  }
1875 
1876  return success();
1877 }
1878 
1879 LogicalResult tosa::TableOp::inferReturnTypeComponents(
1880  MLIRContext *context, ::std::optional<Location> location,
1881  TableOp::Adaptor adaptor,
1882  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1883  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1884 
1885  if (!inputShape.hasRank()) {
1886  inferredReturnShapes.push_back(ShapedTypeComponents());
1887  return success();
1888  }
1889 
1890  inferredReturnShapes.resize(1);
1891  inputShape.getDims(inferredReturnShapes[0]);
1892  return success();
1893 }
1894 
1895 LogicalResult tosa::TableOp::verify() {
1896  TensorType inputType = getInput1().getType();
1897  TensorType outputType = getOutput().getType();
1898 
1899  if (inputType.hasRank() && outputType.hasRank() &&
1900  inputType.getRank() != outputType.getRank())
1901  return emitOpError()
1902  << "expected input tensor rank to equal result tensor rank";
1903 
1904  auto inputDims = inputType.getShape();
1905  auto outputDims = outputType.getShape();
1906  for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
1907  int64_t dim = it.index();
1908  auto [inputDim, outputDim] = it.value();
1909  if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) {
1910  return emitOpError() << "dim(result, " << dim << ") = " << outputDim
1911  << " doesn't match dim(input, " << dim
1912  << ") = " << inputDim;
1913  }
1914  }
1915  return success();
1916 }
1917 
1918 LogicalResult
1919 tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
1920  // Multiples must be constants.
1921  DenseIntElementsAttr multiplesAttr;
1922  if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
1923  return failure();
1924  multiples = llvm::to_vector(
1925  llvm::map_range(multiplesAttr.getValues<APInt>(),
1926  [](const APInt &val) { return val.getSExtValue(); }));
1927  return success();
1928 }
1929 
1930 LogicalResult tosa::TileOp::inferReturnTypeComponents(
1931  MLIRContext *context, ::std::optional<Location> location,
1932  TileOp::Adaptor adaptor,
1933  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1934  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
1935  SmallVector<int64_t> multiples;
1936  if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
1937  multiples)) {
1938  auto rank =
1939  cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
1940  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
1941  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
1942  return success();
1943  } else {
1944  multiples = convertToMlirShape(multiples);
1945  }
1946 
1947  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1948  SmallVector<int64_t> outputShape;
1949  if (!inputShape.hasRank()) {
1950  outputShape.resize(multiples.size(), ShapedType::kDynamic);
1951  inferredReturnShapes.push_back(
1952  ShapedTypeComponents(outputShape, inputType));
1953  return success();
1954  } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
1955  return failure();
1956 
1957  // Any non dynamic dimension can be multiplied to a known size.
1958  outputShape.reserve(multiples.size());
1959  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1960  if (multiples[i] == ShapedType::kDynamic) {
1961  outputShape.push_back(ShapedType::kDynamic);
1962  } else {
1963  int64_t dim = inputShape.getDimSize(i);
1964  if (dim != ShapedType::kDynamic)
1965  dim *= multiples[i];
1966  outputShape.push_back(dim);
1967  }
1968  }
1969 
1970  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1971  return success();
1972 }
1973 
1974 LogicalResult tosa::TileOp::verify() {
1975  if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
1976  /* outType = */ getOutput().getType())
1977  .failed()) {
1978  return failure();
1979  }
1980  ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
1981  ShapedType outputType = llvm::cast<ShapedType>(getType());
1982 
1983  shapeType multiplesType =
1984  llvm::cast<tosa::shapeType>(getMultiples().getType());
1985 
1986  auto multiplesRank = multiplesType.getRank();
1987 
1988  if (inputType.hasRank()) {
1989  if (inputType.getRank() != multiplesRank)
1990  return emitOpError("expect 'multiples' to have rank ")
1991  << inputType.getRank() << " but got " << multiplesRank << ".";
1992  if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
1993  return emitOpError("expect same input and output tensor rank.");
1994  } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
1995  return emitOpError("expect 'multiples' array to have length ")
1996  << outputType.getRank() << " but got " << multiplesRank << ".";
1997 
1998  SmallVector<int64_t> multiples;
1999  if (getConstantMultiples(multiples).succeeded() &&
2000  llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
2001  return emitOpError(
2002  "expect element of 'multiples' to be positive integer or -1.");
2003 
2004  return success();
2005 }
2006 
2007 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2008  if (l.size() != r.size() || l.size() != 1)
2009  return false;
2010  return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
2011 }
2012 
2013 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2014  MLIRContext *context, ::std::optional<Location> location,
2015  ReshapeOp::Adaptor adaptor,
2016  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2017  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2018  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2019  llvm::SmallVector<int64_t> newShapeValue;
2020  if (!tosa::getConstShapeValues(adaptor.getShape().getDefiningOp(),
2021  newShapeValue)) {
2022  auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2023  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2024  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2025  return success();
2026  } else {
2027  newShapeValue = convertToMlirShape(newShapeValue);
2028  }
2029 
2030  // We cannot infer from the total number of elements so we must take the
2031  // shape attribute as exact.
2032  if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2033  inferredReturnShapes.push_back(
2034  ShapedTypeComponents(newShapeValue, inputType));
2035  return success();
2036  }
2037 
2038  // Determine the number of elements covered by the slice of all static
2039  // dimensions. This allows us to infer the length of the remaining dynamic
2040  // dimension.
2041  int64_t numElements = inputShape.getNumElements();
2042  int64_t staticMul = 1;
2043  for (auto val : newShapeValue) {
2044  if (!ShapedType::isDynamic(val)) {
2045  staticMul *= val;
2046  }
2047  }
2048 
2049  // Determine the length of the dynamic dimension.
2050  for (auto &val : newShapeValue) {
2051  if (ShapedType::isDynamic(val))
2052  val = numElements / staticMul;
2053  }
2054 
2055  inferredReturnShapes.push_back(
2056  ShapedTypeComponents(newShapeValue, inputType));
2057  return success();
2058 }
2059 
2060 llvm::LogicalResult tosa::ReshapeOp::verify() {
2061  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2062  /* outType = */ getOutput().getType())
2063  .failed()) {
2064  return failure();
2065  }
2066  TensorType inputType = getInput1().getType();
2067  RankedTensorType outputType = getType();
2068 
2069  SmallVector<int64_t> shapeValues;
2070  if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
2071  // skip following checks if shape is not constant
2072  return mlir::success();
2073  }
2074 
2075  if ((int64_t)shapeValues.size() != outputType.getRank())
2076  return emitOpError() << "new shape does not match result rank";
2077 
2078  for (auto [newShapeDim, outputShapeDim] :
2079  zip(shapeValues, outputType.getShape())) {
2080  if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2081  outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2082  return emitOpError() << "new shape is inconsistent with result shape";
2083 
2084  if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2085  return emitOpError() << "new shape has invalid tensor dimension size "
2086  << newShapeDim;
2087  }
2088 
2089  if (inputType.hasStaticShape()) {
2090  int64_t inputElementsNum = inputType.getNumElements();
2091  if (outputType.hasStaticShape()) {
2092  int64_t outputElementsNum = outputType.getNumElements();
2093  if (inputElementsNum != outputElementsNum) {
2094  return emitOpError() << "cannot reshape " << inputElementsNum
2095  << " elements into " << outputElementsNum;
2096  }
2097  }
2098 
2099  int64_t newShapeElementsNum = std::accumulate(
2100  shapeValues.begin(), shapeValues.end(), 1LL,
2101  [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
2102  bool isStaticNewShape =
2103  llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
2104  if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2105  (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2106  return emitOpError() << "cannot reshape " << inputElementsNum
2107  << " elements into " << newShapeElementsNum;
2108  }
2109  }
2110 
2111  int missingDims = llvm::count(shapeValues, -1);
2112  if (missingDims > 1)
2113  return emitOpError() << "expected at most one target dimension to be -1";
2114 
2115  return mlir::success();
2116 }
2117 
2118 // return failure if val is not a constant
2119 // set zp to -1 if val is non-zero float or val is not integer nor float
2120 // otherwise set zp to val's constant value
2121 static FailureOr<int64_t> getZeroPoint(Value val) {
2122  ElementsAttr zpAttr;
2123  if (!matchPattern(val, m_Constant(&zpAttr))) {
2124  return failure();
2125  }
2126 
2127  Type zpElemType = zpAttr.getElementType();
2128 
2129  if (llvm::isa<FloatType>(zpElemType)) {
2130  if (zpAttr.getValues<APFloat>()[0].isZero()) {
2131  return 0;
2132  }
2133  // return non-zero value to trigger error check
2134  return -1;
2135  }
2136 
2137  if (llvm::isa<IntegerType>(zpElemType)) {
2138  return zpAttr.getValues<APInt>()[0].getSExtValue();
2139  }
2140 
2141  // return non-zero value to trigger error check
2142  return -1;
2143 }
2144 
2145 template <typename T>
2146 static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
2147  const std::string &operand) {
2148  Type zpElemType = getElementTypeOrSelf(val);
2149 
2150  if (!zpElemType.isInteger(8) && zp != 0) {
2151  // convert operand to lower case for error message
2152  std::string lower = operand;
2153  std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
2154  return op.emitOpError()
2155  << lower << " zero point must be zero for non-int8 integer types";
2156  }
2157 
2158  return success();
2159 }
2160 
2161 static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
2162  const int64_t &zp,
2163  const std::string &operand) {
2164  bool isInputZp = (operand == "Input");
2165 
2166  bool tensorUnsigned =
2167  isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2168  StringRef tensorName = isInputZp ? "input" : "output";
2169 
2170  Type zpElemType = getElementTypeOrSelf(zpVal);
2171 
2172  if (zp != 0) {
2173  if (!zpElemType.isInteger(8) &&
2174  !(zpElemType.isInteger(16) && tensorUnsigned)) {
2175  return op.emitOpError()
2176  << "expect " << tensorName << "_zp of 0, got " << zp;
2177  }
2178  if (zpElemType.isInteger(16) && tensorUnsigned &&
2179  zp != static_cast<int16_t>(32768)) {
2180  return op.emitOpError() << "expect " << tensorName
2181  << "_zp of 0 or 32768 for unsigned int16 "
2182  << tensorName << ", got " << zp;
2183  }
2184  }
2185 
2186  return success();
2187 }
2188 
2189 #define ZERO_POINT_HELPER(OP, OPERAND_NAME) \
2190  FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2191  return getZeroPoint(get##OPERAND_NAME##Zp()); \
2192  } \
2193  LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2194  return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2195  }
2196 
2197 ZERO_POINT_HELPER(Conv2DOp, Input)
2198 ZERO_POINT_HELPER(Conv2DOp, Weight)
2199 ZERO_POINT_HELPER(Conv3DOp, Input)
2200 ZERO_POINT_HELPER(Conv3DOp, Weight)
2201 ZERO_POINT_HELPER(DepthwiseConv2DOp, Input)
2202 ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight)
2203 ZERO_POINT_HELPER(TransposeConv2DOp, Input)
2204 ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
2205 ZERO_POINT_HELPER(AvgPool2dOp, Input)
2206 ZERO_POINT_HELPER(AvgPool2dOp, Output)
2207 ZERO_POINT_HELPER(MatMulOp, A)
2208 ZERO_POINT_HELPER(MatMulOp, B)
2209 ZERO_POINT_HELPER(NegateOp, Input1)
2210 ZERO_POINT_HELPER(NegateOp, Output)
2211 ZERO_POINT_HELPER(RescaleOp, Input)
2212 ZERO_POINT_HELPER(RescaleOp, Output)
2213 #undef ZERO_POINT_HELPER
2214 
2215 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2216  MLIRContext *context, ::std::optional<Location> location,
2217  TransposeOp::Adaptor adaptor,
2218  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2219  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2220 
2221  // If input rank and permutation length is unknown, the output rank is
2222  // unknown.
2223  if (!inputShape.hasRank()) {
2224  inferredReturnShapes.push_back(ShapedTypeComponents());
2225  return success();
2226  }
2227 
2228  const auto inputRank = inputShape.getRank();
2229 
2230  // This would imply the number of permutations does not match the rank of
2231  // the input which is illegal.
2232  if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
2233  return failure();
2234  }
2235 
2236  SmallVector<int64_t> outputShape;
2237  // Rank-0 means no permutations matter.
2238  if (inputRank == 0) {
2239  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2240  return success();
2241  }
2242 
2243  // Check whether the input dimensions are all the same.
2244  bool allTheSame = true;
2245  for (int i = 1, s = inputRank; i < s; i++) {
2246  if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
2247  allTheSame = false;
2248  break;
2249  }
2250  }
2251 
2252  // If all of the input dimensions are the same we don't care about the
2253  // permutation.
2254  if (allTheSame) {
2255  outputShape.resize(inputRank, inputShape.getDimSize(0));
2256  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2257  return success();
2258  }
2259 
2260  outputShape.resize(inputRank, ShapedType::kDynamic);
2261 
2262  // Constant permutation values must be within the input rank.
2263  if (llvm::any_of(adaptor.getPerms(),
2264  [inputRank](const auto i) { return i >= inputRank; }))
2265  return failure();
2266 
2267  outputShape.reserve(inputRank);
2268  for (int i = 0, s = inputRank; i < s; i++) {
2269  outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
2270  }
2271 
2272  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2273  return success();
2274 }
2275 
2276 LogicalResult tosa::TransposeOp::verify() {
2277  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2278  /* outType = */ getOutput().getType())
2279  .failed()) {
2280  return failure();
2281  }
2282 
2283  const ShapeAdaptor inputShape(getInput1().getType());
2284  const ShapeAdaptor outputShape(getOutput().getType());
2285 
2286  const llvm::ArrayRef<int32_t> constantPerms = getPerms();
2287 
2288  if (inputShape.hasRank() &&
2289  constantPerms.size() != static_cast<size_t>(inputShape.getRank()))
2290  return emitOpError() << "expected perms attribute to have size "
2291  << inputShape.getRank()
2292  << " (input rank) but got size "
2293  << constantPerms.size();
2294 
2295  if (inputShape.hasRank() && outputShape.hasRank() &&
2296  inputShape.getRank() != outputShape.getRank())
2297  return emitOpError()
2298  << "expected input tensor rank to equal result tensor rank";
2299 
2300  if (outputShape.hasRank() &&
2301  constantPerms.size() != static_cast<size_t>(outputShape.getRank()))
2302  return emitOpError() << "expected perms attribute to have size "
2303  << outputShape.getRank()
2304  << " (output rank) but got size "
2305  << constantPerms.size();
2306 
2307  if (!llvm::all_of(constantPerms,
2308  [&constantPerms](int32_t s) {
2309  return s >= 0 &&
2310  static_cast<size_t>(s) < constantPerms.size();
2311  }) ||
2312  !isPermutationVector(llvm::to_vector(llvm::map_range(
2313  constantPerms, [](int32_t v) -> int64_t { return v; }))))
2314  return emitOpError() << "expected valid permutation indices";
2315 
2316  // ERROR_IF(tensor_size(shape1) != tensor_size(shape))
2317  if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2318  inputShape.getNumElements() != outputShape.getNumElements())
2319  return emitOpError() << "expected input1 and output to have same numbers "
2320  "of elements, got "
2321  << inputShape.getNumElements() << " and "
2322  << outputShape.getNumElements();
2323 
2324  // Verify that the types of the input and output tensors are properly
2325  // permuted.
2326  if (inputShape.hasRank() && outputShape.hasRank()) {
2327  for (auto i = 0; i < outputShape.getRank(); i++) {
2328  if (inputShape.isDynamicDim(constantPerms[i]) ||
2329  outputShape.isDynamicDim(i))
2330  continue;
2331 
2332  if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2333  return emitOpError()
2334  << "expected output tensor dim " << i << " to match "
2335  << "input dim " << constantPerms[i] << " with value of "
2336  << inputShape.getDimSize(constantPerms[i]);
2337  }
2338  }
2339 
2340  return success();
2341 }
2342 
2343 LogicalResult TransposeOp::reifyResultShapes(
2344  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2345 
2346  const llvm::ArrayRef<int32_t> transposePerms = getPerms();
2347 
2348  Value input = getInput1();
2349  auto inputType = cast<TensorType>(input.getType());
2350 
2351  SmallVector<OpFoldResult> returnedDims(inputType.getRank());
2352  for (auto dim : transposePerms) {
2353  int32_t dimInInput = transposePerms[dim];
2354  if (inputType.isDynamicDim(dimInInput))
2355  returnedDims[dim] =
2356  builder.create<tensor::DimOp>(getLoc(), input, dimInInput)
2357  .getResult();
2358  else
2359  returnedDims[dim] =
2360  builder.getIndexAttr(inputType.getDimSize(dimInInput));
2361  }
2362 
2363  reifiedReturnShapes.emplace_back(std::move(returnedDims));
2364  return success();
2365 }
2366 
2367 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2368  MLIRContext *context, ::std::optional<Location> location,
2369  GatherOp::Adaptor adaptor,
2370  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2371  llvm::SmallVector<int64_t> outputShape;
2372  outputShape.resize(3, ShapedType::kDynamic);
2373 
2374  ShapeAdaptor valuesShape(adaptor.getValues().getType());
2375  if (valuesShape.hasRank()) {
2376  outputShape[0] = valuesShape.getDimSize(0);
2377  outputShape[2] = valuesShape.getDimSize(2);
2378  }
2379 
2380  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2381  if (indicesShape.hasRank()) {
2382  if (outputShape[0] == ShapedType::kDynamic)
2383  outputShape[0] = indicesShape.getDimSize(0);
2384  if (outputShape[1] == ShapedType::kDynamic)
2385  outputShape[1] = indicesShape.getDimSize(1);
2386  }
2387 
2388  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2389  return success();
2390 }
2391 
2392 LogicalResult tosa::GatherOp::verify() {
2393  if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
2394  /* outType = */ getOutput().getType())
2395  .failed()) {
2396  return failure();
2397  }
2398 
2399  const ShapeAdaptor valuesShape(getValues().getType());
2400  const ShapeAdaptor indicesShape(getIndices().getType());
2401  const ShapeAdaptor outputShape(getOutput().getType());
2402 
2403  int64_t N = ShapedType::kDynamic;
2404  int64_t W = ShapedType::kDynamic;
2405  int64_t C = ShapedType::kDynamic;
2406 
2407  if (valuesShape.hasRank()) {
2408  N = valuesShape.getDimSize(0);
2409  C = valuesShape.getDimSize(2);
2410  }
2411  if (indicesShape.hasRank()) {
2412  const int64_t indicesN = indicesShape.getDimSize(0);
2413  W = indicesShape.getDimSize(1);
2414  if (N == ShapedType::kDynamic)
2415  N = indicesN;
2416  else if (indicesN != ShapedType::kDynamic && N != indicesN)
2417  return emitOpError() << "requires indices dimension 0 to have size " << N
2418  << ", got " << indicesN;
2419  }
2420  if (outputShape.hasRank()) {
2421  const int64_t outputN = outputShape.getDimSize(0);
2422  const int64_t outputW = outputShape.getDimSize(1);
2423  const int64_t outputC = outputShape.getDimSize(2);
2424  if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2425  N != outputN)
2426  return emitOpError() << "requires output dimension 0 to have size " << N
2427  << ", got " << outputN;
2428 
2429  if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2430  W != outputW)
2431  return emitOpError() << "requires output dimension 1 to have size " << W
2432  << ", got " << outputW;
2433  if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2434  C != outputC)
2435  return emitOpError() << "requires output dimension 2 to have size " << C
2436  << ", got " << outputC;
2437  }
2438  return success();
2439 }
2440 
2441 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2442  MLIRContext *context, ::std::optional<Location> location,
2443  ResizeOp::Adaptor adaptor,
2444  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2445  llvm::SmallVector<int64_t, 4> outputShape;
2446  outputShape.resize(4, ShapedType::kDynamic);
2447 
2448  ShapeAdaptor inputShape(adaptor.getInput().getType());
2449  if (!inputShape.hasRank())
2450  return failure();
2451 
2452  outputShape[0] = inputShape.getDimSize(0);
2453  outputShape[3] = inputShape.getDimSize(3);
2454  int64_t inputHeight = inputShape.getDimSize(1);
2455  int64_t inputWidth = inputShape.getDimSize(2);
2456 
2457  if ((inputHeight == ShapedType::kDynamic) ||
2458  (inputWidth == ShapedType::kDynamic))
2459  return failure();
2460 
2461  SmallVector<int64_t> scaleInt, offsetInt, borderInt;
2462  if (!tosa::getConstShapeValues(adaptor.getScale().getDefiningOp(),
2463  scaleInt) ||
2464  !tosa::getConstShapeValues(adaptor.getOffset().getDefiningOp(),
2465  offsetInt) ||
2466  !tosa::getConstShapeValues(adaptor.getBorder().getDefiningOp(),
2467  borderInt)) {
2468  return failure();
2469  }
2470 
2471  // Compute the output shape based on attributes: scale, offset, and border.
2472  outputShape[1] =
2473  (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2474  scaleInt[1]) +
2475  1;
2476 
2477  outputShape[2] =
2478  (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2479  scaleInt[3]) +
2480  1;
2481 
2482  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2483  return success();
2484 }
2485 
2486 LogicalResult tosa::ResizeOp::verify() {
2487  const Value input = getInput();
2488  const Value output = getOutput();
2489  const RankedTensorType inputType =
2490  llvm::dyn_cast<RankedTensorType>(input.getType());
2491  const RankedTensorType outputType =
2492  llvm::dyn_cast<RankedTensorType>(output.getType());
2493 
2494  if (!inputType)
2495  return emitOpError("expect a ranked input tensor");
2496  if (!outputType)
2497  return emitOpError("expect a ranked output tensor");
2498 
2499  const int64_t oh = outputType.getDimSize(1);
2500  const int64_t ow = outputType.getDimSize(2);
2501  const int64_t ih = inputType.getDimSize(1);
2502  const int64_t iw = inputType.getDimSize(2);
2503 
2504  SmallVector<int64_t> scaleValues;
2505  SmallVector<int64_t> offsetValues;
2506  SmallVector<int64_t> borderValues;
2507  if (!tosa::getConstShapeValues(getScale().getDefiningOp(), scaleValues) ||
2508  !tosa::getConstShapeValues(getOffset().getDefiningOp(), offsetValues) ||
2509  !tosa::getConstShapeValues(getBorder().getDefiningOp(), borderValues)) {
2510  // Skip following checks if shape is not constant
2511  return success();
2512  }
2513 
2514  if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
2515  return emitOpError("expect all scale values to be > 0, got ")
2516  << scaleValues;
2517 
2518  const int64_t scaleYN = scaleValues[0];
2519  const int64_t scaleYD = scaleValues[1];
2520  const int64_t scaleXN = scaleValues[2];
2521  const int64_t scaleXD = scaleValues[3];
2522 
2523  const int64_t offsetY = offsetValues[0];
2524  const int64_t offsetX = offsetValues[1];
2525 
2526  const int64_t borderY = borderValues[0];
2527  const int64_t borderX = borderValues[1];
2528 
2529  // Don't check with input height that could be broadcast (ih != 1)
2530  // since Linalg, a consumer of TOSA, expects broadcasting support
2531  // in resize to be available. Taking the cautious approach for now,
2532  // we can consider removing support for broadcasting later.
2533  if (ih != ShapedType::kDynamic && ih != 1) {
2534  const std::optional<int64_t> calculatedOutHeightMinusOne =
2535  idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
2536  if (!calculatedOutHeightMinusOne.has_value())
2537  return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
2538  "border_y ")
2539  << "to be wholly divisible by scale_y_d, got ((" << ih
2540  << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
2541  << ") / " << scaleYD;
2542  const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
2543  if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
2544  return emitOpError("calculated output height did not match expected: ")
2545  << "calculated=" << calculatedOutHeight << ", expected=" << oh;
2546  }
2547 
2548  // Don't check with input width that could be broadcast (iw != 1)
2549  // since Linalg, a consumer of TOSA, expects broadcasting support
2550  // in resize to be available. Taking the cautious approach for now,
2551  // we can consider removing support for broadcasting later.
2552  if (iw != ShapedType::kDynamic && iw != 1) {
2553  const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
2554  const std::optional<int64_t> calculatedOutWidthMinusOne =
2555  idivCheck(scaledInWidth, scaleXD);
2556  if (!calculatedOutWidthMinusOne.has_value())
2557  return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
2558  "border_x ")
2559  << "to be wholly divisible by scale_x_d, got ((" << iw
2560  << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
2561  << ") / " << scaleXD;
2562  const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
2563  if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
2564  return emitOpError("calculated output width did not match expected: ")
2565  << "calculated=" << calculatedOutWidth << ", expected=" << ow;
2566  }
2567 
2568  return success();
2569 }
2570 
2571 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
2572  MLIRContext *context, ::std::optional<Location> location,
2573  ScatterOp::Adaptor adaptor,
2574  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2575  llvm::SmallVector<int64_t> outputShape;
2576  outputShape.resize(3, ShapedType::kDynamic);
2577 
2578  ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
2579  if (valuesInShape.hasRank()) {
2580  outputShape[0] = valuesInShape.getDimSize(0);
2581  outputShape[1] = valuesInShape.getDimSize(1);
2582  outputShape[2] = valuesInShape.getDimSize(2);
2583  }
2584 
2585  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2586  if (indicesShape.hasRank()) {
2587  if (outputShape[0] == ShapedType::kDynamic)
2588  outputShape[0] = indicesShape.getDimSize(0);
2589  }
2590 
2591  ShapeAdaptor inputShape(adaptor.getInput().getType());
2592  if (inputShape.hasRank()) {
2593  if (outputShape[0] == ShapedType::kDynamic)
2594  outputShape[0] = inputShape.getDimSize(0);
2595  if (outputShape[2] == ShapedType::kDynamic)
2596  outputShape[2] = inputShape.getDimSize(2);
2597  }
2598 
2599  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2600  return success();
2601 }
2602 
2603 LogicalResult tosa::ScatterOp::verify() {
2604  if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
2605  /* outType = */ getValuesOut().getType())
2606  .failed() ||
2607  verifySameElementTypes(*this, /* inType = */ getInput().getType(),
2608  /* outType = */ getValuesOut().getType())
2609  .failed()) {
2610  return failure();
2611  }
2612  return success();
2613 }
2614 
2615 static LogicalResult ReduceInferReturnTypes(
2616  ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
2617  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2618  int64_t axisVal = axis.getValue().getSExtValue();
2619  if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
2620  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
2621  return success();
2622  }
2623 
2624  SmallVector<int64_t> outputShape;
2625  operandShape.getDims(outputShape);
2626  outputShape[axisVal] = 1;
2627  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2628  return success();
2629 }
2630 
2631 #define COMPATIBLE_RETURN_TYPES(OP) \
2632  bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
2633  if (l.size() != r.size() || l.size() != 1) \
2634  return false; \
2635  if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
2636  return false; \
2637  return succeeded(verifyCompatibleShape(l[0], r[0])); \
2638  }
2639 
2640 #define REDUCE_SHAPE_INFER(OP) \
2641  LogicalResult OP::inferReturnTypeComponents( \
2642  MLIRContext *context, ::std::optional<Location> location, \
2643  OP::Adaptor adaptor, \
2644  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2645  Type inputType = \
2646  llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
2647  ShapeAdaptor inputShape(adaptor.getInput().getType()); \
2648  const Properties &prop = adaptor.getProperties(); \
2649  return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
2650  inferredReturnShapes); \
2651  } \
2652  COMPATIBLE_RETURN_TYPES(OP)
2653 
2654 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
2655 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
2656 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
2657 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
2658 REDUCE_SHAPE_INFER(tosa::ReduceProductOp)
2659 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
2660 #undef REDUCE_SHAPE_INFER
2661 COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
2662 #undef COMPATIBLE_RETURN_TYPES
2663 
2664 template <typename T>
2665 static LogicalResult verifyReduceOp(T op) {
2666  // All TOSA reduce Ops have input, output and axis.
2667  TensorType inputType = op.getInput().getType();
2668  TensorType outputType = op.getOutput().getType();
2669  int32_t reduceAxis = op.getAxis();
2670 
2671  if (reduceAxis < 0) {
2672  op.emitOpError("reduce axis must not be negative");
2673  return failure();
2674  }
2675  if (inputType.hasRank()) {
2676  int64_t inputRank = inputType.getRank();
2677  // We allow for a special case where the input/output shape has rank 0 and
2678  // axis is also 0.
2679  if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
2680  op.emitOpError("expect input tensor rank (")
2681  << inputRank << ") to be larger than reduce axis (" << reduceAxis
2682  << ")";
2683  return failure();
2684  }
2685  }
2686  if (outputType.hasRank()) {
2687  int64_t outputRank = outputType.getRank();
2688  if (inputType.hasRank() && outputRank != inputType.getRank()) {
2689  op.emitOpError(
2690  "expect output tensor rank to be equal to input tensor rank");
2691  return failure();
2692  }
2693  if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
2694  op.emitOpError("expect output tensor rank (")
2695  << outputRank << ") to be larger than reduce axis (" << reduceAxis
2696  << ")";
2697  return failure();
2698  }
2699  // We can only verify the reduced dimension size to be 1 if this is not
2700  // the special case of output rank == 0.
2701  if (outputRank != 0) {
2702  auto outputShape = outputType.getShape();
2703  if (!outputType.isDynamicDim(reduceAxis) &&
2704  outputShape[reduceAxis] != 1) {
2705  op.emitOpError("expect reduced dimension size to be 1, got ")
2706  << outputShape[reduceAxis];
2707  return failure();
2708  }
2709  }
2710  }
2711  return success();
2712 }
2713 
2714 LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
2715 LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
2716 LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
2717 LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
2718 LogicalResult tosa::ReduceProductOp::verify() { return verifyReduceOp(*this); }
2719 LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
2720 
2721 static LogicalResult NAryInferReturnTypes(
2722  const ValueShapeRange &operands,
2723  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2724  llvm::SmallVector<int64_t> outShape;
2725  if (resolveBroadcastShape(operands, outShape).failed()) {
2726  inferredReturnShapes.push_back(ShapedTypeComponents());
2727  } else {
2728  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2729  }
2730  return success();
2731 }
2732 
2733 #define NARY_SHAPE_INFER(OP) \
2734  LogicalResult OP::inferReturnTypeComponents( \
2735  MLIRContext *context, ::std::optional<Location> location, \
2736  ValueShapeRange operands, DictionaryAttr attributes, \
2737  OpaqueProperties properties, RegionRange regions, \
2738  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2739  return NAryInferReturnTypes(operands, inferredReturnShapes); \
2740  }
2741 
2742 NARY_SHAPE_INFER(tosa::AbsOp)
2743 NARY_SHAPE_INFER(tosa::AddOp)
2744 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
2745 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
2746 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
2747 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
2748 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
2749 NARY_SHAPE_INFER(tosa::CastOp)
2750 NARY_SHAPE_INFER(tosa::CeilOp)
2751 NARY_SHAPE_INFER(tosa::ClampOp)
2752 NARY_SHAPE_INFER(tosa::ClzOp)
2753 NARY_SHAPE_INFER(tosa::CosOp)
2754 NARY_SHAPE_INFER(tosa::ExpOp)
2755 NARY_SHAPE_INFER(tosa::FloorOp)
2756 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
2757 NARY_SHAPE_INFER(tosa::GreaterOp)
2758 NARY_SHAPE_INFER(tosa::IdentityOp)
2759 NARY_SHAPE_INFER(tosa::IntDivOp)
2760 NARY_SHAPE_INFER(tosa::LogOp)
2761 NARY_SHAPE_INFER(tosa::LogicalAndOp)
2762 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
2763 NARY_SHAPE_INFER(tosa::LogicalNotOp)
2764 NARY_SHAPE_INFER(tosa::LogicalOrOp)
2765 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
2766 NARY_SHAPE_INFER(tosa::LogicalXorOp)
2767 NARY_SHAPE_INFER(tosa::MaximumOp)
2768 NARY_SHAPE_INFER(tosa::MinimumOp)
2769 NARY_SHAPE_INFER(tosa::PowOp)
2770 NARY_SHAPE_INFER(tosa::ReciprocalOp)
2771 NARY_SHAPE_INFER(tosa::ReverseOp)
2772 NARY_SHAPE_INFER(tosa::RsqrtOp)
2773 NARY_SHAPE_INFER(tosa::SinOp)
2774 NARY_SHAPE_INFER(tosa::SelectOp)
2775 NARY_SHAPE_INFER(tosa::SubOp)
2776 NARY_SHAPE_INFER(tosa::TanhOp)
2777 NARY_SHAPE_INFER(tosa::ErfOp)
2778 NARY_SHAPE_INFER(tosa::SigmoidOp)
2779 #undef PRED_SHAPE_INFER
2780 
2781 LogicalResult tosa::NegateOp::inferReturnTypeComponents(
2782  MLIRContext *context, ::std::optional<Location> location,
2783  NegateOp::Adaptor adaptor,
2784  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2785  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2786  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
2787  return success();
2788 }
2789 
2790 LogicalResult tosa::NegateOp::verify() {
2791  // Verify same element type
2792  const Type input1Type = getInput1().getType();
2793  const Type outputType = getOutput().getType();
2794  if (verifySameElementTypes(*this, input1Type, outputType).failed())
2795  return failure();
2796 
2797  // Verify same shape
2798  const SmallVector<Type, 2> types = {input1Type, outputType};
2799  if (failed(verifyCompatibleShapes(types)))
2800  return emitOpError() << "requires the same shape for input1 and output";
2801 
2802  const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
2803  const Type input1ZpEType =
2804  getStorageElementTypeOrSelf(getInput1Zp().getType());
2805  if (input1EType != input1ZpEType) {
2806  return emitOpError("expect both input1 and its zero point are the same "
2807  "element type, got ")
2808  << input1EType << " and " << input1ZpEType;
2809  }
2810  const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
2811  const Type outputZpEType =
2812  getStorageElementTypeOrSelf(getOutputZp().getType());
2813  if (outputEType != outputZpEType) {
2814  return emitOpError("expect both output and its zero point are the same "
2815  "element type, got ")
2816  << outputEType << " and " << outputZpEType;
2817  }
2818 
2819  FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2820  if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
2821  return failure();
2822 
2823  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2824  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
2825  return failure();
2826 
2827  return success();
2828 }
2829 
2830 static LogicalResult poolingInferReturnTypes(
2831  ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
2832  ArrayRef<int64_t> pad,
2833  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2834  llvm::SmallVector<int64_t> outputShape;
2835  outputShape.resize(4, ShapedType::kDynamic);
2836 
2837  // We only know the rank if the input type is unranked.
2838  if (!inputShape) {
2839  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2840  return success();
2841  }
2842 
2843  // Batch and number of channels are identical for pooling layer.
2844  outputShape[0] = inputShape.getDimSize(0);
2845  outputShape[3] = inputShape.getDimSize(3);
2846 
2847  int64_t height = inputShape.getDimSize(1);
2848  int64_t width = inputShape.getDimSize(2);
2849 
2850  if (!ShapedType::isDynamic(height)) {
2851  int64_t padded = height + pad[0] + pad[1] - kernel[0];
2852  outputShape[1] = padded / stride[0] + 1;
2853  }
2854 
2855  if (!ShapedType::isDynamic(width)) {
2856  int64_t padded = width + pad[2] + pad[3] - kernel[1];
2857  outputShape[2] = padded / stride[1] + 1;
2858  }
2859 
2860  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2861  return success();
2862 }
2863 
2864 LogicalResult Conv2DOp::inferReturnTypeComponents(
2865  MLIRContext *context, ::std::optional<Location> location,
2866  Conv2DOp::Adaptor adaptor,
2867  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2868  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
2869 
2870  int64_t inputWidth = ShapedType::kDynamic;
2871  int64_t inputHeight = ShapedType::kDynamic;
2872  int64_t weightWidth = ShapedType::kDynamic;
2873  int64_t weightHeight = ShapedType::kDynamic;
2874 
2875  // Input shape describes input width/height and batch.
2876 
2877  ShapeAdaptor inputShape(adaptor.getInput().getType());
2878  if (inputShape.hasRank()) {
2879  outputShape[0] = inputShape.getDimSize(0);
2880  inputHeight = inputShape.getDimSize(1);
2881  inputWidth = inputShape.getDimSize(2);
2882  }
2883 
2884  // Weight shapes describes the filter width/height and the output channels.
2885  ShapeAdaptor weightShape(adaptor.getWeight().getType());
2886  if (weightShape.hasRank()) {
2887  outputShape[3] = weightShape.getDimSize(0);
2888  weightHeight = weightShape.getDimSize(1);
2889  weightWidth = weightShape.getDimSize(2);
2890  }
2891 
2892  // Bias shape can describe the output channels.
2893  ShapeAdaptor biasShape(adaptor.getBias().getType());
2894  if (biasShape.hasRank()) {
2895  outputShape[3] = ShapedType::isDynamic(outputShape[3])
2896  ? biasShape.getDimSize(0)
2897  : outputShape[3];
2898  }
2899 
2900  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
2901  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
2902  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
2903 
2904  if (!ShapedType::isDynamic(inputHeight) &&
2905  !ShapedType::isDynamic(weightHeight)) {
2906  int64_t inputSize = inputHeight + padding[0] + padding[1];
2907  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
2908  int64_t unstridedResult = inputSize - filterSize + 1;
2909  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
2910  }
2911 
2912  if (!ShapedType::isDynamic(inputWidth) &&
2913  !ShapedType::isDynamic(weightWidth)) {
2914  int64_t inputSize = inputWidth + padding[2] + padding[3];
2915  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
2916  int64_t unstridedResult = inputSize - filterSize + 1;
2917  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
2918  }
2919 
2920  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2921  return success();
2922 }
2923 
2924 LogicalResult Conv2DOp::verify() {
2925  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
2926  verifyConvOpErrorIf(*this).failed())
2927  return failure();
2928  return success();
2929 }
2930 
2931 LogicalResult Conv3DOp::inferReturnTypeComponents(
2932  MLIRContext *context, ::std::optional<Location> location,
2933  Conv3DOp::Adaptor adaptor,
2934  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2935  llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
2936 
2937  int64_t inputWidth = ShapedType::kDynamic;
2938  int64_t inputHeight = ShapedType::kDynamic;
2939  int64_t inputDepth = ShapedType::kDynamic;
2940 
2941  int64_t weightWidth = ShapedType::kDynamic;
2942  int64_t weightHeight = ShapedType::kDynamic;
2943  int64_t weightDepth = ShapedType::kDynamic;
2944 
2945  // Input shape describes input width/height and batch.
2946  ShapeAdaptor inputShape(adaptor.getInput().getType());
2947  if (inputShape.hasRank()) {
2948  outputShape[0] = inputShape.getDimSize(0);
2949  inputDepth = inputShape.getDimSize(1);
2950  inputHeight = inputShape.getDimSize(2);
2951  inputWidth = inputShape.getDimSize(3);
2952  }
2953 
2954  // Weight shapes describes the filter width/height and the output channels.
2955  ShapeAdaptor weightShape(adaptor.getWeight().getType());
2956  if (weightShape.hasRank()) {
2957  outputShape[4] = weightShape.getDimSize(0);
2958  weightDepth = weightShape.getDimSize(1);
2959  weightHeight = weightShape.getDimSize(2);
2960  weightWidth = weightShape.getDimSize(3);
2961  }
2962 
2963  // Bias shape can describe the output channels.
2964  ShapeAdaptor biasShape(adaptor.getBias().getType());
2965  if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
2966  outputShape[4] = biasShape.getDimSize(0);
2967  }
2968 
2969  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
2970  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
2971  llvm::ArrayRef<int64_t> pad = adaptor.getPad();
2972 
2973  if (!ShapedType::isDynamic(inputDepth) &&
2974  !ShapedType::isDynamic(weightDepth)) {
2975  int32_t inputSize = inputDepth + pad[0] + pad[1];
2976  int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
2977  int32_t unstridedResult = inputSize - filterSize + 1;
2978  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
2979  }
2980 
2981  if (!ShapedType::isDynamic(inputHeight) &&
2982  !ShapedType::isDynamic(weightHeight)) {
2983  int32_t inputSize = inputHeight + pad[2] + pad[3];
2984  int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
2985  int32_t unstridedResult = inputSize - filterSize + 1;
2986  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
2987  }
2988 
2989  if (!ShapedType::isDynamic(inputWidth) &&
2990  !ShapedType::isDynamic(weightWidth)) {
2991  int32_t inputSize = inputWidth + pad[4] + pad[5];
2992  int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
2993  int32_t unstridedResult = inputSize - filterSize + 1;
2994  outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
2995  }
2996 
2997  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2998  return success();
2999 }
3000 
3001 LogicalResult Conv3DOp::verify() {
3002  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3003  verifyConvOpErrorIf(*this).failed())
3004  return failure();
3005  return success();
3006 }
3007 
3008 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3009  MLIRContext *context, ::std::optional<Location> location,
3010  AvgPool2dOp::Adaptor adaptor,
3011  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3012  ShapeAdaptor inputShape(adaptor.getInput().getType());
3013  const Properties &prop = adaptor.getProperties();
3014  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3015  inferredReturnShapes);
3016 }
3017 
3018 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3019  MLIRContext *context, ::std::optional<Location> location,
3020  MaxPool2dOp::Adaptor adaptor,
3021  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3022  ShapeAdaptor inputShape(adaptor.getInput().getType());
3023  const Properties &prop = adaptor.getProperties();
3024  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3025  inferredReturnShapes);
3026 }
3027 
3028 LogicalResult MaxPool2dOp::verify() {
3029  if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
3030  /* outType = */ getOutput().getType())))
3031  return failure();
3032 
3033  if (failed(verifyPoolingOp(*this)))
3034  return failure();
3035 
3036  return success();
3037 }
3038 
3039 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3040  MLIRContext *context, ::std::optional<Location> location,
3041  DepthwiseConv2DOp::Adaptor adaptor,
3042  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3043  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3044 
3045  int64_t inputWidth = ShapedType::kDynamic;
3046  int64_t inputHeight = ShapedType::kDynamic;
3047  int64_t inputChannels = ShapedType::kDynamic;
3048 
3049  int64_t weightWidth = ShapedType::kDynamic;
3050  int64_t weightHeight = ShapedType::kDynamic;
3051  int64_t depthChannels = ShapedType::kDynamic;
3052 
3053  // Input shape describes input width/height and batch.
3054  ShapeAdaptor inputShape(adaptor.getInput().getType());
3055  if (inputShape.hasRank()) {
3056  outputShape[0] = inputShape.getDimSize(0);
3057  inputHeight = inputShape.getDimSize(1);
3058  inputWidth = inputShape.getDimSize(2);
3059  inputChannels = inputShape.getDimSize(3);
3060  }
3061 
3062  // Weight shapes describes the filter width/height and the output channels.
3063  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3064  if (weightShape.hasRank()) {
3065  weightHeight = weightShape.getDimSize(0);
3066  weightWidth = weightShape.getDimSize(1);
3067  inputChannels = ShapedType::isDynamic(inputChannels)
3068  ? weightShape.getDimSize(2)
3069  : inputChannels;
3070  depthChannels = weightShape.getDimSize(3);
3071  }
3072 
3073  // If both inputChannels and depthChannels are available we can determine
3074  // the output channels.
3075  if (!ShapedType::isDynamic(inputChannels) &&
3076  !ShapedType::isDynamic(depthChannels)) {
3077  outputShape[3] = inputChannels * depthChannels;
3078  }
3079 
3080  // Bias shape can describe the output channels.
3081  ShapeAdaptor biasShape(adaptor.getBias().getType());
3082  if (biasShape.hasRank()) {
3083  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3084  ? biasShape.getDimSize(0)
3085  : outputShape[3];
3086  }
3087 
3088  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3089  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3090  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3091 
3092  if (!ShapedType::isDynamic(inputHeight) &&
3093  !ShapedType::isDynamic(weightHeight)) {
3094  int64_t inputSize = inputHeight + padding[0] + padding[1];
3095  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3096  int64_t unstridedResult = inputSize - filterSize + 1;
3097  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3098  }
3099 
3100  if (!ShapedType::isDynamic(inputWidth) &&
3101  !ShapedType::isDynamic(weightWidth)) {
3102  int64_t inputSize = inputWidth + padding[2] + padding[3];
3103  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3104  int64_t unstridedResult = inputSize - filterSize + 1;
3105  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3106  }
3107 
3108  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3109  return success();
3110 }
3111 
3112 LogicalResult DepthwiseConv2DOp::verify() {
3113  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3114  verifyConvOpErrorIf(*this).failed())
3115  return failure();
3116  return success();
3117 }
3118 
3119 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3120  MLIRContext *context, ::std::optional<Location> location,
3121  TransposeConv2DOp::Adaptor adaptor,
3122  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3123  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3124 
3125  int64_t inputWidth = ShapedType::kDynamic;
3126  int64_t inputHeight = ShapedType::kDynamic;
3127  int64_t weightWidth = ShapedType::kDynamic;
3128  int64_t weightHeight = ShapedType::kDynamic;
3129 
3130  // Input shape describes input width/height and batch.
3131  ShapeAdaptor inputShape(adaptor.getInput().getType());
3132  if (inputShape.hasRank()) {
3133  outputShape[0] = ShapedType::isDynamic(outputShape[0])
3134  ? inputShape.getDimSize(0)
3135  : outputShape[0];
3136  inputHeight = inputShape.getDimSize(1);
3137  inputWidth = inputShape.getDimSize(2);
3138  }
3139 
3140  // Weight shapes describes the filter width/height and the output channels.
3141  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3142  if (weightShape.hasRank()) {
3143  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3144  ? weightShape.getDimSize(0)
3145  : outputShape[3];
3146  weightHeight = weightShape.getDimSize(1);
3147  weightWidth = weightShape.getDimSize(2);
3148  }
3149 
3150  // Bias shape can describe the output channels.
3151  ShapeAdaptor biasShape(adaptor.getInput().getType());
3152  if (biasShape.hasRank()) {
3153  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3154  ? biasShape.getDimSize(0)
3155  : outputShape[3];
3156  }
3157 
3158  llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
3159  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3160 
3161  if (!ShapedType::isDynamic(inputHeight) &&
3162  !ShapedType::isDynamic(weightHeight)) {
3163  int64_t calculateSize =
3164  (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3165  outputShape[1] =
3166  ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3167  }
3168 
3169  if (!ShapedType::isDynamic(inputWidth) &&
3170  !ShapedType::isDynamic(weightWidth)) {
3171  int64_t calculateSize =
3172  (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3173  outputShape[2] =
3174  ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3175  }
3176 
3177  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3178  return success();
3179 }
3180 
3181 LogicalResult TransposeConv2DOp::verify() {
3182  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
3183  return failure();
3184 
3185  const llvm::ArrayRef<int64_t> strides = getStride();
3186  const int64_t strideY = strides[0];
3187  const int64_t strideX = strides[1];
3188 
3189  if (strideY < 1 || strideX < 1)
3190  return emitOpError("expect all stride values to be >= 1, got [")
3191  << strides << "]";
3192 
3193  const auto checkPadAgainstKernelDim =
3194  [this](int64_t pad_value, int64_t kernel_dim_size,
3195  llvm::StringRef pad_name,
3196  llvm::StringRef kernel_dim_name) -> LogicalResult {
3197  if (pad_value <= -kernel_dim_size)
3198  return emitOpError("expected ")
3199  << pad_name << " > -" << kernel_dim_name
3200  << ", but got: " << pad_name << "=" << pad_value << " and "
3201  << kernel_dim_name << "=" << kernel_dim_size;
3202  return success();
3203  };
3204 
3205  const llvm::ArrayRef<int64_t> padding = getOutPad();
3206  const int64_t outPadTop = padding[0];
3207  const int64_t outPadBottom = padding[1];
3208  const int64_t outPadLeft = padding[2];
3209  const int64_t outPadRight = padding[3];
3210 
3211  const auto weightType =
3212  llvm::dyn_cast<RankedTensorType>(getWeight().getType());
3213 
3214  if (weightType) {
3215  const int64_t kernelHeight = weightType.getDimSize(1);
3216  if (!ShapedType::isDynamic(kernelHeight)) {
3217  if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
3218  "out_pad_top", "KH")))
3219  return failure();
3220 
3221  if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3222  "out_pad_bottom", "KH")))
3223  return failure();
3224  }
3225 
3226  const int64_t kernelWidth = weightType.getDimSize(2);
3227  if (!ShapedType::isDynamic(kernelWidth)) {
3228  if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
3229  "out_pad_left", "KW")))
3230  return failure();
3231 
3232  if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3233  "out_pad_right", "KW")))
3234  return failure();
3235  }
3236  }
3237 
3238  // Rest of the checks depend on the output type being a RankedTensorType
3239  const auto outputType =
3240  llvm::dyn_cast<RankedTensorType>(getOutput().getType());
3241  if (!outputType)
3242  return success();
3243 
3244  const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
3245  if (inputType && weightType) {
3246  const int64_t inputHeight = inputType.getDimSize(1);
3247  const int64_t kernelHeight = weightType.getDimSize(1);
3248  const int64_t outputHeight = outputType.getDimSize(1);
3249 
3250  if (!ShapedType::isDynamic(inputHeight) &&
3251  !ShapedType::isDynamic(outputHeight)) {
3252  if (outputHeight !=
3253  (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3254  return emitOpError(
3255  "dimension mismatch: expected OH == (IH - 1) * stride_y "
3256  "+ out_pad_top + out_pad_bottom + KH, but got ")
3257  << outputHeight << " != (" << inputHeight << " - 1) * "
3258  << strideY << " + " << outPadTop << " + " << outPadBottom
3259  << " + " << kernelHeight;
3260  }
3261 
3262  const int64_t inputWidth = inputType.getDimSize(2);
3263  const int64_t kernelWidth = weightType.getDimSize(2);
3264  const int64_t outputWidth = outputType.getDimSize(2);
3265 
3266  if (!ShapedType::isDynamic(inputWidth) &&
3267  !ShapedType::isDynamic(outputWidth)) {
3268  if (outputWidth !=
3269  (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3270  return emitOpError(
3271  "dimension mismatch: expected OW == (IW - 1) * stride_x "
3272  "+ out_pad_left + out_pad_right + KW, but got ")
3273  << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
3274  << " + " << outPadLeft << " + " << outPadRight << " + "
3275  << kernelWidth;
3276  }
3277  }
3278 
3279  const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().getType());
3280 
3281  if (!biasType)
3282  return success();
3283 
3284  const int64_t biasChannels = biasType.getDimSize(0);
3285 
3286  // Skip further checks if bias is dynamic
3287  if (biasChannels == ShapedType::kDynamic)
3288  return success();
3289 
3290  const int64_t outputChannels = outputType.getDimSize(3);
3291  if (biasChannels != outputChannels && biasChannels != 1)
3292  return emitOpError(
3293  "bias channels expected to be equal to output channels (")
3294  << outputChannels << ") or 1, got " << biasChannels;
3295 
3296  return success();
3297 }
3298 
3299 LogicalResult RescaleOp::verify() {
3300  auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
3301  if (!inputType) {
3302  emitOpError("expect shaped tensor for input, got ") << getInput().getType();
3303  return failure();
3304  }
3305 
3306  auto inputElementType =
3307  getStorageElementTypeOrSelf(inputType.getElementType());
3308  if (!mlir::isa<IntegerType>(inputElementType)) {
3309  emitOpError("expect input to have integer element type, got ")
3310  << inputElementType;
3311  return failure();
3312  }
3313 
3314  auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
3315  if (!outputType) {
3316  emitOpError("expect shaped tensor for output, got ")
3317  << getOutput().getType();
3318  return failure();
3319  }
3320 
3321  auto outputElementType =
3322  getStorageElementTypeOrSelf(outputType.getElementType());
3323  if (!mlir::isa<IntegerType>(outputElementType)) {
3324  emitOpError("expect output to have integer element type, got ")
3325  << outputElementType;
3326  return failure();
3327  }
3328 
3329  if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
3330  .failed())
3331  return failure();
3332 
3333  if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
3334  .failed())
3335  return failure();
3336 
3337  FailureOr<int64_t> maybeIZp = getInputZeroPoint();
3338  if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
3339  return failure();
3340 
3341  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3342  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3343  return failure();
3344 
3345  auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
3346  if (!multiplierType) {
3347  emitOpError("expect shaped tensor for multiplier, got ")
3348  << getMultiplier().getType();
3349  return failure();
3350  }
3351 
3352  auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
3353  if (!shiftType) {
3354  emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
3355  return failure();
3356  }
3357 
3358  // multiplier element type must be i32 for scale32 = true
3359  if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
3360  emitOpError("expect i32 element type for multiplier for scale32=true, got ")
3361  << multiplierType.getElementType();
3362  return failure();
3363  }
3364 
3365  // multiplier element type must be i16 for scale32 = false
3366  if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
3367  emitOpError(
3368  "expect i16 element type for multiplier for scale32=false, got ")
3369  << multiplierType.getElementType();
3370  return failure();
3371  }
3372 
3373  if (!inputType.hasRank())
3374  return success();
3375 
3376  // multiplier/shift must have shape = {numChannels},
3377  // where numChannel is 1 if per_channel = false
3378  // otherwise numChannel is dimension in input shape's last axis
3379  int64_t numChannels = 1;
3380  if (getPerChannel()) {
3381  if (inputType.getRank() < 1) {
3382  emitOpError("requires input to be at least rank 1 when per_channel is "
3383  "true, but got rank ")
3384  << inputType.getRank();
3385  return failure();
3386  }
3387  numChannels = inputType.getDimSize(inputType.getRank() - 1);
3388  }
3389 
3390  if (!multiplierType.hasRank())
3391  return success();
3392 
3393  ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
3394  // multiplier input has rank 1 by dialect definition
3395  if (multiplierShape[0] != ShapedType::kDynamic &&
3396  multiplierShape[0] != numChannels) {
3397  emitOpError("expect shape of { ")
3398  << numChannels << " } for multiplier input, got { "
3399  << multiplierShape[0] << " }";
3400  return failure();
3401  }
3402 
3403  if (!shiftType.hasRank())
3404  return success();
3405 
3406  ArrayRef<int64_t> shiftShape = shiftType.getShape();
3407  // shift input has rank 1 by dialect definition
3408  if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3409  emitOpError("expect shape of { ")
3410  << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
3411  return failure();
3412  }
3413 
3414  return success();
3415 }
3416 
3417 LogicalResult RescaleOp::inferReturnTypeComponents(
3418  MLIRContext *context, ::std::optional<Location> location,
3419  RescaleOp::Adaptor adaptor,
3420  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3421  ShapeAdaptor inputShape(adaptor.getInput().getType());
3422  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3423  return success();
3424 }
3425 
3426 LogicalResult IfOp::inferReturnTypeComponents(
3427  MLIRContext *context, ::std::optional<Location> location,
3428  IfOp::Adaptor adaptor,
3429  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3431  for (Region *region : adaptor.getRegions()) {
3432  for (auto &block : *region)
3433  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3434  yieldOps.push_back(returnOp);
3435  }
3436 
3437  if (yieldOps.empty())
3438  return failure();
3439 
3440  // Get the initial type information for the yield op.
3441  llvm::SmallVector<ValueKnowledge> resultKnowledge;
3442  resultKnowledge.reserve(yieldOps.front().getNumOperands());
3443  for (auto operand : yieldOps.front().getOperands()) {
3444  resultKnowledge.push_back(
3445  ValueKnowledge::getKnowledgeFromType(operand.getType()));
3446  }
3447 
3448  for (auto yieldOp : yieldOps) {
3449  if (resultKnowledge.size() != yieldOp.getNumOperands())
3450  return failure();
3451 
3452  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
3453  int32_t index = it.index();
3454  auto meet = ValueKnowledge::meet(
3455  resultKnowledge[index],
3456  ValueKnowledge::getKnowledgeFromType(it.value().getType()));
3457  if (!meet)
3458  continue;
3459  resultKnowledge[index] = meet;
3460  }
3461  }
3462 
3463  for (const ValueKnowledge &result : resultKnowledge) {
3464  inferredReturnShapes.push_back(result.getShapedTypeComponents());
3465  }
3466 
3467  return success();
3468 }
3469 
3470 LogicalResult WhileOp::inferReturnTypeComponents(
3471  MLIRContext *context, ::std::optional<Location> location,
3472  WhileOp::Adaptor adaptor,
3473  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3475  for (auto &block : adaptor.getBodyGraph())
3476  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3477  yieldOps.push_back(returnOp);
3478 
3479  // TOSA's while must have a tosa.yield as its terminator. If not found this
3480  // tosa.while is invalid.
3481  if (yieldOps.empty())
3482  return failure();
3483 
3484  // Get the initial type information from the operand types.
3485  llvm::SmallVector<ValueKnowledge> resultKnowledge;
3486  resultKnowledge.reserve(yieldOps.front().getNumOperands());
3487  for (auto operand : yieldOps.front().getOperands()) {
3488  resultKnowledge.push_back(
3489  ValueKnowledge::getKnowledgeFromType(operand.getType()));
3490  }
3491 
3492  for (auto yieldOp : yieldOps) {
3493  if (resultKnowledge.size() != yieldOp.getNumOperands())
3494  return failure();
3495 
3496  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
3497  int32_t index = it.index();
3498  if (auto meet = ValueKnowledge::meet(
3499  resultKnowledge[index],
3500  ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
3501  resultKnowledge[index] = meet;
3502  }
3503  }
3504  }
3505 
3506  for (const ValueKnowledge &result : resultKnowledge) {
3507  inferredReturnShapes.push_back(result.getShapedTypeComponents());
3508  }
3509 
3510  return success();
3511 }
3512 
3513 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
3514  if (auto vt = llvm::dyn_cast<VectorType>(getType()))
3515  return llvm::to_vector<4>(vt.getShape());
3516  return std::nullopt;
3517 }
3518 
3519 // parse and print of IfOp refer to the implementation of SCF dialect.
3520 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
3521  // Create the regions for 'then'.
3522  result.regions.reserve(2);
3523  Region *thenRegion = result.addRegion();
3524  Region *elseRegion = result.addRegion();
3525 
3526  auto &builder = parser.getBuilder();
3528  // Create a i1 tensor type for the boolean condition.
3529  Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
3530  if (parser.parseOperand(cond) ||
3531  parser.resolveOperand(cond, i1Type, result.operands))
3532  return failure();
3533  // Parse optional results type list.
3534  if (parser.parseOptionalArrowTypeList(result.types))
3535  return failure();
3536  // Parse the 'then' region.
3537  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
3538  return failure();
3539 
3540  // If we find an 'else' keyword then parse the 'else' region.
3541  if (!parser.parseOptionalKeyword("else")) {
3542  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
3543  return failure();
3544  }
3545 
3546  // Parse the optional attribute list.
3547  if (parser.parseOptionalAttrDict(result.attributes))
3548  return failure();
3549  return success();
3550 }
3551 
3552 void IfOp::print(OpAsmPrinter &p) {
3553  bool printBlockTerminators = false;
3554 
3555  p << " " << getCondition();
3556  if (!getResults().empty()) {
3557  p << " -> (" << getResultTypes() << ")";
3558  // Print yield explicitly if the op defines values.
3559  printBlockTerminators = true;
3560  }
3561  p << ' ';
3562  p.printRegion(getThenGraph(),
3563  /*printEntryBlockArgs=*/false,
3564  /*printBlockTerminators=*/printBlockTerminators);
3565 
3566  // Print the 'else' regions if it exists and has a block.
3567  auto &elseRegion = getElseGraph();
3568  if (!elseRegion.empty()) {
3569  p << " else ";
3570  p.printRegion(elseRegion,
3571  /*printEntryBlockArgs=*/false,
3572  /*printBlockTerminators=*/printBlockTerminators);
3573  }
3574 
3575  p.printOptionalAttrDict((*this)->getAttrs());
3576 }
3577 
3578 LogicalResult IfOp::verify() {
3579  if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
3580  "'then_graph' arguments", getInputList(),
3581  "'input_list'")
3582  .failed())
3583  return failure();
3584 
3585  if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(),
3586  "'else_graph' arguments", getInputList(),
3587  "'input_list'")
3588  .failed())
3589  return failure();
3590 
3591  auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
3592  if (errorIfTypeOrShapeMismatch(*this, thenYield.getInputs(),
3593  "'then_graph' results", getOutputList(),
3594  "'output_list'")
3595  .failed())
3596  return failure();
3597 
3598  auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
3599  if (errorIfTypeOrShapeMismatch(*this, elseYield.getInputs(),
3600  "'else_graph' results", getOutputList(),
3601  "'output_list'")
3602  .failed())
3603  return failure();
3604 
3605  auto condType = getCondition().getType();
3606  if (errorIfShapeNotSizeOne(*this, condType).failed())
3607  return emitOpError() << "'condition' must be a size 1 tensor, got "
3608  << condType;
3609 
3610  return success();
3611 }
3612 
3613 LogicalResult WhileOp::verify() {
3614  if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'",
3615  getOutputList(), "'output_list'")
3616  .failed())
3617  return failure();
3618 
3619  if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(),
3620  "'cond_graph' arguments", getInputList(),
3621  "'input_list'")
3622  .failed())
3623  return failure();
3624 
3625  if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(),
3626  "'body_graph' arguments", getInputList(),
3627  "'input_list'")
3628  .failed())
3629  return failure();
3630 
3631  auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
3632  if (errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
3633  "'body_graph' results", getInputList(),
3634  "'input_list'")
3635  .failed())
3636  return failure();
3637 
3638  // Condition block output must be a single element tensor with a single bool
3639  // value.
3640  auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
3641  if (condYield.getInputs().size() != 1)
3642  return emitOpError() << "require 'cond_graph' only have one result";
3643 
3644  auto condOutType = condYield.getInputs()[0].getType();
3645  if (errorIfShapeNotSizeOne(*this, condOutType).failed())
3646  return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
3647  << condOutType;
3648 
3649  if (!getElementTypeOrSelf(condOutType).isInteger(1))
3650  return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
3651  << condOutType;
3652 
3653  return success();
3654 }
3655 
3656 LogicalResult ReverseOp::verify() {
3657  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
3658  /* outType = */ getOutput().getType())
3659  .failed())
3660  return failure();
3661  TensorType inputType = getInput1().getType();
3662  TensorType outputType = getOutput().getType();
3663  int32_t reverseAxis = getAxis();
3664 
3665  if (reverseAxis < 0)
3666  return emitOpError("expected non-negative reverse axis");
3667  if (inputType.hasRank()) {
3668  int64_t inputRank = inputType.getRank();
3669  // We allow for a special case where the input/output shape has rank 0 and
3670  // axis is also 0.
3671  if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
3672  return emitOpError("expect input tensor rank (")
3673  << inputRank << ") to be larger than reverse axis (" << reverseAxis
3674  << ")";
3675  }
3676  if (outputType.hasRank()) {
3677  int64_t outputRank = outputType.getRank();
3678  if (inputType.hasRank() && outputRank != inputType.getRank())
3679  return emitOpError(
3680  "expect output tensor rank to be equal to input tensor rank");
3681  if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
3682  return emitOpError("expect output tensor rank (")
3683  << outputRank << ") to be larger than reverse axis ("
3684  << reverseAxis << ")";
3685  }
3686  return success();
3687 }
3688 
3689 LogicalResult tosa::SelectOp::verify() {
3690  // verify input2 and input3 have same element type as output
3691  if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(),
3692  /* outType = */ getOutput().getType())
3693  .failed() ||
3694  verifySameElementTypes(*this, /* inType = */ getInput3().getType(),
3695  /* outType = */ getOutput().getType())
3696  .failed()) {
3697  return failure();
3698  }
3699  // verify input1 has element type of bool
3700  auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().getType());
3701  if (!predicateType) {
3702  return emitOpError("expect shaped tensor for input1, got ")
3703  << getInput1().getType();
3704  }
3705  auto predicateElementType = predicateType.getElementType();
3706  if (!predicateElementType.isInteger(1)) {
3707  return emitOpError("expect element type of bool for input1, got ")
3708  << predicateElementType;
3709  }
3710 
3711  return success();
3712 }
3713 
3714 LogicalResult tosa::VariableOp::verify() {
3715  StringRef symName = getName();
3716  FailureOr<tosa::VariableOp> varOp = findVariableDecl(*this, symName);
3717  if (succeeded(varOp))
3718  return emitOpError("illegal to have multiple declaration of '")
3719  << symName << "'";
3720 
3721  return success();
3722 }
3723 
3724 LogicalResult tosa::VariableReadOp::verify() {
3725  if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
3726  .failed())
3727  return failure();
3728 
3729  return success();
3730 }
3731 
3732 LogicalResult tosa::VariableWriteOp::verify() {
3733  if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
3734  .failed())
3735  return failure();
3736 
3737  return success();
3738 }
3739 
3740 // parse and print of WhileOp refer to the implementation of SCF dialect.
3741 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3744  Region *cond = result.addRegion();
3745  Region *body = result.addRegion();
3746 
3747  OptionalParseResult listResult =
3748  parser.parseOptionalAssignmentList(regionArgs, operands);
3749  if (listResult.has_value() && failed(listResult.value()))
3750  return failure();
3751 
3752  FunctionType functionType;
3753  SMLoc typeLoc = parser.getCurrentLocation();
3754  if (failed(parser.parseColonType(functionType)))
3755  return failure();
3756 
3757  result.addTypes(functionType.getResults());
3758 
3759  if (functionType.getNumInputs() != operands.size()) {
3760  return parser.emitError(typeLoc)
3761  << "expected as many input types as operands "
3762  << "(expected " << operands.size() << " got "
3763  << functionType.getNumInputs() << ")";
3764  }
3765 
3766  // Resolve input operands.
3767  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3768  parser.getCurrentLocation(),
3769  result.operands)))
3770  return failure();
3771 
3772  // Propagate the types into the region arguments.
3773  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3774  regionArgs[i].type = functionType.getInput(i);
3775 
3776  return failure(parser.parseRegion(*cond, regionArgs) ||
3777  parser.parseKeyword("do") || parser.parseRegion(*body) ||
3779 }
3780 
3782  Block::BlockArgListType blocksArgs,
3783  ValueRange initializers,
3784  StringRef prefix = "") {
3785  assert(blocksArgs.size() == initializers.size() &&
3786  "expected same length of arguments and initializers");
3787  if (initializers.empty())
3788  return;
3789 
3790  parser << prefix << '(';
3791  llvm::interleaveComma(
3792  llvm::zip(blocksArgs, initializers), parser,
3793  [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
3794  parser << ")";
3795 }
3796 
3797 void WhileOp::print(OpAsmPrinter &parser) {
3798  printInitializationList(parser, getCondGraph().front().getArguments(),
3799  getInputList(), " ");
3800  parser << " : ";
3801  parser.printFunctionalType(getInputList().getTypes(),
3802  getResults().getTypes());
3803  parser << ' ';
3804  parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
3805  parser << " do ";
3806  parser.printRegion(getBodyGraph());
3807  parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3808 }
3809 
3810 // Create a rank-1 const tensor for zero point of the source tensor.
3811 std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
3812  Location loc,
3813  Type srcElemType,
3814  int64_t zp) {
3815  srcElemType = getStorageElementTypeOrSelf(srcElemType);
3816  auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
3817  if (llvm::isa<FloatType>(srcElemType)) {
3818  auto zpAttr = DenseElementsAttr::get(
3819  zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
3820  return builder.create<tosa::ConstOp>(loc, zpType, zpAttr);
3821  }
3822  if (llvm::isa<IntegerType>(srcElemType)) {
3823  auto zpAttr =
3824  DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp));
3825  return builder.create<tosa::ConstOp>(loc, zpType, zpAttr);
3826  }
3827  llvm::errs() << "zero point is not allowed for unsupported data types\n";
3828  return std::nullopt;
3829 }
3830 
3831 //===----------------------------------------------------------------------===//
3832 // TOSA Shape and Shape Operators Helper functions.
3833 //===----------------------------------------------------------------------===//
3834 
3836  return mlir::isa<tosa::shapeType>(t);
3837 }
3838 
3839 LogicalResult
3841  int rank) {
3842  if (rank < 0)
3843  return emitError() << "invalid rank (must be >= 0): " << rank;
3844  return success();
3845 }
3846 
3848  for (auto v : op->getOperands()) {
3849  if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
3850  Operation *definingOp = v.getDefiningOp();
3851  if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
3852  return op->emitOpError("shape operand is not compile time resolvable");
3853  }
3854  }
3855  }
3856  return success();
3857 }
3858 
3860  for (auto type : op->getOperandTypes()) {
3861  if (!mlir::isa<mlir::tosa::shapeType>(type)) {
3862  return op->emitOpError("must have operands with tosa shape type");
3863  }
3864  }
3865  for (auto type : op->getResultTypes()) {
3866  if (!mlir::isa<mlir::tosa::shapeType>(type)) {
3867  return op->emitOpError("must have result with tosa shape type");
3868  }
3869  }
3870  return success();
3871 }
3872 
3873 LogicalResult
3875  if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) ||
3876  failed(verifyTosaShapeOperator(op)))
3877  return failure();
3878 
3879  // delegate function that returns rank of shape type
3880  auto getRank = [](const Type type) {
3881  return mlir::cast<mlir::tosa::shapeType>(type).getRank();
3882  };
3883  auto operandTypes = op->getOperandTypes();
3884  auto resultTypes = op->getResultTypes();
3885 
3886  auto rank = getRank(*op->getOperandTypes().begin());
3887  for (auto type : operandTypes) {
3888  if (getRank(type) != rank) {
3889  return op->emitOpError("operands don't have matching ranks");
3890  }
3891  }
3892  for (auto type : resultTypes) {
3893  if (getRank(type) != rank) {
3894  return op->emitOpError("result shape has different rank than operands");
3895  }
3896  }
3897  return success();
3898 }
3899 
3900 //===----------------------------------------------------------------------===//
3901 // TOSA Shape Operators verify functions.
3902 //===----------------------------------------------------------------------===//
3903 
3904 LogicalResult tosa::ConstShapeOp::verify() {
3905  // check one dimensional rank
3906  auto valuesRank = getValues().getType().getRank();
3907  if (valuesRank != 1)
3908  return emitOpError("expect elements in attribute values with rank 1");
3909  // check that number of elements in values attr equal to rank of result shape
3910  auto count = getValues().getNumElements();
3911  auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
3912  if (!(count == rank || (count == 1 && rank == 0))) {
3913  return emitOpError("expect number of elements in attribute values (")
3914  << count << ") to be equal to the rank (" << rank
3915  << ") for the result shape type";
3916  }
3917  return success();
3918 }
3919 
3920 //===----------------------------------------------------------------------===//
3921 // TOSA Attribute Definitions.
3922 //===----------------------------------------------------------------------===//
3923 
3924 #define GET_ATTRDEF_CLASSES
3925 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
3926 
3927 //===----------------------------------------------------------------------===//
3928 // TOSA Type Definitions.
3929 //===----------------------------------------------------------------------===//
3930 #define GET_TYPEDEF_CLASSES
3931 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
3932 
3933 //===----------------------------------------------------------------------===//
3934 // TOSA Operator Definitions.
3935 //===----------------------------------------------------------------------===//
3936 
3937 #define GET_OP_CLASSES
3938 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition: TosaOps.cpp:978
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType)
Definition: TosaOps.cpp:670
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:2615
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
Definition: TosaOps.cpp:240
static FailureOr< tosa::VariableOp > findVariableDecl(Operation *op, StringRef symName)
Definition: TosaOps.cpp:618
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
Definition: TosaOps.cpp:608
#define REDUCE_SHAPE_INFER(OP)
Definition: TosaOps.cpp:2640
static FailureOr< int64_t > getZeroPoint(Value val)
Definition: TosaOps.cpp:2121
static LogicalResult verifyConvOp(T op)
Definition: TosaOps.cpp:280
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
Definition: TosaOps.cpp:652
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:2830
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition: TosaOps.cpp:1092
static LogicalResult verifyReduceOp(T op)
Definition: TosaOps.cpp:2665
#define NARY_SHAPE_INFER(OP)
Definition: TosaOps.cpp:2733
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition: TosaOps.cpp:956
static LogicalResult verifyConvOpErrorIf(T op)
Definition: TosaOps.cpp:426
static LogicalResult verifyConvOpModes(T op)
Definition: TosaOps.cpp:385
std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition: TosaOps.cpp:223
#define ZERO_POINT_HELPER(OP, OPERAND_NAME)
Definition: TosaOps.cpp:2189
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:2721
#define COMPATIBLE_RETURN_TYPES(OP)
Definition: TosaOps.cpp:2631
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition: TosaOps.cpp:1110
Type getStorageElementTypeOrSelf(Type type)
Definition: TosaOps.cpp:229
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
Definition: TosaOps.cpp:1052
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition: TosaOps.cpp:932
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
Definition: TosaOps.cpp:1007
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
Definition: TosaOps.cpp:566
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition: TosaOps.cpp:1679
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
Definition: TosaOps.cpp:2146
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Definition: TosaOps.cpp:3781
static LogicalResult verifyPoolingOp(T op)
Definition: TosaOps.cpp:733
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
Definition: TosaOps.cpp:1199
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:95
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class indicates that op operates on tosa shape types.
Definition: TosaOps.h:126
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
Definition: Operation.cpp:917
LogicalResult verifyTosaShapeOperator(Operation *op)
Definition: TosaOps.cpp:3859
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
Definition: TosaOps.cpp:3874
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
Definition: TosaOps.cpp:3847
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition: Traits.cpp:60
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Definition: QuantUtils.cpp:198
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
Definition: QuantUtils.cpp:289
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
Definition: QuantUtils.cpp:269
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
Definition: QuantUtils.cpp:162
ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &attr)
Definition: TosaOps.cpp:180
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
Definition: QuantUtils.cpp:214
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
Definition: TosaOps.cpp:3811
bool isa_tosa_shape_type(mlir::Type t)
Definition: TosaOps.cpp:3835
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, Attribute attr)
Definition: TosaOps.cpp:202
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Definition: QuantUtils.cpp:243
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
Definition: TosaOps.cpp:260
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
Definition: ShapeUtils.h:33
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45