MLIR  21.0.0git
TosaOps.cpp
Go to the documentation of this file.
1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://www.mlplatform.org/tosa/tosa_spec.html
12 //
13 //===----------------------------------------------------------------------===//
14 
22 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 #include <numeric>
34 
35 using namespace mlir;
36 using namespace mlir::tosa;
37 
38 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
40 
41 //===----------------------------------------------------------------------===//
42 // Tosa dialect interface includes.
43 //===----------------------------------------------------------------------===//
44 
45 #include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
46 #include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
47 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
48 #include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
49 
50 namespace {
51 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
52 
53 //===----------------------------------------------------------------------===//
54 // Dialect Function Inliner Interface.
55 //===----------------------------------------------------------------------===//
56 struct TosaInlinerInterface : public DialectInlinerInterface {
58 
59  //===--------------------------------------------------------------------===//
60  // Analysis Hooks.
61  //===--------------------------------------------------------------------===//
62 
63  /// All operations can be inlined by default.
64  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
65  IRMapping &map) const final {
66  return true;
67  }
68 
69  /// All regions with If and While parent operators can be inlined.
70  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
71  IRMapping &map) const final {
72  return (isa<tosa::IfOp>(dest->getParentOp()) ||
73  isa<tosa::WhileOp>(dest->getParentOp()));
74  }
75 };
76 
77 /// This class implements the bytecode interface for the Tosa dialect.
78 struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
79  TosaDialectBytecodeInterface(Dialect *dialect)
80  : BytecodeDialectInterface(dialect) {}
81 
82  //===--------------------------------------------------------------------===//
83  // Attributes
84 
85  Attribute readAttribute(DialectBytecodeReader &reader) const override {
86  return ::readAttribute(getContext(), reader);
87  }
88 
89  LogicalResult writeAttribute(Attribute attr,
90  DialectBytecodeWriter &writer) const override {
91  return ::writeAttribute(attr, writer);
92  }
93 
94  //===--------------------------------------------------------------------===//
95  // Types
96 
97  Type readType(DialectBytecodeReader &reader) const override {
98  return ::readType(getContext(), reader);
99  }
100 
101  LogicalResult writeType(Type type,
102  DialectBytecodeWriter &writer) const override {
103  return ::writeType(type, writer);
104  }
105 
106  void writeVersion(DialectBytecodeWriter &writer) const final {
107  // TODO: Populate.
108  }
109 
110  std::unique_ptr<DialectVersion>
111  readVersion(DialectBytecodeReader &reader) const final {
112  // TODO: Populate
113  reader.emitError("Dialect does not support versioning");
114  return nullptr;
115  }
116 
117  LogicalResult upgradeFromVersion(Operation *topLevelOp,
118  const DialectVersion &version) const final {
119  return success();
120  }
121 };
122 
123 } // namespace
124 
125 //===----------------------------------------------------------------------===//
126 // TOSA control flow support.
127 //===----------------------------------------------------------------------===//
128 
129 /// Returns the while loop body.
130 SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
131  return {&getBodyGraph()};
132 }
133 
134 //===----------------------------------------------------------------------===//
135 // TOSA variable operator support.
136 //===----------------------------------------------------------------------===//
137 
139  return to_vector(llvm::map_range(shape, [](int64_t dim) {
140  return dim == -1 ? ShapedType::kDynamic : dim;
141  }));
142 }
143 
144 // returns type of variable op
145 RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
146  Type elementType = variableOp.getType();
147  DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
148  auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
149  return RankedTensorType::get(shape, elementType);
150 }
151 
152 //===----------------------------------------------------------------------===//
153 // Tosa dialect initialization.
154 //===----------------------------------------------------------------------===//
155 
156 void TosaDialect::initialize() {
157  addTypes<
158 #define GET_TYPEDEF_LIST
159 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
160  >();
161  addOperations<
162 #define GET_OP_LIST
163 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
164  >();
165  addAttributes<
166 #define GET_ATTRDEF_LIST
167 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
168  >();
169  addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
170  declarePromisedInterfaces<
171  mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
172  ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
173  LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
174  LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
175  BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
176  NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
177  GreaterEqualOp, MatMulOp>();
178 }
179 
181  Type type, Location loc) {
182  // Tosa dialect constants only support ElementsAttr unlike standard dialect
183  // constant which supports all attributes.
184  if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
185  return builder.create<tosa::ConstShapeOp>(
186  loc, type, llvm::cast<DenseIntElementsAttr>(value));
187  }
188  if (llvm::isa<ElementsAttr>(value))
189  return builder.create<tosa::ConstOp>(loc, type,
190  llvm::cast<ElementsAttr>(value));
191  return nullptr;
192 }
193 
194 //===----------------------------------------------------------------------===//
195 // Parsers and printers
196 //===----------------------------------------------------------------------===//
197 
198 namespace {
199 
200 ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
201  DenseElementsAttr &varShapeAttr,
202  TypeAttr &typeAttr) {
203  if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
204  if (!shapedType.hasRank())
205  return parser.emitError(parser.getCurrentLocation())
206  << "expected ranked type";
207 
208  auto elementType = shapedType.getElementType();
209  typeAttr = TypeAttr::get(elementType);
210  ArrayRef<int64_t> shape = shapedType.getShape();
211  Builder builder(parser.getContext());
212  varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
213  return success();
214  }
215  return parser.emitError(parser.getCurrentLocation())
216  << "expected shaped type";
217 }
218 
219 } // namespace
220 
221 // parses the optional initial value or type for a tosa variable
222 // with initial value:
223 // tosa.variable @name = dense<0.0> : tensor<1x8xf32>
224 //
225 // without initial value:
226 // tosa.variable @name : tensor<1x8xf32>
228  OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
229  Attribute &initialValueAttr) {
230  if (succeeded(parser.parseOptionalEqual())) {
231  if (failed(parser.parseAttribute(initialValueAttr))) {
232  return parser.emitError(parser.getCurrentLocation())
233  << "expected attribute";
234  }
235  if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
236  return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
237  typeAttr);
238  }
239  return parser.emitError(parser.getCurrentLocation())
240  << "expected Typed attr";
241  }
242 
243  initialValueAttr = nullptr;
244  Type parsedType;
245  if (failed(parser.parseColonType(parsedType))) {
246  return parser.emitError(parser.getCurrentLocation())
247  << "expected type after colon";
248  }
249  return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
250 }
251 
253  OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
254  TypeAttr typeAttr, Attribute initialValueAttr) {
255  bool needsSpace = false;
256  if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
257  auto shape =
258  convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
259  Type elementType = typeAttr.getValue();
260  RankedTensorType tensorType =
261  RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
262  auto tensorTypeAttr = TypeAttr::get(tensorType);
263  p << ": ";
264  p.printAttribute(tensorTypeAttr);
265  needsSpace = true; // subsequent attr value needs a space separator
266  }
267  if (initialValueAttr) {
268  if (needsSpace)
269  p << ' ';
270  p << "= ";
271  p.printAttribute(initialValueAttr);
272  }
273 }
274 
275 //===----------------------------------------------------------------------===//
276 // Tosa utilities.
277 //===----------------------------------------------------------------------===//
278 
279 std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
280  if (lhs % rhs != 0)
281  return std::nullopt;
282  return lhs / rhs;
283 }
284 
286  auto srcType = getElementTypeOrSelf(type);
287  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
288  srcType = quantType.getStorageType();
289  return srcType;
290 }
291 
293  return getStorageElementTypeOrSelf(value.getType());
294 }
295 
296 static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
297  Value valZp, StringRef name) {
299  Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
300 
301  bool bothInts =
302  mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
303  bool sameBitWidth =
304  (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
305 
306  if (!bothInts || !sameBitWidth) {
307  return op->emitOpError()
308  << "expected " << name << " and " << name
309  << "_zp to both be integer of the same bitwidth, but got " << eType
310  << " vs. " << eZpType;
311  }
312  return success();
313 }
314 
315 // Create a pad-const const tensor with value of `val` of required data-type
317  Value src, int32_t val) {
318  const auto srcType = getElementTypeOrSelf(src);
319  const auto srcElemType = getStorageElementTypeOrSelf(src);
320  const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
321  const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
322  const auto padConstAttr{
323  llvm::isa<FloatType>(srcElemType)
324  ? DenseElementsAttr::get(padConstEType,
325  builder.getFloatAttr(srcElemType, val))
326  : DenseElementsAttr::get(padConstEType,
327  builder.getIntegerAttr(srcElemType, val))};
328  return builder.create<tosa::ConstOp>(loc, padConstType, padConstAttr);
329 }
330 
331 //===----------------------------------------------------------------------===//
332 // TOSA Operator Verifiers.
333 //===----------------------------------------------------------------------===//
334 
335 template <typename T>
336 static LogicalResult verifyConvOp(T op) {
337  const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
338  const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
339 
340  auto inputEType = inputType.getElementType();
341  auto weightEType = weightType.getElementType();
342  auto biasEType =
343  llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
344  auto resultEType =
345  llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
346  bool biasIsFloat = llvm::isa<FloatType>(biasEType);
347  bool resultIsFloat = llvm::isa<FloatType>(resultEType);
348 
349  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
350  inputEType = quantType.getStorageType();
351 
352  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
353  weightEType = quantType.getStorageType();
354 
355  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
356  biasEType = quantType.getStorageType();
357 
358  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
359  resultEType = quantType.getStorageType();
360 
361  if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
362  // for now, only enforce bias element type == result element type for
363  // float types.
364  op.emitOpError(
365  "expect both bias and result to have same element type, got ")
366  << biasEType << " and " << resultEType;
367  return failure();
368  }
369 
370  if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
371  isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
372  if (inputEType != weightEType) {
373  op.emitOpError(
374  "expect both input and weight to have same element type, got ")
375  << inputEType << " and " << weightEType;
376  return failure();
377  }
378  }
379 
380  bool inputIsFloat = llvm::isa<FloatType>(inputEType);
381  bool weightIsFloat = llvm::isa<FloatType>(weightEType);
382 
383  // Either both must be float or both non-float.
384  if (inputIsFloat != weightIsFloat) {
385  op.emitOpError(
386  "expect both input and weight to be float or not together, got ")
387  << inputEType << " and " << weightEType;
388  return failure();
389  }
390 
391  auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
392  if (inputEType != inputZpEType) {
393  return op.emitOpError("expect both input and its zero point are the same "
394  "element type, got ")
395  << inputEType << " and " << inputZpEType;
396  }
397 
398  auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
399  if (weightEType != weightZpEType) {
400  return op.emitOpError("expect both weight and its zero point are the same "
401  "element type, got ")
402  << weightEType << " and " << weightZpEType;
403  }
404 
405  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
406  if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
407  return failure();
408 
409  FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
410  if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
411  return failure();
412 
413  return success();
414 }
415 
416 LogicalResult tosa::ConstOp::verify() {
417 
418  auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().getType());
419  auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
420 
421  if (!attrType || !outputType) {
422  emitOpError("expected tensors for attr/result type");
423  return failure();
424  }
425 
426  if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
427  outputType.getElementType())) {
428  if (result.getStorageType() == attrType.getElementType())
429  return success();
430  }
431 
432  if (attrType.getElementType() != outputType.getElementType()) {
433  emitOpError("expected same attr/result element types");
434  return failure();
435  }
436 
437  return success();
438 }
439 
440 template <typename T>
441 static LogicalResult verifyConvOpModes(T op) {
442  auto inputEType =
443  llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
444 
445  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
446  inputEType = quantType.getStorageType();
447 
448  auto accType = op.getAccType();
449  if (inputEType.isInteger(8) && !accType.isInteger(32))
450  return op.emitOpError("accumulator type for i8 tensor is not i32");
451 
452  if (inputEType.isInteger(16) && !accType.isInteger(48))
453  return op.emitOpError("accumulator type for i16 tensor is not i48");
454 
455  if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
456  return op.emitOpError("accumulator type for f8 tensor is not f16");
457 
458  if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
459  return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
460 
461  if (inputEType.isBF16() && !accType.isF32())
462  return op.emitOpError("accumulator type for bf16 tensor is not f32");
463 
464  if (inputEType.isF32() && !accType.isF32())
465  return op.emitOpError("accumulator type for f32 tensor is not f32");
466 
467  auto resultEType =
468  llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
469 
470  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
471  resultEType = quantType.getStorageType();
472 
473  return success();
474 }
475 
476 //===----------------------------------------------------------------------===//
477 // ERROR_IF functions.
478 // ERROR_IF is a predicate that must set an error if the condition holds.
479 //===----------------------------------------------------------------------===//
480 
481 template <typename T>
482 static LogicalResult verifyConvOpErrorIf(T op) {
483  llvm::ArrayRef<int64_t> padding = op.getPad();
484  if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
485  return op.emitOpError("expect all padding values to be >= 0, got ")
486  << padding;
487 
488  llvm::ArrayRef<int64_t> strides = op.getStride();
489  if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
490  return op.emitOpError("expect all stride values to be >= 1, got ")
491  << strides;
492 
493  llvm::ArrayRef<int64_t> dilations = op.getDilation();
494  if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
495  return op.emitOpError("expect all dilation values to be >= 1, got ")
496  << dilations;
497 
498  const RankedTensorType outputType =
499  llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
500  if (!outputType)
501  // Skip following checks if output is not ranked
502  return success();
503 
504  const RankedTensorType inputType =
505  llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
506  const RankedTensorType weightType =
507  llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
508 
509  if (inputType && weightType) {
510  const auto verifyOutputSize =
511  [&op](const int64_t inputSize, const int64_t kernelSize,
512  const int64_t outputSize, const int64_t padBefore,
513  const int64_t padAfter, const int64_t stride,
514  const int64_t dilation, const llvm::StringRef dimName,
515  const llvm::StringRef dimAxis,
516  const llvm::StringRef padBeforeName,
517  const llvm::StringRef padAfterName) -> LogicalResult {
518  if (inputSize == ShapedType::kDynamic ||
519  kernelSize == ShapedType::kDynamic)
520  return success();
521 
522  // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
523 
524  const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
525  inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
526  stride);
527  if (!calculatedOutSizeMinusOne.has_value())
528  return op.emitOpError("expected input_")
529  << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
530  << padAfterName << " - (kernel_" << dimName
531  << " - 1) * dilation_" << dimAxis
532  << " to be wholly divisible by stride_" << dimAxis << ", got ("
533  << inputSize << " - 1 + " << padBefore << " + " << padAfter
534  << " - (" << kernelSize << " - 1) * " << dilation << ") / "
535  << stride;
536 
537  const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
538  if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
539  return op.emitOpError("calculated output ")
540  << dimName << " did not match expected: "
541  << "calculated=" << calculatedOutSize
542  << ", expected=" << outputSize;
543 
544  return success();
545  };
546 
547  // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
548  if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
549  if (failed(verifyOutputSize(
550  inputType.getDimSize(1), weightType.getDimSize(1),
551  outputType.getDimSize(1), padding[0], padding[1], strides[0],
552  dilations[0], "height", "y", "top", "bottom")))
553  return failure();
554 
555  if (failed(verifyOutputSize(
556  inputType.getDimSize(2), weightType.getDimSize(2),
557  outputType.getDimSize(2), padding[2], padding[3], strides[1],
558  dilations[1], "width", "x", "left", "right")))
559  return failure();
560  }
561 
562  // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
563  if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
564  if (failed(verifyOutputSize(
565  inputType.getDimSize(1), weightType.getDimSize(0),
566  outputType.getDimSize(1), padding[0], padding[1], strides[0],
567  dilations[0], "height", "y", "top", "bottom")))
568  return failure();
569 
570  if (failed(verifyOutputSize(
571  inputType.getDimSize(2), weightType.getDimSize(1),
572  outputType.getDimSize(2), padding[2], padding[3], strides[1],
573  dilations[1], "width", "x", "left", "right")))
574  return failure();
575  }
576 
577  // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
578  if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
579  if (failed(verifyOutputSize(
580  inputType.getDimSize(1), weightType.getDimSize(1),
581  outputType.getDimSize(1), padding[0], padding[1], strides[0],
582  dilations[0], "depth", "d", "front", "back")))
583  return failure();
584 
585  if (failed(verifyOutputSize(
586  inputType.getDimSize(2), weightType.getDimSize(2),
587  outputType.getDimSize(2), padding[2], padding[3], strides[1],
588  dilations[1], "height", "y", "top", "bottom")))
589  return failure();
590 
591  if (failed(verifyOutputSize(
592  inputType.getDimSize(3), weightType.getDimSize(3),
593  outputType.getDimSize(3), padding[4], padding[5], strides[2],
594  dilations[2], "width", "x", "left", "right")))
595  return failure();
596  }
597  }
598 
599  const RankedTensorType biasType =
600  llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
601  if (!biasType)
602  // Skip following checks if bias is not ranked
603  return success();
604 
605  const int64_t biasChannels = biasType.getDimSize(0);
606  const int64_t outputChannels =
607  outputType.getDimSize(outputType.getRank() - 1);
608  if (biasChannels == ShapedType::kDynamic ||
609  outputChannels == ShapedType::kDynamic)
610  // Skip following checks if biasChannels or outputChannels is dynamic dim
611  return success();
612 
613  if (biasChannels != outputChannels && biasChannels != 1)
614  return op.emitOpError(
615  "bias channels expected to be equal to output channels (")
616  << outputChannels << ") or 1, got " << biasChannels;
617 
618  return success();
619 }
620 
621 // Verify whether same type and shape of the given two types.
622 static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
623  StringRef name1, Type type2,
624  StringRef name2) {
625  auto shapeType1 = dyn_cast<ShapedType>(type1);
626  auto shapeType2 = dyn_cast<ShapedType>(type2);
627  if (!shapeType1 || !shapeType2)
628  return failure();
629 
630  auto elemType1 = shapeType1.getElementType();
631  auto elemType2 = shapeType2.getElementType();
632  if (elemType1 != elemType2)
633  return op->emitOpError()
634  << "require same element type for " << name1 << " (" << elemType1
635  << ") and " << name2 << " (" << elemType2 << ")";
636 
637  if (failed(verifyCompatibleShape(type1, type2)))
638  return op->emitOpError()
639  << "require same shapes for " << name1 << " (" << type1 << ") and "
640  << name2 << " (" << type2 << ")";
641 
642  return success();
643 }
644 
645 // Verify whether same length, type, and shape of the given two tensor lists.
646 static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
647  StringRef name1,
648  ValueRange list2,
649  StringRef name2) {
650  if (list1.size() != list2.size())
651  return op->emitOpError()
652  << "require same number of values in " << name1 << " ("
653  << list1.size() << ") and " << name2 << " (" << list2.size() << ")";
654 
655  for (auto [type1, type2] :
656  llvm::zip_equal(list1.getTypes(), list2.getTypes())) {
657  if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
658  return failure();
659  }
660 
661  return success();
662 }
663 
664 static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
665  ShapeAdaptor shapeAdaptor(type);
666  if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
667  return success();
668 
669  return shapeAdaptor.getNumElements() == 1 ? success() : failure();
670 }
671 
672 // Returns the first declaration point prior to this operation or failure if
673 // not found.
674 static FailureOr<tosa::VariableOp> findVariableDecl(Operation *op,
675  StringRef symName) {
676  ModuleOp module = op->getParentOfType<ModuleOp>();
677  tosa::VariableOp varOp = nullptr;
678 
679  // TODO: Adopt SymbolTable trait to Varible ops.
680  // Currently, the variable's definition point is searched via walk(),
681  // starting from the top-level ModuleOp and stopping at the point of use. Once
682  // TOSA control flow and variable extensions reach the complete state, may
683  // leverage MLIR's Symbol Table functionality to look up symbol and enhance
684  // the search to a TOSA specific graph traversal over the IR structure.
685  module.walk([&](Operation *tempOp) {
686  // Reach this op itself.
687  if (tempOp == op) {
688  return WalkResult::interrupt();
689  }
690 
691  if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
692  if (symName == tosaOp.getName()) {
693  varOp = tosaOp;
694  return WalkResult::interrupt();
695  }
696  }
697 
698  return WalkResult::advance();
699  });
700 
701  if (varOp)
702  return varOp;
703 
704  return failure();
705 }
706 
707 template <typename T>
708 static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
709  StringRef symName = op.getName();
710  FailureOr<tosa::VariableOp> varOp = findVariableDecl(op, symName);
711  if (failed(varOp))
712  return op->emitOpError("'")
713  << symName << "' has not been declared by 'tosa.variable'";
714 
715  // Verify type and shape
716  auto variableType = getVariableType(varOp.value());
717  if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
718  "the input tensor")
719  .failed())
720  return failure();
721 
722  return success();
723 }
724 
725 // verify that inType and outType have same element types
726 template <typename T>
727 static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
728  auto inputType = llvm::dyn_cast<TensorType>(inType);
729  auto outputType = llvm::dyn_cast<TensorType>(outType);
730  if (!inputType) {
731  op.emitOpError("expect shaped tensor for input, got ") << inType;
732  return failure();
733  }
734  if (!outputType) {
735  op.emitOpError("expect shaped tensor for output, got ") << outType;
736  return failure();
737  }
738  auto inputElementType = inputType.getElementType();
739  auto outputElementType = outputType.getElementType();
740  auto inputQuantType =
741  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
742  auto outputQuantType =
743  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
744  if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
745  (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
746  inputElementType != outputElementType) {
747  // only check if both element types are int/index/float/UniformQuantized
748  // eg, not sure how to check quant::QuantizedType
749  // this happens in test_conv2d_q_grouped_convolution in
750  // tfl-to-tosa-pipeline.mlir
751  op.emitOpError("expect input and output to have same element type, got ")
752  << inputElementType << " and " << outputElementType;
753  return failure();
754  }
755  return success();
756 }
757 
758 LogicalResult tosa::ArgMaxOp::verify() {
759  const ShapedType resultType = llvm::cast<ShapedType>(getType());
760 
761  // Ensure output is of 32-bit integer
762  if (const auto resultETy = resultType.getElementType();
763  !resultETy.isIntOrIndex())
764  return emitOpError("result tensor is not of integer type");
765 
766  const auto inputType = llvm::cast<ShapedType>(getInput().getType());
767  if (!inputType.hasRank())
768  return success();
769 
770  // Ensure axis is within the tensor rank
771  const int64_t axis = getAxisAttr().getInt();
772  if (((axis < 0) || axis >= inputType.getRank()))
773  return emitOpError("specified axis is outside the rank of the tensor");
774 
775  if (!resultType.hasRank())
776  return success();
777 
778  const ArrayRef<int64_t> inputShape = inputType.getShape();
779  const ArrayRef<int64_t> outputShape = resultType.getShape();
780  llvm::SmallVector<int64_t> expectedOutputShape(inputShape);
781  expectedOutputShape.erase(expectedOutputShape.begin() + axis);
782  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
783  return emitOpError("expected output shape '")
784  << expectedOutputShape << "', got '" << outputShape << "'";
785 
786  return success();
787 }
788 
789 template <typename T>
790 static LogicalResult verifyPoolingOp(T op) {
791  const llvm::ArrayRef<int64_t> kernel = op.getKernel();
792  if (llvm::any_of(kernel, [](int64_t s) { return s < 1; }))
793  return op.emitOpError("expect all kernel values to be >= 1, got ")
794  << kernel;
795 
796  const llvm::ArrayRef<int64_t> strides = op.getStride();
797  if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
798  return op.emitOpError("expect all stride values to be >= 1, got ")
799  << strides;
800 
801  const llvm::ArrayRef<int64_t> padding = op.getPad();
802  if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
803  return op.emitOpError("expect all padding values to be >= 0, got ")
804  << padding;
805 
806  // Padding must be less than kernel size to avoid a divide-by-zero
807  const int64_t kernelX = kernel[1];
808  const int64_t padLeft = padding[2];
809  const int64_t padRight = padding[3];
810  if (padRight >= kernelX || padLeft >= kernelX)
811  return op.emitOpError("expected left/right padding to be less than the "
812  "width of the kernel, got pad_left=")
813  << padLeft << ", pad_right=" << padRight << ", kernel_x=" << kernelX;
814 
815  const int64_t kernelY = kernel[0];
816  const int64_t padTop = padding[0];
817  const int64_t padBottom = padding[1];
818  if (padTop >= kernelY || padBottom >= kernelY)
819  return op.emitOpError("expected top/bottom padding to be less than the "
820  "height of the kernel, got pad_top=")
821  << padTop << ", pad_bottom=" << padBottom
822  << ", kernel_y=" << kernelY;
823 
824  const auto inputType =
825  llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
826  const auto outputType =
827  llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
828  if (!inputType || !outputType)
829  return success();
830 
831  const auto verifyOutputSize =
832  [&op](const int64_t inputSize, const int64_t outputSize,
833  const int64_t kernelSize, const int64_t strideSize,
834  const int64_t padBefore, const int64_t padAfter,
835  const llvm::StringRef dimName, const llvm::StringRef dimAxis,
836  const llvm::StringRef padBeforeName,
837  const llvm::StringRef padAfterName) -> LogicalResult {
838  if (ShapedType::isDynamic(inputSize))
839  return success();
840 
841  const std::optional<int64_t> calculatedOutSizeMinusOne =
842  idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
843  if (!calculatedOutSizeMinusOne.has_value())
844  return op.emitOpError("expected input_")
845  << dimName << " + pad_" << padBeforeName << " + pad_"
846  << padAfterName << " - kernel_" << dimAxis
847  << " to be wholly divisible by stride_" << dimAxis << ", got ("
848  << inputSize << " + " << padBefore << " + " << padAfter << " - "
849  << kernelSize << ") / " << strideSize;
850 
851  const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
852  if (!ShapedType::isDynamic(outputSize) && calculatedOutSize != outputSize)
853  return op.emitOpError("calculated output ")
854  << dimName << " did not match expected: "
855  << "calculated=" << calculatedOutSize
856  << ", expected=" << outputSize;
857 
858  return success();
859  };
860 
861  if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
862  kernel[0], strides[0], padding[0], padding[1],
863  "height", "y", "top", "bottom")))
864  return failure();
865 
866  if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
867  kernel[1], strides[1], padding[2], padding[3],
868  "width", "x", "left", "right")))
869  return failure();
870 
871  return success();
872 }
873 
874 LogicalResult tosa::AvgPool2dOp::verify() {
875  if (failed(verifyPoolingOp(*this)))
876  return failure();
877 
878  const Type inputETy = getStorageElementTypeOrSelf(getInput().getType());
879  const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType());
880  const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType());
881  const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType());
882 
883  auto accType = getAccType();
884  if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
885  return emitOpError("accumulator type for integer tensor is not i32");
886 
887  if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
888  return emitOpError("accumulator type for f16 tensor is not f16/f32");
889 
890  if (inputETy.isBF16() && !accType.isF32())
891  return emitOpError("accumulator type for bf16 tensor is not f32");
892 
893  if (inputETy.isF32() && !accType.isF32())
894  return emitOpError("accumulator type for f32 tensor is not f32");
895 
896  if (inputETy != inputZpETy)
897  return emitOpError("expect both input and its zero point are the same "
898  "element type, got ")
899  << inputETy << " and " << inputZpETy;
900 
901  if (resultETy != outputZpETy)
902  return emitOpError("expect both output and its zero point are the same "
903  "element type, got ")
904  << resultETy << " and " << outputZpETy;
905 
906  FailureOr<int64_t> maybeIZp = getInputZeroPoint();
907  if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
908  return failure();
909 
910  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
911  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
912  return failure();
913 
914  return success();
915 }
916 
917 LogicalResult tosa::ClampOp::verify() {
918  mlir::Type inputETy =
919  llvm::cast<ShapedType>(getInput().getType()).getElementType();
920  if (auto quantType =
921  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
922  inputETy = quantType.getStorageType();
923  }
924  mlir::Type outputETy =
925  llvm::cast<ShapedType>(getOutput().getType()).getElementType();
926  if (auto quantType =
927  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
928  outputETy = quantType.getStorageType();
929  }
930  if (inputETy != outputETy)
931  return emitOpError("input/output element types are incompatible.");
932 
933  auto maxValAttr = getMaxValAttr();
934  auto minValAttr = getMinValAttr();
935 
936  unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
937 
938  if (inputETy.isInteger(dataTypeBitWidth)) {
939  // if input datatype is integer, check that the min_val/max_val attributes
940  // are integer attributes, and that their type is the same as the input's
941  // datatype
942  auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
943  auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
944  if (!intMaxValAttr || !intMinValAttr ||
945  (intMaxValAttr.getType() != intMinValAttr.getType()) ||
946  (intMaxValAttr.getType() != inputETy))
947  return emitOpError("min/max attributes types are incompatible with "
948  "input/output element types.");
949 
950  const bool isUnsigned = cast<IntegerType>(inputETy).isUnsigned();
951  const APInt minVal = intMinValAttr.getValue();
952  const APInt maxVal = intMaxValAttr.getValue();
953  if (isUnsigned ? maxVal.ult(minVal) : maxVal.slt(minVal))
954  return emitOpError("expected min_val <= max_val, got min_val=")
955  << minValAttr << ", max_val=" << maxValAttr;
956  } else {
957  // otherwise, input datatype is float, check that the min_val/max_val
958  // attributes share the same type and that their type is the same as the
959  // input's datatype
960  auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
961  auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
962  if (!floatMaxValAttr || !floatMinValAttr ||
963  (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
964  (floatMaxValAttr.getType() != inputETy))
965  return emitOpError("min/max attributes types are incompatible with "
966  "input/output element types.");
967 
968  const APFloat minVal = floatMinValAttr.getValue();
969  const APFloat maxVal = floatMaxValAttr.getValue();
970  if (minVal.isNaN() || maxVal.isNaN())
971  return emitOpError("min/max attributes should not be 'NaN', got min_val=")
972  << minValAttr << ", max_val=" << maxValAttr;
973 
974  if (maxVal < minVal)
975  return emitOpError("expected min_val <= max_val, got min_val=")
976  << minValAttr << ", max_val=" << maxValAttr;
977  }
978 
979  return success();
980 }
981 
982 //===----------------------------------------------------------------------===//
983 // TOSA Operator Quantization Builders.
984 //===----------------------------------------------------------------------===//
985 
986 /// This builder is called on all convolution operators except TransposeConv,
987 /// which has specialized output shape semantics. The builder also defines the
988 /// bitwidth of the output given the bit width of the input & weight content.
989 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
990  Type outputType, Value input, Value weight,
991  Value bias, DenseI64ArrayAttr pad,
992  DenseI64ArrayAttr stride,
993  DenseI64ArrayAttr dilation,
994  TypeAttr accType) {
995  auto zps = createZPsAsConst(builder, input, weight);
996  result.addOperands({input, weight, bias, zps.first, zps.second});
997  result.addAttribute("pad", pad);
998  result.addAttribute("stride", stride);
999  result.addAttribute("dilation", dilation);
1000  result.addAttribute("acc_type", accType);
1001  Type finalOutputType = outputType;
1002  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1003  if (quantAttr) {
1004  finalOutputType =
1005  buildConvOpResultTypeInfo(builder, outputType, input, weight);
1006  }
1007  result.addTypes(finalOutputType);
1008 }
1009 
1010 /// Handles tosa.transpose_conv2d which has outpad and output shape
1011 /// attributes.
1012 static void
1014  Type outputType, Value input, Value weight,
1015  Value bias, DenseI64ArrayAttr outpad,
1016  DenseI64ArrayAttr stride, TypeAttr accType) {
1017  auto zps = createZPsAsConst(builder, input, weight);
1018  result.addOperands({input, weight, bias, zps.first, zps.second});
1019  result.addAttribute("out_pad", outpad);
1020  result.addAttribute("stride", stride);
1021  result.addAttribute("acc_type", accType);
1022  Type finalOutputType = outputType;
1023  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1024  if (quantAttr) {
1025  finalOutputType =
1026  buildConvOpResultTypeInfo(builder, outputType, input, weight);
1027  }
1028  result.addTypes(finalOutputType);
1029 }
1030 
1031 /// The tosa.matmul op is also intended to be generated where a fully_connected
1032 /// op must be constructed where the weight is not a constant. In this case,
1033 /// the fully_connected op must be expressed using matmul.
1034 /// TODO: Add link to the leglization document explaining this.
1036  OperationState &result, Type outputType,
1037  Value a, Value b) {
1038  auto zps = createZPsAsConst(builder, a, b);
1039  result.addOperands({a, b, zps.first, zps.second});
1040 
1041  Type finalOutputType{outputType};
1042  if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
1043  auto eType = getStorageElementTypeOrSelf(a.getType());
1044  auto inputBits = eType.getIntOrFloatBitWidth();
1045 
1046  auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1047  assert(outputShapedType && "Output must be a shaped type");
1048 
1049  IntegerType accElementType;
1050  if (inputBits == 16)
1051  accElementType = builder.getIntegerType(48);
1052  else
1053  accElementType = builder.getI32Type();
1054 
1055  finalOutputType = outputShapedType.clone(accElementType);
1056  }
1057  result.addTypes(finalOutputType);
1058 }
1059 
1060 /// Both the tosa.avg_pool2d and unary ops use the same
1061 /// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
1062 /// has additional parameters not part of the unary ops.
1063 static void
1065  Type outputType, Value input,
1066  DenseArrayAttr kernel, DenseArrayAttr stride,
1067  DenseArrayAttr pad, TypeAttr accType) {
1068  const Location loc{result.location};
1069  int64_t inputZp{0};
1070  int64_t outputZp{0};
1071 
1072  if (auto quantAttr =
1073  buildUnaryOpQuantizationAttr(builder, input, outputType)) {
1074  inputZp = quantAttr.getInputZp();
1075  outputZp = quantAttr.getOutputZp();
1076  }
1077  const std::optional<Value> inputZpOp =
1078  createZeroPointTensor(builder, loc, input.getType(), inputZp);
1079  if (!inputZpOp) {
1080  (void)emitError(
1081  loc,
1082  "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1083  }
1084  const std::optional<Value> outputZpOp =
1085  createZeroPointTensor(builder, loc, outputType, outputZp);
1086  if (!outputZpOp) {
1087  (void)emitError(loc, "Failed to create output zero point tensor for "
1088  "quantized AVG_POOL2D op");
1089  }
1090 
1091  if (inputZpOp && outputZpOp) {
1092  result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1093  } else {
1094  // failed to create one or more zero points above: just add input as
1095  // operands this will trigger error in building the op because of missing
1096  // zero points
1097  result.addOperands({input});
1098  }
1099  result.addAttribute("kernel", kernel);
1100  result.addAttribute("stride", stride);
1101  result.addAttribute("pad", pad);
1102  result.addAttribute("acc_type", accType);
1103  result.types.push_back(outputType);
1104 }
1105 
1106 /// This builder is called on single-parameter negate operator
1107 /// to construct input and output zero points based on their
1108 /// types.
1110  OperationState &result, Type outputType,
1111  Value input) {
1112  const Location loc{result.location};
1113  int64_t input1Zp{0};
1114  int64_t outputZp{0};
1115  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
1116  if (quantAttr) {
1117  input1Zp = quantAttr.getInputZp();
1118  outputZp = quantAttr.getOutputZp();
1119  }
1120  const std::optional<Value> input1ZpOp =
1121  createZeroPointTensor(builder, loc, input.getType(), input1Zp);
1122  if (!input1ZpOp) {
1123  (void)emitError(
1124  loc, "Failed to create input1 zero point for quantized NEGATE op");
1125  }
1126 
1127  const std::optional<Value> outputZpOp =
1128  createZeroPointTensor(builder, loc, input.getType(), outputZp);
1129  if (!outputZpOp) {
1130  (void)emitError(
1131  loc, "Failed to create output zero point for quantized NEGATE op");
1132  }
1133 
1134  if (input1ZpOp && outputZpOp) {
1135  result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1136  } else {
1137  // failed to create one or more zero points above: just add input as
1138  // operands. This will trigger error in building the op because of
1139  // missing zero points
1140  result.addOperands({input});
1141  }
1142 
1143  result.types.push_back(outputType);
1144 }
1145 
1146 /// This builder is called on TOSA pad operator that needs to create its own
1147 /// OptionalAttr quantization_attr parameter to scale the padding values
1148 /// correctly. No pad_const is interpreted as zero-padding.
1149 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
1150  Type outputType, Value input,
1151  Value paddings) {
1152  const Location loc{result.location};
1153  int32_t zp{0};
1154  const auto quantAttr = buildPadOpQuantizationAttr(builder, input);
1155  if (quantAttr) {
1156  zp = static_cast<int32_t>(quantAttr.getInputZp());
1157  }
1158  const auto padConstOp{createPadConstTensor(builder, loc, input, zp)};
1159  result.addOperands({input, paddings, padConstOp});
1160  result.types.push_back(outputType);
1161 }
1162 
1163 static void buildVariableOp(OpBuilder &builder, OperationState &result,
1164  StringRef name, Type variableType,
1165  Attribute initialValue) {
1166  const Location loc{result.location};
1167  auto nameAttr = builder.getStringAttr(name);
1168 
1169  auto shapedType = dyn_cast<ShapedType>(variableType);
1170  if (!shapedType) {
1171  (void)emitError(loc, "variable type must be a shaped type");
1172  return;
1173  }
1174  if (!shapedType.hasRank()) {
1175  (void)emitError(loc, "variable type must be a ranked type");
1176  return;
1177  }
1178 
1179  auto elementType = shapedType.getElementType();
1180  auto elementTypeAttr = TypeAttr::get(elementType);
1181  ArrayRef<int64_t> shape = shapedType.getShape();
1182  auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
1183 
1184  result.addAttribute("name", nameAttr);
1185  result.addAttribute("var_shape", varShapeAttr);
1186  result.addAttribute("type", elementTypeAttr);
1187  result.addAttribute("initial_value", initialValue);
1188 }
1189 
1190 //===----------------------------------------------------------------------===//
1191 // TOSA Operator Return Type Inference.
1192 //===----------------------------------------------------------------------===//
1193 
1194 static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
1195  SmallVector<int64_t> &outShape) {
1196  int64_t outRank = 0;
1197  for (int i = 0, e = operands.size(); i != e; ++i) {
1198  auto shape = operands.getShape(i);
1199  if (!shape.hasRank()) {
1200  // TODO(jennik): Update function to have better case handling for
1201  // invalid operands and for ranked tensors.
1202  return failure();
1203  }
1204  outRank = std::max<int64_t>(outRank, shape.getRank());
1205  }
1206 
1207  outShape.resize(outRank, 1);
1208 
1209  for (int i = 0, e = operands.size(); i != e; ++i) {
1210  auto shape = operands.getShape(i);
1211  auto rankDiff = outShape.size() - shape.getRank();
1212 
1213  for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
1214  auto dim1 = outShape[i + rankDiff];
1215  auto dim2 = shape.getDimSize(i);
1216  auto resolvedDim = dim1;
1217 
1218  if (dim1 == 1) {
1219  resolvedDim = dim2;
1220  } else if (dim2 == 1) {
1221  resolvedDim = dim1;
1222  } else if (dim1 != dim2) {
1223  return failure();
1224  }
1225  outShape[i + rankDiff] = resolvedDim;
1226  }
1227  }
1228 
1229  return success();
1230 }
1231 
1232 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1233  MLIRContext *context, ::std::optional<Location> location,
1234  ArgMaxOp::Adaptor adaptor,
1235  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1236  ShapeAdaptor inputShape(adaptor.getInput().getType());
1237  IntegerAttr axis = adaptor.getProperties().axis;
1238  int32_t axisVal = axis.getValue().getSExtValue();
1239 
1240  if (!inputShape.hasRank()) {
1241  inferredReturnShapes.push_back(ShapedTypeComponents());
1242  return success();
1243  }
1244 
1245  SmallVector<int64_t> outShape;
1246  outShape.reserve(inputShape.getRank() - 1);
1247  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1248  if (i == axisVal)
1249  continue;
1250  outShape.push_back(inputShape.getDimSize(i));
1251  }
1252 
1253  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1254  return success();
1255 }
1256 
1257 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1258  MLIRContext *context, ::std::optional<Location> location,
1259  RFFT2dOp::Adaptor adaptor,
1260  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1261  ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1262 
1263  if (!inputShape.hasRank())
1264  return failure();
1265 
1266  llvm::SmallVector<int64_t> outputShape;
1267  outputShape.resize(3, ShapedType::kDynamic);
1268  outputShape[0] = inputShape.getDimSize(0);
1269  outputShape[1] = inputShape.getDimSize(1);
1270  int64_t inWidth = inputShape.getDimSize(2);
1271 
1272  // Note that we can support this calculation symbolically
1273  // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
1274  if (inWidth != ShapedType::kDynamic)
1275  outputShape[2] = inWidth / 2 + 1;
1276 
1277  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1278  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1279 
1280  return success();
1281 }
1282 
1283 static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
1284  const llvm::StringRef dimName) {
1285  const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1286  if (!isPowerOfTwo)
1287  return op->emitOpError("expected ")
1288  << dimName << " to be a power of two, got " << dimSize;
1289 
1290  return success();
1291 }
1292 
1293 LogicalResult tosa::RFFT2dOp::verify() {
1294  const auto outputTypes = getResultTypes();
1295  if (failed(verifyCompatibleShapes(outputTypes)))
1296  return emitOpError("expected output shapes to match, got ") << outputTypes;
1297 
1298  const auto inputType =
1299  llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1300  if (!inputType)
1301  return success();
1302 
1303  const int64_t height = inputType.getDimSize(1);
1304  if (!ShapedType::isDynamic(height) &&
1305  failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1306  return failure();
1307 
1308  const int64_t width = inputType.getDimSize(2);
1309  if (!ShapedType::isDynamic(width) &&
1310  failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1311  return failure();
1312 
1313  const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1314  if (!outputType)
1315  return success();
1316 
1317  // Batch and height input/output dimensions should match
1318  if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
1319  outputType.getShape().drop_back())))
1320  return emitOpError("expected batch and height dimensions of input/output "
1321  "to match, got input=")
1322  << inputType << " output=" << outputType;
1323 
1324  // Output width dimension expected to be input_width / 2 + 1
1325  const int64_t outputWidth = outputType.getDimSize(2);
1326  if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) &&
1327  (outputWidth != (width / 2) + 1))
1328  return emitOpError(
1329  "expected output width to be equal to input_width / 2 + 1, got ")
1330  << outputWidth;
1331 
1332  return success();
1333 }
1334 
1335 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1336  MLIRContext *context, ::std::optional<Location> location,
1337  FFT2dOp::Adaptor adaptor,
1338  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1339  inferredReturnShapes.push_back(
1340  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
1341  inferredReturnShapes.push_back(
1342  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
1343  return success();
1344 }
1345 
1346 LogicalResult tosa::FFT2dOp::verify() {
1347  const auto inputRealType =
1348  llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1349  const auto inputImagType =
1350  llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
1351  if (!inputRealType || !inputImagType)
1352  return success();
1353 
1354  const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
1355  return ShapedType::isDynamic(a) ? a : b;
1356  };
1357 
1358  const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1359  inputImagType.getDimSize(1));
1360  if (!ShapedType::isDynamic(height) &&
1361  failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1362  return failure();
1363 
1364  const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1365  inputImagType.getDimSize(2));
1366  if (!ShapedType::isDynamic(width) &&
1367  failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1368  return failure();
1369 
1370  return success();
1371 }
1372 
1373 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1374  MLIRContext *context, ::std::optional<Location> location,
1375  ConcatOp::Adaptor adaptor,
1376  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1377  // Infer all dimension sizes by reducing based on inputs.
1378  const Properties &prop = adaptor.getProperties();
1379  int32_t axis = prop.axis.getValue().getSExtValue();
1380  llvm::SmallVector<int64_t> outputShape;
1381  bool hasRankedInput = false;
1382  for (auto operand : adaptor.getOperands()) {
1383  ShapeAdaptor operandShape(operand.getType());
1384  if (!operandShape.hasRank())
1385  continue;
1386 
1387  // Copy the Operand's rank.
1388  if (!hasRankedInput)
1389  outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1390 
1391  // Copy shapes until the dim is non-dynamic.
1392  for (int i = 0, s = operandShape.getRank(); i < s; i++) {
1393  if (i == axis || operandShape.isDynamicDim(i))
1394  continue;
1395  if (outputShape[i] == ShapedType::kDynamic)
1396  outputShape[i] = operandShape.getDimSize(i);
1397  if (outputShape[i] != operandShape.getDimSize(i))
1398  return emitOptionalError(location,
1399  "Cannot concat tensors with different sizes"
1400  " on the non-axis dimension ",
1401  i);
1402  }
1403 
1404  hasRankedInput = true;
1405  }
1406 
1407  if (adaptor.getInput1().empty())
1408  return failure();
1409 
1410  Type inputType =
1411  llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1412  if (!hasRankedInput) {
1413  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1414  return success();
1415  }
1416 
1417  // Determine the dimension size along the concatenation axis.
1418  int64_t concatDimSize = 0;
1419  for (auto operand : adaptor.getOperands()) {
1420  ShapeAdaptor operandShape(operand.getType());
1421 
1422  // We need to know the length of the concatenation axis of all inputs to
1423  // determine the dimension size of the output shape.
1424  if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1425  concatDimSize = ShapedType::kDynamic;
1426  break;
1427  }
1428 
1429  concatDimSize += operandShape.getDimSize(axis);
1430  }
1431 
1432  outputShape[axis] = concatDimSize;
1433 
1434  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1435  return success();
1436 }
1437 
1438 LogicalResult tosa::ConcatOp::verify() {
1439  // check that each input has same element type as output
1440  auto outType = getOutput().getType();
1441  const Operation::operand_range inputList = getInput1();
1442 
1443  // Check there is at least one input
1444  if (inputList.empty())
1445  return emitOpError("expect at least one input");
1446 
1447  if (!llvm::all_of(inputList, [&](auto input) {
1448  return succeeded(verifySameElementTypes(
1449  *this, /* inType = */ input.getType(), outType));
1450  })) {
1451  return failure();
1452  }
1453 
1454  const int32_t axis = getAxis();
1455  ShapeAdaptor firstRankedInputShape = nullptr;
1456  for (const auto &input : inputList) {
1457  const Type inputType = input.getType();
1458  ShapeAdaptor currShape(inputType);
1459  if (currShape.hasRank()) {
1460  firstRankedInputShape = currShape;
1461  // Check axis is in expected range
1462  if (axis < 0 || axis >= firstRankedInputShape.getRank())
1463  return emitOpError("expect axis to be within range 0 < axis < "
1464  "rank(input1[firstRankedTensorIdx]), got ")
1465  << axis;
1466  break;
1467  }
1468  }
1469 
1470  const auto allOperandsHasRank = [](const Value input) {
1471  return ShapeAdaptor(input.getType()).hasRank();
1472  };
1473  if (llvm::all_of(inputList, allOperandsHasRank)) {
1474  const int64_t firstInputRank = firstRankedInputShape.getRank();
1475 
1476  for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {
1477  const ShapeAdaptor inputShape(input.getType());
1478  const int64_t inputRank = inputShape.getRank();
1479  const size_t operandNum = index + 1;
1480 
1481  // Check that each operand has the same rank
1482  if (inputRank != firstInputRank)
1483  return emitOpError(
1484  "expect all operands to have the same rank, but got ")
1485  << firstInputRank << " vs " << inputRank << " on operands 0 and "
1486  << operandNum;
1487 
1488  // Check non-axis dims match
1489  for (int i = 0; i < inputRank; i++) {
1490  const int64_t inputDim = inputShape.getDimSize(i);
1491  const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
1492  if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
1493  inputShape.isDynamicDim(i))
1494  continue;
1495  if (inputDim != firstInputDim)
1496  return emitOpError("expect all operand shapes to have the same sizes "
1497  "on non-axis dimensions, but got ")
1498  << inputDim << " vs " << firstInputDim << " at index " << i
1499  << " on operands 0 and " << operandNum;
1500  }
1501  }
1502 
1503  // ERROR_IF(axis_sum != shape[axis]);
1504  int64_t axisSum = 0;
1505  for (const auto &input : inputList) {
1506  const ShapeAdaptor inputShape(input.getType());
1507  if (inputShape.isDynamicDim(axis)) {
1508  // make axisSum negative to indicate invalid value
1509  axisSum = -1;
1510  break;
1511  }
1512  axisSum += inputShape.getDimSize(axis);
1513  }
1514  const ShapeAdaptor outputShape(outType);
1515  if (axisSum >= 0 && outputShape.hasRank() &&
1516  !outputShape.isDynamicDim(axis) &&
1517  axisSum != outputShape.getDimSize(axis))
1518  return emitOpError("requires sum of axis dimensions of input1 "
1519  "equal to output axis dimension, got ")
1520  << axisSum << " and " << outputShape.getDimSize(axis);
1521  }
1522 
1523  return success();
1524 }
1525 
1526 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1527  MLIRContext *context, ::std::optional<Location> location,
1528  ValueShapeRange operands, DictionaryAttr attributes,
1529  OpaqueProperties properties, RegionRange regions,
1530  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1531  auto elementType = IntegerType::get(context, /*width=*/1);
1532 
1533  llvm::SmallVector<int64_t> outShape;
1534  if (resolveBroadcastShape(operands, outShape).failed()) {
1535  inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
1536  return success();
1537  }
1538 
1539  inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
1540  return success();
1541 }
1542 
1543 bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1544  if (l.size() != r.size() || l.size() != 1)
1545  return false;
1546  return succeeded(verifyCompatibleShape(l[0], r[0]));
1547 }
1548 
1549 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1550  MLIRContext *context, ::std::optional<Location> location,
1551  MatMulOp::Adaptor adaptor,
1552  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1553  ShapeAdaptor lhsShape(adaptor.getA().getType());
1554  ShapeAdaptor rhsShape(adaptor.getB().getType());
1555 
1556  // All shapes are dynamic.
1557  SmallVector<int64_t> outShape;
1558  outShape.resize(3, ShapedType::kDynamic);
1559 
1560  if (lhsShape.hasRank()) {
1561  outShape[0] = lhsShape.getDimSize(0);
1562  outShape[1] = lhsShape.getDimSize(1);
1563  }
1564 
1565  if (rhsShape.hasRank()) {
1566  outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1567  : outShape[0];
1568  outShape[2] = rhsShape.getDimSize(2);
1569  }
1570 
1571  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1572  return success();
1573 }
1574 
1575 LogicalResult MatMulOp::verify() {
1576  auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
1577  auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
1578 
1579  // Must be shaped tensor types
1580  if (!aType)
1581  return emitOpError("expect a shaped tensor for input a, got ")
1582  << getA().getType();
1583 
1584  if (!bType)
1585  return emitOpError("expect a shaped tensor for input b, got ")
1586  << getB().getType();
1587 
1588  auto aElementType = aType.getElementType();
1589  auto bElementType = bType.getElementType();
1590 
1591  auto aQuantizedEType =
1592  llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1593  auto bQuantizedEType =
1594  llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1595 
1596  if (aQuantizedEType || bQuantizedEType) {
1597  if (!aQuantizedEType || !bQuantizedEType) {
1598  return emitOpError("expect operands to be both quantized or both not "
1599  "quantized, got ")
1600  << aElementType << " and " << bElementType;
1601  }
1602  // both a and b have quantized element types
1603  auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1604  auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1605  if (aQuantWidth != bQuantWidth) {
1606  return emitOpError("expect quantized operands to have same widths, got ")
1607  << aQuantWidth << " and " << bQuantWidth;
1608  }
1609  } else {
1610  // non-quantized element types
1611  if (aElementType != bElementType) {
1612  return emitOpError("expect same element type for inputs a and b, got ")
1613  << aElementType << " and " << bElementType;
1614  }
1615  }
1616 
1617  // check a_zp and b_zp
1618  auto aEType = getStorageElementTypeOrSelf(aType);
1619  auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
1620  if (aEType != aZpEType) {
1621  return emitOpError("expect input a and a_zp have the same "
1622  "element type, got ")
1623  << aEType << " and " << aZpEType;
1624  }
1625 
1626  auto bEType = getStorageElementTypeOrSelf(bType);
1627  auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
1628  if (bEType != bZpEType) {
1629  return emitOpError("expect input b and b_zp have the same "
1630  "element type, got ")
1631  << bEType << " and " << bZpEType;
1632  }
1633 
1634  FailureOr<int64_t> maybeAZp = getAZeroPoint();
1635  if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1636  return failure();
1637 
1638  FailureOr<int64_t> maybeBZp = getBZeroPoint();
1639  if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1640  return failure();
1641 
1642  return success();
1643 }
1644 
1645 LogicalResult tosa::PadOp::inferReturnTypeComponents(
1646  MLIRContext *context, ::std::optional<Location> location,
1647  PadOp::Adaptor adaptor,
1648  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1649  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1650  auto paddingRank =
1651  cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
1652  SmallVector<int64_t> outputShape;
1653 
1654  // If the input rank is unknown, we can infer the output rank using the
1655  // padding shape's rank divided by 2.
1656  if (!inputShape.hasRank()) {
1657  outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
1658  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1659  return success();
1660  }
1661 
1662  SmallVector<int64_t> paddingValues;
1663  // If the paddings value is not a constant, all dimensions must be dynamic.
1664  if (!tosa::getConstShapeValues(adaptor.getPadding().getDefiningOp(),
1665  paddingValues)) {
1666  outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1667  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1668  return success();
1669  }
1670 
1671  outputShape.reserve(inputShape.getRank());
1672  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1673  if (inputShape.isDynamicDim(i)) {
1674  outputShape.push_back(ShapedType::kDynamic);
1675  continue;
1676  }
1677  auto padFront = paddingValues[i * 2];
1678  auto padBack = paddingValues[i * 2 + 1];
1679  if (padFront < 0 || padBack < 0) {
1680  // if either padding for dim i is -1, output dim is unknown
1681  outputShape.push_back(ShapedType::kDynamic);
1682  continue;
1683  }
1684 
1685  outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
1686  }
1687 
1688  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1689  return success();
1690 }
1691 
1692 LogicalResult tosa::PadOp::verify() {
1693  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1694  /* outType = */ getOutput().getType())
1695  .failed()) {
1696  return failure();
1697  }
1698 
1699  if (auto padConst = getPadConst()) {
1700  if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
1701  /* outType = */ getOutput().getType())
1702  .failed()) {
1703  return failure();
1704  }
1705  }
1706 
1707  RankedTensorType inputType =
1708  llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1709  RankedTensorType outputType =
1710  llvm::dyn_cast<RankedTensorType>(getOutput().getType());
1711  if (!inputType || !outputType)
1712  return success();
1713 
1714  auto inputRank = inputType.getRank();
1715  auto outputRank = outputType.getRank();
1716  if (inputRank != outputRank)
1717  return emitOpError() << "expect same input and output tensor rank, but got "
1718  << "inputRank: " << inputRank
1719  << ", outputRank: " << outputRank;
1720 
1721  DenseIntElementsAttr paddingAttr;
1722  if (!matchPattern(getPadding(), m_Constant(&paddingAttr))) {
1723  return failure();
1724  }
1725 
1726  auto paddingValues = paddingAttr.getValues<APInt>();
1727  if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
1728  return emitOpError() << "padding tensor must have " << inputRank
1729  << " * 2 = " << inputRank * 2 << " elements, but got "
1730  << paddingValues.size();
1731 
1732  auto inputShape = inputType.getShape();
1733  auto outputShape = outputType.getShape();
1734 
1735  for (int64_t i = 0; i < inputRank; ++i) {
1736  int64_t padStart = paddingValues[i * 2].getSExtValue();
1737  int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
1738 
1739  if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
1740  return emitOpError()
1741  << "invalid padding values at dimension " << i
1742  << ": values must be non-negative or -1 for dynamic padding, got ["
1743  << padStart << ", " << padEnd << "]";
1744  }
1745 
1746  // Skip shape verification for dynamic input/output
1747  if (inputShape[i] == ShapedType::kDynamic ||
1748  outputShape[i] == ShapedType::kDynamic)
1749  continue;
1750 
1751  if (outputShape[i] != inputShape[i] + padStart + padEnd) {
1752  return emitOpError() << "mismatch in output shape at dimension " << i
1753  << ": expected " << inputShape[i] << " + "
1754  << padStart << " + " << padEnd << " = "
1755  << (inputShape[i] + padStart + padEnd)
1756  << ", but got " << outputShape[i];
1757  }
1758  }
1759 
1760  return success();
1761 }
1762 
1763 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
1764  MLIRContext *context, ::std::optional<Location> location,
1765  SliceOp::Adaptor adaptor,
1766  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1767 
1768  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
1769  SmallVector<int64_t> start;
1770  SmallVector<int64_t> size;
1771 
1772  if (!tosa::getConstShapeValues(adaptor.getStart().getDefiningOp(), start) ||
1773  !tosa::getConstShapeValues(adaptor.getSize().getDefiningOp(), size)) {
1774  auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
1775  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
1776  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
1777  return success();
1778  }
1779 
1780  // if size[i] is -1, all remaining elements in dimension i are included
1781  // in the slice, similar to TF.
1782  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1783  // initialize outputShape to all unknown
1784  SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
1785  if (inputShape.hasRank()) {
1786  for (size_t i = 0; i < size.size(); i++) {
1787  if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
1788  (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
1789  start[i] < inputShape.getDimSize(i))) {
1790  // size[i] is not 0 and not < -1, and start[i] is in valid range
1791  if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
1792  // input shape has unknown dim[i] - only valid if size[i] > 0
1793  if (size[i] > 0) {
1794  outputShape[i] = size[i];
1795  }
1796  } else {
1797  // input shape has known dim[i]
1798  if (size[i] == -1) {
1799  outputShape[i] = inputShape.getDimSize(i) - start[i];
1800  } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
1801  // start[i] + size[i] is within bound of input shape's dim[i]
1802  outputShape[i] = size[i];
1803  }
1804  }
1805  }
1806  }
1807  } else {
1808  outputShape = convertToMlirShape(size);
1809  }
1810  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1811  return success();
1812 }
1813 
1814 LogicalResult tosa::SliceOp::verify() {
1815  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1816  /* outType = */ getOutput().getType())
1817  .failed())
1818  return failure();
1819 
1820  const ShapeAdaptor inputShape(getInput1().getType());
1821  if (inputShape.hasRank()) {
1822  const auto inputRank = inputShape.getRank();
1823  const ShapeAdaptor outputShape(getOutput().getType());
1824  if (outputShape.hasRank() && inputRank != outputShape.getRank())
1825  return emitOpError(
1826  "expect input1 and output to have the same ranks, got ")
1827  << inputRank << " and " << outputShape.getRank();
1828 
1829  const auto startShapeRank =
1830  llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
1831  if (inputRank != startShapeRank)
1832  return emitOpError("length of start is not equal to rank of input shape");
1833 
1834  const auto sizeShapeRank =
1835  llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
1836  if (inputRank != sizeShapeRank)
1837  return emitOpError("length of size is not equal to rank of input shape");
1838  }
1839 
1840  return success();
1841 }
1842 
1843 LogicalResult tosa::MulOp::inferReturnTypeComponents(
1844  MLIRContext *context, ::std::optional<Location> location,
1845  ValueShapeRange operands, DictionaryAttr attributes,
1846  OpaqueProperties properties, RegionRange regions,
1847  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1848  // mul op's output shape only depend on input1 and input2, not on shift
1849  ValueShapeRange twoInputs = operands.drop_back();
1850  llvm::SmallVector<int64_t> outShape;
1851  if (resolveBroadcastShape(twoInputs, outShape).failed()) {
1852  inferredReturnShapes.push_back(ShapedTypeComponents());
1853  } else {
1854  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1855  }
1856  return success();
1857 }
1858 
1859 LogicalResult tosa::MulOp::verify() {
1860  const Value output = getOutput();
1861  auto resElemType = getElementTypeOrSelf(output);
1862 
1863  // Verify if the element type among operands and result match tosa
1864  // specification.
1865  if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
1866  IntegerType lhsIntType =
1867  dyn_cast<IntegerType>(getElementTypeOrSelf(getInput1()));
1868  IntegerType rhsIntType =
1869  dyn_cast<IntegerType>(getElementTypeOrSelf(getInput2()));
1870  if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
1871  return emitOpError("requires the same element type for all operands");
1872 
1873  // Though the spec requires the element type of result to be i32, a more
1874  // relaxed way is provided at dialect level for easier cooperating with
1875  // other dialects.
1876  if (lhsIntType.getWidth() > resIntType.getWidth())
1877  return emitOpError("invalid data type size for operands or result");
1878 
1879  } else {
1880  // For other supported type, the spec requires requires the same element
1881  // type for all operands (excludes `shift` operand) and results.
1882  for (int i = 0; i < 2; ++i) {
1883  if (getElementTypeOrSelf(getOperand(i)) != resElemType)
1884  return emitOpError(
1885  "requires the same element type for all operands and results");
1886  }
1887 
1888  // verify shift has value 0 for non-integer types
1889  ElementsAttr shift_elem;
1890  if (matchPattern(getShift(), m_Constant(&shift_elem))) {
1891  int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1892  if (shift != 0) {
1893  return emitOpError() << "require shift to be 0 for float type";
1894  }
1895  }
1896  }
1897 
1898  // Verify the op has same ranks for all main operands (excludes extra operands
1899  // such as shift of mul op, so this is the only difference with the built-in
1900  // `SameOperandsAndResultRank` trait) and results types, if known.
1901  TypeRange operandTypes = getOperandTypes();
1902  ShapedType aType = cast<ShapedType>(operandTypes[0]);
1903  ShapedType bType = cast<ShapedType>(operandTypes[1]);
1904 
1905  const bool aHasRank = aType.hasRank();
1906  const bool bHasRank = bType.hasRank();
1907  if (aHasRank && bHasRank) {
1908  const int64_t aRank = aType.getRank();
1909  const int64_t bRank = bType.getRank();
1910  if (aRank != bRank)
1911  return emitOpError("a and b operands don't have matching ranks, got ")
1912  << aRank << " and " << bRank;
1913 
1914  // check for broadcast compatible shapes
1915  SmallVector<int64_t> resultShape;
1917  aType.getShape(), bType.getShape(), resultShape))
1918  return emitOpError("a and b operands don't have broadcast-compatible "
1919  "shapes, got ")
1920  << aType << " and " << bType;
1921  }
1922 
1923  ShapedType resultType = cast<ShapedType>(output.getType());
1924  if (!resultType.hasRank())
1925  return success();
1926 
1927  const int64_t resultRank = resultType.getRank();
1928  if (aHasRank && resultRank != aType.getRank())
1929  return emitOpError("result type has different rank than a, got ")
1930  << resultRank << " vs " << aType.getRank();
1931  if (bHasRank && resultRank != bType.getRank())
1932  return emitOpError("result type has different rank than b, got ")
1933  << resultRank << " vs " << bType.getRank();
1934 
1935  return success();
1936 }
1937 
1938 LogicalResult tosa::TableOp::inferReturnTypeComponents(
1939  MLIRContext *context, ::std::optional<Location> location,
1940  TableOp::Adaptor adaptor,
1941  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1942  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1943 
1944  if (!inputShape.hasRank()) {
1945  inferredReturnShapes.push_back(ShapedTypeComponents());
1946  return success();
1947  }
1948 
1949  inferredReturnShapes.resize(1);
1950  inputShape.getDims(inferredReturnShapes[0]);
1951  return success();
1952 }
1953 
1954 LogicalResult tosa::TableOp::verify() {
1955  TensorType inputType = getInput1().getType();
1956  TensorType outputType = getOutput().getType();
1957 
1958  if (inputType.hasRank() && outputType.hasRank() &&
1959  inputType.getRank() != outputType.getRank())
1960  return emitOpError()
1961  << "expected input tensor rank to equal result tensor rank";
1962 
1963  auto inputDims = inputType.getShape();
1964  auto outputDims = outputType.getShape();
1965  for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
1966  int64_t dim = it.index();
1967  auto [inputDim, outputDim] = it.value();
1968  if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) {
1969  return emitOpError() << "dim(result, " << dim << ") = " << outputDim
1970  << " doesn't match dim(input, " << dim
1971  << ") = " << inputDim;
1972  }
1973  }
1974  return success();
1975 }
1976 
1977 LogicalResult
1978 tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
1979  // Multiples must be constants.
1980  DenseIntElementsAttr multiplesAttr;
1981  if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
1982  return failure();
1983  multiples = llvm::to_vector(
1984  llvm::map_range(multiplesAttr.getValues<APInt>(),
1985  [](const APInt &val) { return val.getSExtValue(); }));
1986  return success();
1987 }
1988 
1989 LogicalResult tosa::TileOp::inferReturnTypeComponents(
1990  MLIRContext *context, ::std::optional<Location> location,
1991  TileOp::Adaptor adaptor,
1992  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1993  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
1994  SmallVector<int64_t> multiples;
1995  if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
1996  multiples)) {
1997  auto rank =
1998  cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
1999  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2000  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2001  return success();
2002  } else {
2003  multiples = convertToMlirShape(multiples);
2004  }
2005 
2006  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2007  SmallVector<int64_t> outputShape;
2008  if (!inputShape.hasRank()) {
2009  outputShape.resize(multiples.size(), ShapedType::kDynamic);
2010  inferredReturnShapes.push_back(
2011  ShapedTypeComponents(outputShape, inputType));
2012  return success();
2013  } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
2014  return failure();
2015 
2016  // Any non dynamic dimension can be multiplied to a known size.
2017  outputShape.reserve(multiples.size());
2018  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2019  if (multiples[i] == ShapedType::kDynamic) {
2020  outputShape.push_back(ShapedType::kDynamic);
2021  } else {
2022  int64_t dim = inputShape.getDimSize(i);
2023  if (dim != ShapedType::kDynamic)
2024  dim *= multiples[i];
2025  outputShape.push_back(dim);
2026  }
2027  }
2028 
2029  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2030  return success();
2031 }
2032 
2033 LogicalResult tosa::TileOp::verify() {
2034  if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
2035  /* outType = */ getOutput().getType())
2036  .failed()) {
2037  return failure();
2038  }
2039  ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
2040  ShapedType outputType = llvm::cast<ShapedType>(getType());
2041 
2042  shapeType multiplesType =
2043  llvm::cast<tosa::shapeType>(getMultiples().getType());
2044 
2045  auto multiplesRank = multiplesType.getRank();
2046 
2047  if (inputType.hasRank()) {
2048  if (inputType.getRank() != multiplesRank)
2049  return emitOpError("expect 'multiples' to have rank ")
2050  << inputType.getRank() << " but got " << multiplesRank << ".";
2051  if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
2052  return emitOpError("expect same input and output tensor rank.");
2053  } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2054  return emitOpError("expect 'multiples' array to have length ")
2055  << outputType.getRank() << " but got " << multiplesRank << ".";
2056 
2057  SmallVector<int64_t> multiples;
2058  if (getConstantMultiples(multiples).succeeded() &&
2059  llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
2060  return emitOpError(
2061  "expect element of 'multiples' to be positive integer or -1.");
2062 
2063  return success();
2064 }
2065 
2066 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2067  if (l.size() != r.size() || l.size() != 1)
2068  return false;
2069  return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
2070 }
2071 
2072 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2073  MLIRContext *context, ::std::optional<Location> location,
2074  ReshapeOp::Adaptor adaptor,
2075  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2076  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2077  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2078  llvm::SmallVector<int64_t> newShapeValue;
2079  if (!tosa::getConstShapeValues(adaptor.getShape().getDefiningOp(),
2080  newShapeValue)) {
2081  auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2082  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2083  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2084  return success();
2085  } else {
2086  newShapeValue = convertToMlirShape(newShapeValue);
2087  }
2088 
2089  // We cannot infer from the total number of elements so we must take the
2090  // shape attribute as exact.
2091  if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2092  inferredReturnShapes.push_back(
2093  ShapedTypeComponents(newShapeValue, inputType));
2094  return success();
2095  }
2096 
2097  // Determine the number of elements covered by the slice of all static
2098  // dimensions. This allows us to infer the length of the remaining dynamic
2099  // dimension.
2100  int64_t numElements = inputShape.getNumElements();
2101  int64_t staticMul = 1;
2102  for (auto val : newShapeValue) {
2103  if (!ShapedType::isDynamic(val)) {
2104  staticMul *= val;
2105  }
2106  }
2107 
2108  // Determine the length of the dynamic dimension.
2109  for (auto &val : newShapeValue) {
2110  if (ShapedType::isDynamic(val))
2111  val = numElements / staticMul;
2112  }
2113 
2114  inferredReturnShapes.push_back(
2115  ShapedTypeComponents(newShapeValue, inputType));
2116  return success();
2117 }
2118 
2119 llvm::LogicalResult tosa::ReshapeOp::verify() {
2120  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2121  /* outType = */ getOutput().getType())
2122  .failed()) {
2123  return failure();
2124  }
2125  TensorType inputType = getInput1().getType();
2126 
2127  SmallVector<int64_t> shapeValues;
2128  if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
2129  // skip following checks if shape is not constant
2130  return mlir::success();
2131  }
2132 
2133  int missingDims = llvm::count(shapeValues, -1);
2134  if (missingDims > 1)
2135  return emitOpError() << "expected at most one target dimension to be -1";
2136 
2137  const auto outputType = dyn_cast<RankedTensorType>(getType());
2138  if (!outputType)
2139  return success();
2140 
2141  if ((int64_t)shapeValues.size() != outputType.getRank())
2142  return emitOpError() << "new shape does not match result rank";
2143 
2144  for (auto [newShapeDim, outputShapeDim] :
2145  zip(shapeValues, outputType.getShape())) {
2146  if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2147  outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2148  return emitOpError() << "new shape is inconsistent with result shape";
2149 
2150  if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2151  return emitOpError() << "new shape has invalid tensor dimension size "
2152  << newShapeDim;
2153  }
2154 
2155  if (inputType.hasStaticShape()) {
2156  int64_t inputElementsNum = inputType.getNumElements();
2157  if (outputType.hasStaticShape()) {
2158  int64_t outputElementsNum = outputType.getNumElements();
2159  if (inputElementsNum != outputElementsNum) {
2160  return emitOpError() << "cannot reshape " << inputElementsNum
2161  << " elements into " << outputElementsNum;
2162  }
2163  }
2164 
2165  int64_t newShapeElementsNum = std::accumulate(
2166  shapeValues.begin(), shapeValues.end(), 1LL,
2167  [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
2168  bool isStaticNewShape =
2169  llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
2170  if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2171  (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2172  return emitOpError() << "cannot reshape " << inputElementsNum
2173  << " elements into " << newShapeElementsNum;
2174  }
2175  }
2176 
2177  return mlir::success();
2178 }
2179 
2180 // return failure if val is not a constant
2181 // set zp to -1 if val is non-zero float or val is not integer nor float
2182 // otherwise set zp to val's constant value
2183 static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
2184  ElementsAttr zpAttr;
2185  if (!matchPattern(val, m_Constant(&zpAttr))) {
2186  return failure();
2187  }
2188 
2189  Type zpElemType = zpAttr.getElementType();
2190 
2191  if (llvm::isa<FloatType>(zpElemType)) {
2192  if (zpAttr.getValues<APFloat>()[0].isZero()) {
2193  return 0;
2194  }
2195  // return non-zero value to trigger error check
2196  return -1;
2197  }
2198 
2199  if (llvm::isa<IntegerType>(zpElemType)) {
2200  if (signExtend)
2201  return zpAttr.getValues<APInt>()[0].getSExtValue();
2202  else
2203  return zpAttr.getValues<APInt>()[0].getZExtValue();
2204  }
2205 
2206  // return non-zero value to trigger error check
2207  return -1;
2208 }
2209 
2210 template <typename T>
2211 static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
2212  const std::string &operand) {
2213  Type zpElemType = getElementTypeOrSelf(val);
2214 
2215  if (!zpElemType.isInteger(8) && zp != 0) {
2216  // convert operand to lower case for error message
2217  std::string lower = operand;
2218  std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
2219  return op.emitOpError()
2220  << lower << " zero point must be zero for non-int8 integer types";
2221  }
2222 
2223  return success();
2224 }
2225 
2226 static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
2227  const int64_t &zp,
2228  const std::string &operand) {
2229  bool isInputZp = (operand == "Input");
2230 
2231  bool tensorUnsigned =
2232  isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2233  StringRef tensorName = isInputZp ? "input" : "output";
2234 
2235  Type zpElemType = getElementTypeOrSelf(zpVal);
2236 
2237  if (zp != 0) {
2238  if (!zpElemType.isInteger(8) &&
2239  !(zpElemType.isInteger(16) && tensorUnsigned)) {
2240  return op.emitOpError()
2241  << "expect " << tensorName << "_zp of 0, got " << zp;
2242  }
2243  if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
2244  return op.emitOpError() << "expect " << tensorName
2245  << "_zp of 0 or 32768 for unsigned int16 "
2246  << tensorName << ", got " << zp;
2247  }
2248  }
2249 
2250  return success();
2251 }
2252 
2253 #define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2254  FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2255  return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2256  } \
2257  LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2258  return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2259  }
2260 
2261 ZERO_POINT_HELPER(Conv2DOp, Input, true)
2262 ZERO_POINT_HELPER(Conv2DOp, Weight, true)
2263 ZERO_POINT_HELPER(Conv3DOp, Input, true)
2264 ZERO_POINT_HELPER(Conv3DOp, Weight, true)
2265 ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true)
2266 ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true)
2267 ZERO_POINT_HELPER(TransposeConv2DOp, Input, true)
2268 ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true)
2269 ZERO_POINT_HELPER(AvgPool2dOp, Input, true)
2270 ZERO_POINT_HELPER(AvgPool2dOp, Output, true)
2271 ZERO_POINT_HELPER(MatMulOp, A, true)
2272 ZERO_POINT_HELPER(MatMulOp, B, true)
2273 ZERO_POINT_HELPER(NegateOp, Input1, true)
2274 ZERO_POINT_HELPER(NegateOp, Output, true)
2275 ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned())
2276 ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned())
2277 #undef ZERO_POINT_HELPER
2278 
2279 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2280  MLIRContext *context, ::std::optional<Location> location,
2281  TransposeOp::Adaptor adaptor,
2282  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2283  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2284 
2285  // If input rank and permutation length is unknown, the output rank is
2286  // unknown.
2287  if (!inputShape.hasRank()) {
2288  inferredReturnShapes.push_back(ShapedTypeComponents());
2289  return success();
2290  }
2291 
2292  const auto inputRank = inputShape.getRank();
2293 
2294  // This would imply the number of permutations does not match the rank of
2295  // the input which is illegal.
2296  if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
2297  return failure();
2298  }
2299 
2300  SmallVector<int64_t> outputShape;
2301  // Rank-0 means no permutations matter.
2302  if (inputRank == 0) {
2303  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2304  return success();
2305  }
2306 
2307  // Check whether the input dimensions are all the same.
2308  bool allTheSame = true;
2309  for (int i = 1, s = inputRank; i < s; i++) {
2310  if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
2311  allTheSame = false;
2312  break;
2313  }
2314  }
2315 
2316  // If all of the input dimensions are the same we don't care about the
2317  // permutation.
2318  if (allTheSame) {
2319  outputShape.resize(inputRank, inputShape.getDimSize(0));
2320  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2321  return success();
2322  }
2323 
2324  outputShape.resize(inputRank, ShapedType::kDynamic);
2325 
2326  // Constant permutation values must be within the input rank.
2327  if (llvm::any_of(adaptor.getPerms(),
2328  [inputRank](const auto i) { return i >= inputRank; }))
2329  return failure();
2330 
2331  outputShape.reserve(inputRank);
2332  for (int i = 0, s = inputRank; i < s; i++) {
2333  outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
2334  }
2335 
2336  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2337  return success();
2338 }
2339 
2340 LogicalResult tosa::TransposeOp::verify() {
2341  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2342  /* outType = */ getOutput().getType())
2343  .failed()) {
2344  return failure();
2345  }
2346 
2347  const ShapeAdaptor inputShape(getInput1().getType());
2348  const ShapeAdaptor outputShape(getOutput().getType());
2349 
2350  const llvm::ArrayRef<int32_t> constantPerms = getPerms();
2351 
2352  if (inputShape.hasRank() &&
2353  constantPerms.size() != static_cast<size_t>(inputShape.getRank()))
2354  return emitOpError() << "expected perms attribute to have size "
2355  << inputShape.getRank()
2356  << " (input rank) but got size "
2357  << constantPerms.size();
2358 
2359  if (inputShape.hasRank() && outputShape.hasRank() &&
2360  inputShape.getRank() != outputShape.getRank())
2361  return emitOpError()
2362  << "expected input tensor rank to equal result tensor rank";
2363 
2364  if (outputShape.hasRank() &&
2365  constantPerms.size() != static_cast<size_t>(outputShape.getRank()))
2366  return emitOpError() << "expected perms attribute to have size "
2367  << outputShape.getRank()
2368  << " (output rank) but got size "
2369  << constantPerms.size();
2370 
2371  if (!llvm::all_of(constantPerms,
2372  [&constantPerms](int32_t s) {
2373  return s >= 0 &&
2374  static_cast<size_t>(s) < constantPerms.size();
2375  }) ||
2376  !isPermutationVector(llvm::to_vector(llvm::map_range(
2377  constantPerms, [](int32_t v) -> int64_t { return v; }))))
2378  return emitOpError() << "expected valid permutation indices";
2379 
2380  // ERROR_IF(tensor_size(shape1) != tensor_size(shape))
2381  if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2382  inputShape.getNumElements() != outputShape.getNumElements())
2383  return emitOpError() << "expected input1 and output to have same numbers "
2384  "of elements, got "
2385  << inputShape.getNumElements() << " and "
2386  << outputShape.getNumElements();
2387 
2388  // Verify that the types of the input and output tensors are properly
2389  // permuted.
2390  if (inputShape.hasRank() && outputShape.hasRank()) {
2391  for (auto i = 0; i < outputShape.getRank(); i++) {
2392  if (inputShape.isDynamicDim(constantPerms[i]) ||
2393  outputShape.isDynamicDim(i))
2394  continue;
2395 
2396  if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2397  return emitOpError()
2398  << "expected output tensor dim " << i << " to match "
2399  << "input dim " << constantPerms[i] << " with value of "
2400  << inputShape.getDimSize(constantPerms[i]);
2401  }
2402  }
2403 
2404  return success();
2405 }
2406 
2407 LogicalResult TransposeOp::reifyResultShapes(
2408  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2409 
2410  const llvm::ArrayRef<int32_t> transposePerms = getPerms();
2411 
2412  Value input = getInput1();
2413  auto inputType = cast<TensorType>(input.getType());
2414 
2415  SmallVector<OpFoldResult> returnedDims(inputType.getRank());
2416  for (auto dim : transposePerms) {
2417  int32_t dimInInput = transposePerms[dim];
2418  if (inputType.isDynamicDim(dimInInput))
2419  returnedDims[dim] =
2420  builder.create<tensor::DimOp>(getLoc(), input, dimInInput)
2421  .getResult();
2422  else
2423  returnedDims[dim] =
2424  builder.getIndexAttr(inputType.getDimSize(dimInInput));
2425  }
2426 
2427  reifiedReturnShapes.emplace_back(std::move(returnedDims));
2428  return success();
2429 }
2430 
2431 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2432  MLIRContext *context, ::std::optional<Location> location,
2433  GatherOp::Adaptor adaptor,
2434  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2435  llvm::SmallVector<int64_t> outputShape;
2436  outputShape.resize(3, ShapedType::kDynamic);
2437 
2438  ShapeAdaptor valuesShape(adaptor.getValues().getType());
2439  if (valuesShape.hasRank()) {
2440  outputShape[0] = valuesShape.getDimSize(0);
2441  outputShape[2] = valuesShape.getDimSize(2);
2442  }
2443 
2444  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2445  if (indicesShape.hasRank()) {
2446  if (outputShape[0] == ShapedType::kDynamic)
2447  outputShape[0] = indicesShape.getDimSize(0);
2448  if (outputShape[1] == ShapedType::kDynamic)
2449  outputShape[1] = indicesShape.getDimSize(1);
2450  }
2451 
2452  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2453  return success();
2454 }
2455 
2456 LogicalResult tosa::GatherOp::verify() {
2457  if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
2458  /* outType = */ getOutput().getType())
2459  .failed()) {
2460  return failure();
2461  }
2462 
2463  const ShapeAdaptor valuesShape(getValues().getType());
2464  const ShapeAdaptor indicesShape(getIndices().getType());
2465  const ShapeAdaptor outputShape(getOutput().getType());
2466 
2467  int64_t N = ShapedType::kDynamic;
2468  int64_t W = ShapedType::kDynamic;
2469  int64_t C = ShapedType::kDynamic;
2470 
2471  if (valuesShape.hasRank()) {
2472  N = valuesShape.getDimSize(0);
2473  C = valuesShape.getDimSize(2);
2474  }
2475  if (indicesShape.hasRank()) {
2476  const int64_t indicesN = indicesShape.getDimSize(0);
2477  W = indicesShape.getDimSize(1);
2478  if (N == ShapedType::kDynamic)
2479  N = indicesN;
2480  else if (indicesN != ShapedType::kDynamic && N != indicesN)
2481  return emitOpError() << "requires indices dimension 0 to have size " << N
2482  << ", got " << indicesN;
2483  }
2484  if (outputShape.hasRank()) {
2485  const int64_t outputN = outputShape.getDimSize(0);
2486  const int64_t outputW = outputShape.getDimSize(1);
2487  const int64_t outputC = outputShape.getDimSize(2);
2488  if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2489  N != outputN)
2490  return emitOpError() << "requires output dimension 0 to have size " << N
2491  << ", got " << outputN;
2492 
2493  if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2494  W != outputW)
2495  return emitOpError() << "requires output dimension 1 to have size " << W
2496  << ", got " << outputW;
2497  if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2498  C != outputC)
2499  return emitOpError() << "requires output dimension 2 to have size " << C
2500  << ", got " << outputC;
2501  }
2502  return success();
2503 }
2504 
2505 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2506  MLIRContext *context, ::std::optional<Location> location,
2507  ResizeOp::Adaptor adaptor,
2508  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2509  llvm::SmallVector<int64_t, 4> outputShape;
2510  outputShape.resize(4, ShapedType::kDynamic);
2511 
2512  ShapeAdaptor inputShape(adaptor.getInput().getType());
2513  if (!inputShape.hasRank())
2514  return failure();
2515 
2516  outputShape[0] = inputShape.getDimSize(0);
2517  outputShape[3] = inputShape.getDimSize(3);
2518  int64_t inputHeight = inputShape.getDimSize(1);
2519  int64_t inputWidth = inputShape.getDimSize(2);
2520 
2521  if ((inputHeight == ShapedType::kDynamic) ||
2522  (inputWidth == ShapedType::kDynamic))
2523  return failure();
2524 
2525  SmallVector<int64_t> scaleInt, offsetInt, borderInt;
2526  if (!tosa::getConstShapeValues(adaptor.getScale().getDefiningOp(),
2527  scaleInt) ||
2528  !tosa::getConstShapeValues(adaptor.getOffset().getDefiningOp(),
2529  offsetInt) ||
2530  !tosa::getConstShapeValues(adaptor.getBorder().getDefiningOp(),
2531  borderInt)) {
2532  return failure();
2533  }
2534 
2535  // Compute the output shape based on attributes: scale, offset, and border.
2536  outputShape[1] =
2537  (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2538  scaleInt[1]) +
2539  1;
2540 
2541  outputShape[2] =
2542  (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2543  scaleInt[3]) +
2544  1;
2545 
2546  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2547  return success();
2548 }
2549 
2550 LogicalResult tosa::ResizeOp::verify() {
2551  const Value input = getInput();
2552  const Value output = getOutput();
2553  const RankedTensorType inputType =
2554  llvm::dyn_cast<RankedTensorType>(input.getType());
2555  const RankedTensorType outputType =
2556  llvm::dyn_cast<RankedTensorType>(output.getType());
2557 
2558  SmallVector<int64_t> scaleValues;
2559  SmallVector<int64_t> offsetValues;
2560  SmallVector<int64_t> borderValues;
2561  if (!tosa::getConstShapeValues(getScale().getDefiningOp(), scaleValues) ||
2562  !tosa::getConstShapeValues(getOffset().getDefiningOp(), offsetValues) ||
2563  !tosa::getConstShapeValues(getBorder().getDefiningOp(), borderValues)) {
2564  // Skip following checks if shape is not constant
2565  return success();
2566  }
2567 
2568  if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
2569  return emitOpError("expect all scale values to be > 0, got ")
2570  << scaleValues;
2571 
2572  const int64_t scaleYN = scaleValues[0];
2573  const int64_t scaleYD = scaleValues[1];
2574  const int64_t scaleXN = scaleValues[2];
2575  const int64_t scaleXD = scaleValues[3];
2576 
2577  const int64_t offsetY = offsetValues[0];
2578  const int64_t offsetX = offsetValues[1];
2579 
2580  const int64_t borderY = borderValues[0];
2581  const int64_t borderX = borderValues[1];
2582 
2583  if (!inputType)
2584  return success();
2585  if (!outputType)
2586  return success();
2587 
2588  const int64_t oh = outputType.getDimSize(1);
2589  const int64_t ow = outputType.getDimSize(2);
2590  const int64_t ih = inputType.getDimSize(1);
2591  const int64_t iw = inputType.getDimSize(2);
2592 
2593  // Don't check with input height that could be broadcast (ih != 1)
2594  // since Linalg, a consumer of TOSA, expects broadcasting support
2595  // in resize to be available. Taking the cautious approach for now,
2596  // we can consider removing support for broadcasting later.
2597  if (ih != ShapedType::kDynamic && ih != 1) {
2598  const std::optional<int64_t> calculatedOutHeightMinusOne =
2599  idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
2600  if (!calculatedOutHeightMinusOne.has_value())
2601  return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
2602  "border_y ")
2603  << "to be wholly divisible by scale_y_d, got ((" << ih
2604  << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
2605  << ") / " << scaleYD;
2606  const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
2607  if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
2608  return emitOpError("calculated output height did not match expected: ")
2609  << "calculated=" << calculatedOutHeight << ", expected=" << oh;
2610  }
2611 
2612  // Don't check with input width that could be broadcast (iw != 1)
2613  // since Linalg, a consumer of TOSA, expects broadcasting support
2614  // in resize to be available. Taking the cautious approach for now,
2615  // we can consider removing support for broadcasting later.
2616  if (iw != ShapedType::kDynamic && iw != 1) {
2617  const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
2618  const std::optional<int64_t> calculatedOutWidthMinusOne =
2619  idivCheck(scaledInWidth, scaleXD);
2620  if (!calculatedOutWidthMinusOne.has_value())
2621  return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
2622  "border_x ")
2623  << "to be wholly divisible by scale_x_d, got ((" << iw
2624  << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
2625  << ") / " << scaleXD;
2626  const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
2627  if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
2628  return emitOpError("calculated output width did not match expected: ")
2629  << "calculated=" << calculatedOutWidth << ", expected=" << ow;
2630  }
2631 
2632  return success();
2633 }
2634 
2635 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
2636  MLIRContext *context, ::std::optional<Location> location,
2637  ScatterOp::Adaptor adaptor,
2638  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2639  llvm::SmallVector<int64_t> outputShape;
2640  outputShape.resize(3, ShapedType::kDynamic);
2641 
2642  ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
2643  if (valuesInShape.hasRank()) {
2644  outputShape[0] = valuesInShape.getDimSize(0);
2645  outputShape[1] = valuesInShape.getDimSize(1);
2646  outputShape[2] = valuesInShape.getDimSize(2);
2647  }
2648 
2649  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2650  if (indicesShape.hasRank()) {
2651  if (outputShape[0] == ShapedType::kDynamic)
2652  outputShape[0] = indicesShape.getDimSize(0);
2653  }
2654 
2655  ShapeAdaptor inputShape(adaptor.getInput().getType());
2656  if (inputShape.hasRank()) {
2657  if (outputShape[0] == ShapedType::kDynamic)
2658  outputShape[0] = inputShape.getDimSize(0);
2659  if (outputShape[2] == ShapedType::kDynamic)
2660  outputShape[2] = inputShape.getDimSize(2);
2661  }
2662 
2663  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2664  return success();
2665 }
2666 
2667 LogicalResult tosa::ScatterOp::verify() {
2668  if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
2669  /* outType = */ getValuesOut().getType())
2670  .failed() ||
2671  verifySameElementTypes(*this, /* inType = */ getInput().getType(),
2672  /* outType = */ getValuesOut().getType())
2673  .failed()) {
2674  return failure();
2675  }
2676 
2677  const ShapeAdaptor valuesInShape(getValuesIn().getType());
2678  const ShapeAdaptor indicesShape(getIndices().getType());
2679  const ShapeAdaptor inputShape(getInput().getType());
2680  const ShapeAdaptor outputShape(getValuesOut().getType());
2681 
2682  int64_t N = ShapedType::kDynamic;
2683  int64_t K = ShapedType::kDynamic;
2684  int64_t W = ShapedType::kDynamic;
2685  int64_t C = ShapedType::kDynamic;
2686  if (valuesInShape.hasRank()) {
2687  N = valuesInShape.getDimSize(0);
2688  K = valuesInShape.getDimSize(1);
2689  C = valuesInShape.getDimSize(2);
2690  }
2691  if (indicesShape.hasRank()) {
2692  const int64_t indicesN = indicesShape.getDimSize(0);
2693  W = indicesShape.getDimSize(1);
2694  if (N == ShapedType::kDynamic)
2695  N = indicesN;
2696  else if (indicesN != ShapedType::kDynamic && N != indicesN)
2697  return emitOpError() << "requires indices dimension 0 to have size " << N
2698  << ", got " << indicesN;
2699  }
2700  if (inputShape.hasRank()) {
2701  const int64_t inputN = inputShape.getDimSize(0);
2702  const int64_t inputW = inputShape.getDimSize(1);
2703  const int64_t inputC = inputShape.getDimSize(2);
2704  if (N == ShapedType::kDynamic)
2705  N = inputN;
2706  else if (inputN != ShapedType::kDynamic && N != inputN)
2707  return emitOpError() << "requires input dimension 0 to have size " << N
2708  << ", got " << inputN;
2709  if (W == ShapedType::kDynamic)
2710  W = inputW;
2711  else if (inputW != ShapedType::kDynamic && W != inputW)
2712  return emitOpError() << "requires input dimension 1 to have size " << W
2713  << ", got " << inputW;
2714 
2715  if (C == ShapedType::kDynamic)
2716  C = inputC;
2717  else if (inputC != ShapedType::kDynamic && C != inputC)
2718  return emitOpError() << "requires input dimension 2 to have size " << C
2719  << ", got " << inputC;
2720  }
2721  if (outputShape.hasRank()) {
2722  const int64_t outputN = outputShape.getDimSize(0);
2723  const int64_t outputK = outputShape.getDimSize(1);
2724  const int64_t outputC = outputShape.getDimSize(2);
2725  if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2726  N != outputN)
2727  return emitOpError() << "requires values_out dimension 0 to have size "
2728  << N << ", got " << outputN;
2729  if (K == ShapedType::kDynamic)
2730  K = outputK;
2731  else if (outputK != ShapedType::kDynamic && K != outputK)
2732  return emitOpError() << "requires values_out dimension 1 to have size "
2733  << K << ", got " << outputK;
2734  if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2735  C != outputC)
2736  return emitOpError() << "requires values_out dimension 2 to have size "
2737  << C << ", got " << outputC;
2738  }
2739  if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
2740  return emitOpError() << "requires dimensions K >= W, got K=" << K
2741  << " and W=" << W;
2742 
2743  return success();
2744 }
2745 
2746 static LogicalResult ReduceInferReturnTypes(
2747  ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
2748  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2749  int64_t axisVal = axis.getValue().getSExtValue();
2750  if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
2751  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
2752  return success();
2753  }
2754 
2755  SmallVector<int64_t> outputShape;
2756  operandShape.getDims(outputShape);
2757  outputShape[axisVal] = 1;
2758  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2759  return success();
2760 }
2761 
2762 #define COMPATIBLE_RETURN_TYPES(OP) \
2763  bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
2764  if (l.size() != r.size() || l.size() != 1) \
2765  return false; \
2766  if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
2767  return false; \
2768  return succeeded(verifyCompatibleShape(l[0], r[0])); \
2769  }
2770 
2771 #define REDUCE_SHAPE_INFER(OP) \
2772  LogicalResult OP::inferReturnTypeComponents( \
2773  MLIRContext *context, ::std::optional<Location> location, \
2774  OP::Adaptor adaptor, \
2775  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2776  Type inputType = \
2777  llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
2778  ShapeAdaptor inputShape(adaptor.getInput().getType()); \
2779  const Properties &prop = adaptor.getProperties(); \
2780  return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
2781  inferredReturnShapes); \
2782  } \
2783  COMPATIBLE_RETURN_TYPES(OP)
2784 
2785 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
2786 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
2787 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
2788 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
2789 REDUCE_SHAPE_INFER(tosa::ReduceProductOp)
2790 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
2791 #undef REDUCE_SHAPE_INFER
2792 COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
2793 #undef COMPATIBLE_RETURN_TYPES
2794 
2795 template <typename T>
2796 static LogicalResult verifyReduceOp(T op) {
2797  // All TOSA reduce Ops have input, output and axis.
2798  TensorType inputType = op.getInput().getType();
2799  TensorType outputType = op.getOutput().getType();
2800  int32_t reduceAxis = op.getAxis();
2801 
2802  if (reduceAxis < 0) {
2803  op.emitOpError("reduce axis must not be negative");
2804  return failure();
2805  }
2806  if (inputType.hasRank()) {
2807  int64_t inputRank = inputType.getRank();
2808  // We allow for a special case where the input/output shape has rank 0 and
2809  // axis is also 0.
2810  if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
2811  op.emitOpError("expect input tensor rank (")
2812  << inputRank << ") to be larger than reduce axis (" << reduceAxis
2813  << ")";
2814  return failure();
2815  }
2816  }
2817  if (outputType.hasRank()) {
2818  int64_t outputRank = outputType.getRank();
2819  if (inputType.hasRank() && outputRank != inputType.getRank()) {
2820  op.emitOpError(
2821  "expect output tensor rank to be equal to input tensor rank");
2822  return failure();
2823  }
2824  if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
2825  op.emitOpError("expect output tensor rank (")
2826  << outputRank << ") to be larger than reduce axis (" << reduceAxis
2827  << ")";
2828  return failure();
2829  }
2830  // We can only verify the reduced dimension size to be 1 if this is not
2831  // the special case of output rank == 0.
2832  if (outputRank != 0) {
2833  auto outputShape = outputType.getShape();
2834  if (!outputType.isDynamicDim(reduceAxis) &&
2835  outputShape[reduceAxis] != 1) {
2836  op.emitOpError("expect reduced dimension size to be 1, got ")
2837  << outputShape[reduceAxis];
2838  return failure();
2839  }
2840  }
2841  }
2842  return success();
2843 }
2844 
2845 LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
2846 LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
2847 LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
2848 LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
2849 LogicalResult tosa::ReduceProductOp::verify() { return verifyReduceOp(*this); }
2850 LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
2851 
2852 static LogicalResult NAryInferReturnTypes(
2853  const ValueShapeRange &operands,
2854  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2855  llvm::SmallVector<int64_t> outShape;
2856  if (resolveBroadcastShape(operands, outShape).failed()) {
2857  inferredReturnShapes.push_back(ShapedTypeComponents());
2858  } else {
2859  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2860  }
2861  return success();
2862 }
2863 
2864 #define NARY_SHAPE_INFER(OP) \
2865  LogicalResult OP::inferReturnTypeComponents( \
2866  MLIRContext *context, ::std::optional<Location> location, \
2867  ValueShapeRange operands, DictionaryAttr attributes, \
2868  OpaqueProperties properties, RegionRange regions, \
2869  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2870  return NAryInferReturnTypes(operands, inferredReturnShapes); \
2871  }
2872 
2873 NARY_SHAPE_INFER(tosa::AbsOp)
2874 NARY_SHAPE_INFER(tosa::AddOp)
2875 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
2876 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
2877 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
2878 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
2879 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
2880 NARY_SHAPE_INFER(tosa::CastOp)
2881 NARY_SHAPE_INFER(tosa::CeilOp)
2882 NARY_SHAPE_INFER(tosa::ClampOp)
2883 NARY_SHAPE_INFER(tosa::ClzOp)
2884 NARY_SHAPE_INFER(tosa::CosOp)
2885 NARY_SHAPE_INFER(tosa::ExpOp)
2886 NARY_SHAPE_INFER(tosa::FloorOp)
2887 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
2888 NARY_SHAPE_INFER(tosa::GreaterOp)
2889 NARY_SHAPE_INFER(tosa::IdentityOp)
2890 NARY_SHAPE_INFER(tosa::IntDivOp)
2891 NARY_SHAPE_INFER(tosa::LogOp)
2892 NARY_SHAPE_INFER(tosa::LogicalAndOp)
2893 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
2894 NARY_SHAPE_INFER(tosa::LogicalNotOp)
2895 NARY_SHAPE_INFER(tosa::LogicalOrOp)
2896 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
2897 NARY_SHAPE_INFER(tosa::LogicalXorOp)
2898 NARY_SHAPE_INFER(tosa::MaximumOp)
2899 NARY_SHAPE_INFER(tosa::MinimumOp)
2900 NARY_SHAPE_INFER(tosa::PowOp)
2901 NARY_SHAPE_INFER(tosa::ReciprocalOp)
2902 NARY_SHAPE_INFER(tosa::ReverseOp)
2903 NARY_SHAPE_INFER(tosa::RsqrtOp)
2904 NARY_SHAPE_INFER(tosa::SinOp)
2905 NARY_SHAPE_INFER(tosa::SelectOp)
2906 NARY_SHAPE_INFER(tosa::SubOp)
2907 NARY_SHAPE_INFER(tosa::TanhOp)
2908 NARY_SHAPE_INFER(tosa::ErfOp)
2909 NARY_SHAPE_INFER(tosa::SigmoidOp)
2910 #undef PRED_SHAPE_INFER
2911 
2912 LogicalResult tosa::NegateOp::inferReturnTypeComponents(
2913  MLIRContext *context, ::std::optional<Location> location,
2914  NegateOp::Adaptor adaptor,
2915  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2916  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2917  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
2918  return success();
2919 }
2920 
2921 LogicalResult tosa::NegateOp::verify() {
2922  // Verify same element type
2923  const Type input1Type = getInput1().getType();
2924  const Type outputType = getOutput().getType();
2925  if (verifySameElementTypes(*this, input1Type, outputType).failed())
2926  return failure();
2927 
2928  // Verify same shape
2929  const SmallVector<Type, 2> types = {input1Type, outputType};
2930  if (failed(verifyCompatibleShapes(types)))
2931  return emitOpError() << "requires the same shape for input1 and output";
2932 
2933  const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
2934  const Type input1ZpEType =
2935  getStorageElementTypeOrSelf(getInput1Zp().getType());
2936  if (input1EType != input1ZpEType) {
2937  return emitOpError("expect both input1 and its zero point are the same "
2938  "element type, got ")
2939  << input1EType << " and " << input1ZpEType;
2940  }
2941  const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
2942  const Type outputZpEType =
2943  getStorageElementTypeOrSelf(getOutputZp().getType());
2944  if (outputEType != outputZpEType) {
2945  return emitOpError("expect both output and its zero point are the same "
2946  "element type, got ")
2947  << outputEType << " and " << outputZpEType;
2948  }
2949 
2950  FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2951  if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
2952  return failure();
2953 
2954  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2955  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
2956  return failure();
2957 
2958  return success();
2959 }
2960 
2961 static LogicalResult poolingInferReturnTypes(
2962  ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
2963  ArrayRef<int64_t> pad,
2964  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2965  llvm::SmallVector<int64_t> outputShape;
2966  outputShape.resize(4, ShapedType::kDynamic);
2967 
2968  // We only know the rank if the input type is unranked.
2969  if (!inputShape) {
2970  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2971  return success();
2972  }
2973 
2974  // Batch and number of channels are identical for pooling layer.
2975  outputShape[0] = inputShape.getDimSize(0);
2976  outputShape[3] = inputShape.getDimSize(3);
2977 
2978  int64_t height = inputShape.getDimSize(1);
2979  int64_t width = inputShape.getDimSize(2);
2980 
2981  if (!ShapedType::isDynamic(height)) {
2982  int64_t padded = height + pad[0] + pad[1] - kernel[0];
2983  outputShape[1] = padded / stride[0] + 1;
2984  }
2985 
2986  if (!ShapedType::isDynamic(width)) {
2987  int64_t padded = width + pad[2] + pad[3] - kernel[1];
2988  outputShape[2] = padded / stride[1] + 1;
2989  }
2990 
2991  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2992  return success();
2993 }
2994 
2995 LogicalResult Conv2DOp::inferReturnTypeComponents(
2996  MLIRContext *context, ::std::optional<Location> location,
2997  Conv2DOp::Adaptor adaptor,
2998  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2999  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3000 
3001  int64_t inputWidth = ShapedType::kDynamic;
3002  int64_t inputHeight = ShapedType::kDynamic;
3003  int64_t weightWidth = ShapedType::kDynamic;
3004  int64_t weightHeight = ShapedType::kDynamic;
3005 
3006  // Input shape describes input width/height and batch.
3007 
3008  ShapeAdaptor inputShape(adaptor.getInput().getType());
3009  if (inputShape.hasRank()) {
3010  outputShape[0] = inputShape.getDimSize(0);
3011  inputHeight = inputShape.getDimSize(1);
3012  inputWidth = inputShape.getDimSize(2);
3013  }
3014 
3015  // Weight shapes describes the filter width/height and the output channels.
3016  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3017  if (weightShape.hasRank()) {
3018  outputShape[3] = weightShape.getDimSize(0);
3019  weightHeight = weightShape.getDimSize(1);
3020  weightWidth = weightShape.getDimSize(2);
3021  }
3022 
3023  // Bias shape can describe the output channels.
3024  ShapeAdaptor biasShape(adaptor.getBias().getType());
3025  if (biasShape.hasRank()) {
3026  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3027  ? biasShape.getDimSize(0)
3028  : outputShape[3];
3029  }
3030 
3031  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3032  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3033  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3034 
3035  if (!ShapedType::isDynamic(inputHeight) &&
3036  !ShapedType::isDynamic(weightHeight)) {
3037  int64_t inputSize = inputHeight + padding[0] + padding[1];
3038  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3039  int64_t unstridedResult = inputSize - filterSize + 1;
3040  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3041  }
3042 
3043  if (!ShapedType::isDynamic(inputWidth) &&
3044  !ShapedType::isDynamic(weightWidth)) {
3045  int64_t inputSize = inputWidth + padding[2] + padding[3];
3046  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3047  int64_t unstridedResult = inputSize - filterSize + 1;
3048  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3049  }
3050 
3051  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3052  return success();
3053 }
3054 
3055 LogicalResult Conv2DOp::verify() {
3056  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3057  verifyConvOpErrorIf(*this).failed())
3058  return failure();
3059  return success();
3060 }
3061 
3062 LogicalResult Conv3DOp::inferReturnTypeComponents(
3063  MLIRContext *context, ::std::optional<Location> location,
3064  Conv3DOp::Adaptor adaptor,
3065  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3066  llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
3067 
3068  int64_t inputWidth = ShapedType::kDynamic;
3069  int64_t inputHeight = ShapedType::kDynamic;
3070  int64_t inputDepth = ShapedType::kDynamic;
3071 
3072  int64_t weightWidth = ShapedType::kDynamic;
3073  int64_t weightHeight = ShapedType::kDynamic;
3074  int64_t weightDepth = ShapedType::kDynamic;
3075 
3076  // Input shape describes input width/height and batch.
3077  ShapeAdaptor inputShape(adaptor.getInput().getType());
3078  if (inputShape.hasRank()) {
3079  outputShape[0] = inputShape.getDimSize(0);
3080  inputDepth = inputShape.getDimSize(1);
3081  inputHeight = inputShape.getDimSize(2);
3082  inputWidth = inputShape.getDimSize(3);
3083  }
3084 
3085  // Weight shapes describes the filter width/height and the output channels.
3086  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3087  if (weightShape.hasRank()) {
3088  outputShape[4] = weightShape.getDimSize(0);
3089  weightDepth = weightShape.getDimSize(1);
3090  weightHeight = weightShape.getDimSize(2);
3091  weightWidth = weightShape.getDimSize(3);
3092  }
3093 
3094  // Bias shape can describe the output channels.
3095  ShapeAdaptor biasShape(adaptor.getBias().getType());
3096  if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
3097  outputShape[4] = biasShape.getDimSize(0);
3098  }
3099 
3100  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3101  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3102  llvm::ArrayRef<int64_t> pad = adaptor.getPad();
3103 
3104  if (!ShapedType::isDynamic(inputDepth) &&
3105  !ShapedType::isDynamic(weightDepth)) {
3106  int32_t inputSize = inputDepth + pad[0] + pad[1];
3107  int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
3108  int32_t unstridedResult = inputSize - filterSize + 1;
3109  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3110  }
3111 
3112  if (!ShapedType::isDynamic(inputHeight) &&
3113  !ShapedType::isDynamic(weightHeight)) {
3114  int32_t inputSize = inputHeight + pad[2] + pad[3];
3115  int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
3116  int32_t unstridedResult = inputSize - filterSize + 1;
3117  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3118  }
3119 
3120  if (!ShapedType::isDynamic(inputWidth) &&
3121  !ShapedType::isDynamic(weightWidth)) {
3122  int32_t inputSize = inputWidth + pad[4] + pad[5];
3123  int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
3124  int32_t unstridedResult = inputSize - filterSize + 1;
3125  outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3126  }
3127 
3128  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3129  return success();
3130 }
3131 
3132 LogicalResult Conv3DOp::verify() {
3133  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3134  verifyConvOpErrorIf(*this).failed())
3135  return failure();
3136  return success();
3137 }
3138 
3139 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3140  MLIRContext *context, ::std::optional<Location> location,
3141  AvgPool2dOp::Adaptor adaptor,
3142  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3143  ShapeAdaptor inputShape(adaptor.getInput().getType());
3144  const Properties &prop = adaptor.getProperties();
3145  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3146  inferredReturnShapes);
3147 }
3148 
3149 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3150  MLIRContext *context, ::std::optional<Location> location,
3151  MaxPool2dOp::Adaptor adaptor,
3152  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3153  ShapeAdaptor inputShape(adaptor.getInput().getType());
3154  const Properties &prop = adaptor.getProperties();
3155  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3156  inferredReturnShapes);
3157 }
3158 
3159 LogicalResult MaxPool2dOp::verify() {
3160  if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
3161  /* outType = */ getOutput().getType())))
3162  return failure();
3163 
3164  if (failed(verifyPoolingOp(*this)))
3165  return failure();
3166 
3167  return success();
3168 }
3169 
3170 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3171  MLIRContext *context, ::std::optional<Location> location,
3172  DepthwiseConv2DOp::Adaptor adaptor,
3173  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3174  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3175 
3176  int64_t inputWidth = ShapedType::kDynamic;
3177  int64_t inputHeight = ShapedType::kDynamic;
3178  int64_t inputChannels = ShapedType::kDynamic;
3179 
3180  int64_t weightWidth = ShapedType::kDynamic;
3181  int64_t weightHeight = ShapedType::kDynamic;
3182  int64_t depthChannels = ShapedType::kDynamic;
3183 
3184  // Input shape describes input width/height and batch.
3185  ShapeAdaptor inputShape(adaptor.getInput().getType());
3186  if (inputShape.hasRank()) {
3187  outputShape[0] = inputShape.getDimSize(0);
3188  inputHeight = inputShape.getDimSize(1);
3189  inputWidth = inputShape.getDimSize(2);
3190  inputChannels = inputShape.getDimSize(3);
3191  }
3192 
3193  // Weight shapes describes the filter width/height and the output channels.
3194  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3195  if (weightShape.hasRank()) {
3196  weightHeight = weightShape.getDimSize(0);
3197  weightWidth = weightShape.getDimSize(1);
3198  inputChannels = ShapedType::isDynamic(inputChannels)
3199  ? weightShape.getDimSize(2)
3200  : inputChannels;
3201  depthChannels = weightShape.getDimSize(3);
3202  }
3203 
3204  // If both inputChannels and depthChannels are available we can determine
3205  // the output channels.
3206  if (!ShapedType::isDynamic(inputChannels) &&
3207  !ShapedType::isDynamic(depthChannels)) {
3208  outputShape[3] = inputChannels * depthChannels;
3209  }
3210 
3211  // Bias shape can describe the output channels.
3212  ShapeAdaptor biasShape(adaptor.getBias().getType());
3213  if (biasShape.hasRank()) {
3214  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3215  ? biasShape.getDimSize(0)
3216  : outputShape[3];
3217  }
3218 
3219  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3220  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3221  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3222 
3223  if (!ShapedType::isDynamic(inputHeight) &&
3224  !ShapedType::isDynamic(weightHeight)) {
3225  int64_t inputSize = inputHeight + padding[0] + padding[1];
3226  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3227  int64_t unstridedResult = inputSize - filterSize + 1;
3228  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3229  }
3230 
3231  if (!ShapedType::isDynamic(inputWidth) &&
3232  !ShapedType::isDynamic(weightWidth)) {
3233  int64_t inputSize = inputWidth + padding[2] + padding[3];
3234  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3235  int64_t unstridedResult = inputSize - filterSize + 1;
3236  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3237  }
3238 
3239  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3240  return success();
3241 }
3242 
3243 LogicalResult DepthwiseConv2DOp::verify() {
3244  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3245  verifyConvOpErrorIf(*this).failed())
3246  return failure();
3247  return success();
3248 }
3249 
3250 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3251  MLIRContext *context, ::std::optional<Location> location,
3252  TransposeConv2DOp::Adaptor adaptor,
3253  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3254  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3255 
3256  int64_t inputWidth = ShapedType::kDynamic;
3257  int64_t inputHeight = ShapedType::kDynamic;
3258  int64_t weightWidth = ShapedType::kDynamic;
3259  int64_t weightHeight = ShapedType::kDynamic;
3260 
3261  // Input shape describes input width/height and batch.
3262  ShapeAdaptor inputShape(adaptor.getInput().getType());
3263  if (inputShape.hasRank()) {
3264  outputShape[0] = ShapedType::isDynamic(outputShape[0])
3265  ? inputShape.getDimSize(0)
3266  : outputShape[0];
3267  inputHeight = inputShape.getDimSize(1);
3268  inputWidth = inputShape.getDimSize(2);
3269  }
3270 
3271  // Weight shapes describes the filter width/height and the output channels.
3272  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3273  if (weightShape.hasRank()) {
3274  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3275  ? weightShape.getDimSize(0)
3276  : outputShape[3];
3277  weightHeight = weightShape.getDimSize(1);
3278  weightWidth = weightShape.getDimSize(2);
3279  }
3280 
3281  // Bias shape can describe the output channels.
3282  ShapeAdaptor biasShape(adaptor.getInput().getType());
3283  if (biasShape.hasRank()) {
3284  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3285  ? biasShape.getDimSize(0)
3286  : outputShape[3];
3287  }
3288 
3289  llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
3290  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3291 
3292  if (!ShapedType::isDynamic(inputHeight) &&
3293  !ShapedType::isDynamic(weightHeight)) {
3294  int64_t calculateSize =
3295  (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3296  outputShape[1] =
3297  ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3298  }
3299 
3300  if (!ShapedType::isDynamic(inputWidth) &&
3301  !ShapedType::isDynamic(weightWidth)) {
3302  int64_t calculateSize =
3303  (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3304  outputShape[2] =
3305  ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3306  }
3307 
3308  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3309  return success();
3310 }
3311 
3312 LogicalResult TransposeConv2DOp::verify() {
3313  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
3314  return failure();
3315 
3316  const llvm::ArrayRef<int64_t> strides = getStride();
3317  const int64_t strideY = strides[0];
3318  const int64_t strideX = strides[1];
3319 
3320  if (strideY < 1 || strideX < 1)
3321  return emitOpError("expect all stride values to be >= 1, got [")
3322  << strides << "]";
3323 
3324  const auto checkPadAgainstKernelDim =
3325  [this](int64_t pad_value, int64_t kernel_dim_size,
3326  llvm::StringRef pad_name,
3327  llvm::StringRef kernel_dim_name) -> LogicalResult {
3328  if (pad_value <= -kernel_dim_size)
3329  return emitOpError("expected ")
3330  << pad_name << " > -" << kernel_dim_name
3331  << ", but got: " << pad_name << "=" << pad_value << " and "
3332  << kernel_dim_name << "=" << kernel_dim_size;
3333  return success();
3334  };
3335 
3336  const llvm::ArrayRef<int64_t> padding = getOutPad();
3337  const int64_t outPadTop = padding[0];
3338  const int64_t outPadBottom = padding[1];
3339  const int64_t outPadLeft = padding[2];
3340  const int64_t outPadRight = padding[3];
3341 
3342  const auto weightType =
3343  llvm::dyn_cast<RankedTensorType>(getWeight().getType());
3344 
3345  if (weightType) {
3346  const int64_t kernelHeight = weightType.getDimSize(1);
3347  if (!ShapedType::isDynamic(kernelHeight)) {
3348  if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
3349  "out_pad_top", "KH")))
3350  return failure();
3351 
3352  if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3353  "out_pad_bottom", "KH")))
3354  return failure();
3355  }
3356 
3357  const int64_t kernelWidth = weightType.getDimSize(2);
3358  if (!ShapedType::isDynamic(kernelWidth)) {
3359  if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
3360  "out_pad_left", "KW")))
3361  return failure();
3362 
3363  if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3364  "out_pad_right", "KW")))
3365  return failure();
3366  }
3367  }
3368 
3369  // Rest of the checks depend on the output type being a RankedTensorType
3370  const auto outputType =
3371  llvm::dyn_cast<RankedTensorType>(getOutput().getType());
3372  if (!outputType)
3373  return success();
3374 
3375  const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
3376  if (inputType && weightType) {
3377  const int64_t inputHeight = inputType.getDimSize(1);
3378  const int64_t kernelHeight = weightType.getDimSize(1);
3379  const int64_t outputHeight = outputType.getDimSize(1);
3380 
3381  if (!ShapedType::isDynamic(inputHeight) &&
3382  !ShapedType::isDynamic(outputHeight)) {
3383  if (outputHeight !=
3384  (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3385  return emitOpError(
3386  "dimension mismatch: expected OH == (IH - 1) * stride_y "
3387  "+ out_pad_top + out_pad_bottom + KH, but got ")
3388  << outputHeight << " != (" << inputHeight << " - 1) * "
3389  << strideY << " + " << outPadTop << " + " << outPadBottom
3390  << " + " << kernelHeight;
3391  }
3392 
3393  const int64_t inputWidth = inputType.getDimSize(2);
3394  const int64_t kernelWidth = weightType.getDimSize(2);
3395  const int64_t outputWidth = outputType.getDimSize(2);
3396 
3397  if (!ShapedType::isDynamic(inputWidth) &&
3398  !ShapedType::isDynamic(outputWidth)) {
3399  if (outputWidth !=
3400  (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3401  return emitOpError(
3402  "dimension mismatch: expected OW == (IW - 1) * stride_x "
3403  "+ out_pad_left + out_pad_right + KW, but got ")
3404  << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
3405  << " + " << outPadLeft << " + " << outPadRight << " + "
3406  << kernelWidth;
3407  }
3408  }
3409 
3410  const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().getType());
3411 
3412  if (!biasType)
3413  return success();
3414 
3415  const int64_t biasChannels = biasType.getDimSize(0);
3416 
3417  // Skip further checks if bias is dynamic
3418  if (biasChannels == ShapedType::kDynamic)
3419  return success();
3420 
3421  const int64_t outputChannels = outputType.getDimSize(3);
3422  if (biasChannels != outputChannels && biasChannels != 1)
3423  return emitOpError(
3424  "bias channels expected to be equal to output channels (")
3425  << outputChannels << ") or 1, got " << biasChannels;
3426 
3427  return success();
3428 }
3429 
3430 LogicalResult RescaleOp::verify() {
3431  auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
3432  if (!inputType) {
3433  emitOpError("expect shaped tensor for input, got ") << getInput().getType();
3434  return failure();
3435  }
3436 
3437  auto inputElementType =
3438  getStorageElementTypeOrSelf(inputType.getElementType());
3439  if (!mlir::isa<IntegerType>(inputElementType)) {
3440  emitOpError("expect input to have integer element type, got ")
3441  << inputElementType;
3442  return failure();
3443  }
3444 
3445  auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
3446  if (!outputType) {
3447  emitOpError("expect shaped tensor for output, got ")
3448  << getOutput().getType();
3449  return failure();
3450  }
3451 
3452  auto outputElementType =
3453  getStorageElementTypeOrSelf(outputType.getElementType());
3454  if (!mlir::isa<IntegerType>(outputElementType)) {
3455  emitOpError("expect output to have integer element type, got ")
3456  << outputElementType;
3457  return failure();
3458  }
3459 
3460  if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
3461  .failed())
3462  return failure();
3463 
3464  if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
3465  .failed())
3466  return failure();
3467 
3468  FailureOr<int64_t> maybeIZp = getInputZeroPoint();
3469  if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
3470  return failure();
3471 
3472  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3473  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3474  return failure();
3475 
3476  auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
3477  if (!multiplierType) {
3478  emitOpError("expect shaped tensor for multiplier, got ")
3479  << getMultiplier().getType();
3480  return failure();
3481  }
3482 
3483  auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
3484  if (!shiftType) {
3485  emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
3486  return failure();
3487  }
3488 
3489  // multiplier element type must be i32 for scale32 = true
3490  if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
3491  emitOpError("expect i32 element type for multiplier for scale32=true, got ")
3492  << multiplierType.getElementType();
3493  return failure();
3494  }
3495 
3496  // multiplier element type must be i16 for scale32 = false
3497  if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
3498  emitOpError(
3499  "expect i16 element type for multiplier for scale32=false, got ")
3500  << multiplierType.getElementType();
3501  return failure();
3502  }
3503 
3504  if (!inputType.hasRank())
3505  return success();
3506 
3507  // multiplier/shift must have shape = {numChannels},
3508  // where numChannel is 1 if per_channel = false
3509  // otherwise numChannel is dimension in input shape's last axis
3510  int64_t numChannels = 1;
3511  if (getPerChannel()) {
3512  if (inputType.getRank() < 1) {
3513  emitOpError("requires input to be at least rank 1 when per_channel is "
3514  "true, but got rank ")
3515  << inputType.getRank();
3516  return failure();
3517  }
3518  numChannels = inputType.getDimSize(inputType.getRank() - 1);
3519  }
3520 
3521  if (!multiplierType.hasRank())
3522  return success();
3523 
3524  ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
3525  // multiplier input has rank 1 by dialect definition
3526  if (multiplierShape[0] != ShapedType::kDynamic &&
3527  multiplierShape[0] != numChannels) {
3528  emitOpError("expect shape of { ")
3529  << numChannels << " } for multiplier input, got { "
3530  << multiplierShape[0] << " }";
3531  return failure();
3532  }
3533 
3534  if (!shiftType.hasRank())
3535  return success();
3536 
3537  ArrayRef<int64_t> shiftShape = shiftType.getShape();
3538  // shift input has rank 1 by dialect definition
3539  if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3540  emitOpError("expect shape of { ")
3541  << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
3542  return failure();
3543  }
3544 
3545  return success();
3546 }
3547 
3548 LogicalResult RescaleOp::inferReturnTypeComponents(
3549  MLIRContext *context, ::std::optional<Location> location,
3550  RescaleOp::Adaptor adaptor,
3551  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3552  ShapeAdaptor inputShape(adaptor.getInput().getType());
3553  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3554  return success();
3555 }
3556 
3557 LogicalResult IfOp::inferReturnTypeComponents(
3558  MLIRContext *context, ::std::optional<Location> location,
3559  IfOp::Adaptor adaptor,
3560  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3562  for (Region *region : adaptor.getRegions()) {
3563  for (auto &block : *region)
3564  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3565  yieldOps.push_back(returnOp);
3566  }
3567 
3568  if (yieldOps.empty())
3569  return failure();
3570 
3571  // Get the initial type information for the yield op.
3572  llvm::SmallVector<ValueKnowledge> resultKnowledge;
3573  resultKnowledge.reserve(yieldOps.front().getNumOperands());
3574  for (auto operand : yieldOps.front().getOperands()) {
3575  resultKnowledge.push_back(
3576  ValueKnowledge::getKnowledgeFromType(operand.getType()));
3577  }
3578 
3579  for (auto yieldOp : yieldOps) {
3580  if (resultKnowledge.size() != yieldOp.getNumOperands())
3581  return failure();
3582 
3583  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
3584  int32_t index = it.index();
3585  auto meet = ValueKnowledge::meet(
3586  resultKnowledge[index],
3587  ValueKnowledge::getKnowledgeFromType(it.value().getType()));
3588  if (!meet)
3589  continue;
3590  resultKnowledge[index] = meet;
3591  }
3592  }
3593 
3594  for (const ValueKnowledge &result : resultKnowledge) {
3595  inferredReturnShapes.push_back(result.getShapedTypeComponents());
3596  }
3597 
3598  return success();
3599 }
3600 
3601 LogicalResult WhileOp::inferReturnTypeComponents(
3602  MLIRContext *context, ::std::optional<Location> location,
3603  WhileOp::Adaptor adaptor,
3604  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3606  for (auto &block : adaptor.getBodyGraph())
3607  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3608  yieldOps.push_back(returnOp);
3609 
3610  // TOSA's while must have a tosa.yield as its terminator. If not found this
3611  // tosa.while is invalid.
3612  if (yieldOps.empty())
3613  return failure();
3614 
3615  // Get the initial type information from the operand types.
3616  llvm::SmallVector<ValueKnowledge> resultKnowledge;
3617  resultKnowledge.reserve(yieldOps.front().getNumOperands());
3618  for (auto operand : yieldOps.front().getOperands()) {
3619  resultKnowledge.push_back(
3620  ValueKnowledge::getKnowledgeFromType(operand.getType()));
3621  }
3622 
3623  for (auto yieldOp : yieldOps) {
3624  if (resultKnowledge.size() != yieldOp.getNumOperands())
3625  return failure();
3626 
3627  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
3628  int32_t index = it.index();
3629  if (auto meet = ValueKnowledge::meet(
3630  resultKnowledge[index],
3631  ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
3632  resultKnowledge[index] = meet;
3633  }
3634  }
3635  }
3636 
3637  for (const ValueKnowledge &result : resultKnowledge) {
3638  inferredReturnShapes.push_back(result.getShapedTypeComponents());
3639  }
3640 
3641  return success();
3642 }
3643 
3644 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
3645  if (auto vt = llvm::dyn_cast<VectorType>(getType()))
3646  return llvm::to_vector<4>(vt.getShape());
3647  return std::nullopt;
3648 }
3649 
3650 // parse and print of IfOp refer to the implementation of SCF dialect.
3651 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
3652  // Create the regions for 'then'.
3653  result.regions.reserve(2);
3654  Region *thenRegion = result.addRegion();
3655  Region *elseRegion = result.addRegion();
3656 
3657  auto &builder = parser.getBuilder();
3659  // Create a i1 tensor type for the boolean condition.
3660  Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
3661  if (parser.parseOperand(cond) ||
3662  parser.resolveOperand(cond, i1Type, result.operands))
3663  return failure();
3664  // Parse optional results type list.
3665  if (parser.parseOptionalArrowTypeList(result.types))
3666  return failure();
3667  // Parse the 'then' region.
3668  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
3669  return failure();
3670 
3671  // If we find an 'else' keyword then parse the 'else' region.
3672  if (!parser.parseOptionalKeyword("else")) {
3673  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
3674  return failure();
3675  }
3676 
3677  // Parse the optional attribute list.
3678  if (parser.parseOptionalAttrDict(result.attributes))
3679  return failure();
3680  return success();
3681 }
3682 
3683 void IfOp::print(OpAsmPrinter &p) {
3684  bool printBlockTerminators = false;
3685 
3686  p << " " << getCondition();
3687  if (!getResults().empty()) {
3688  p << " -> (" << getResultTypes() << ")";
3689  // Print yield explicitly if the op defines values.
3690  printBlockTerminators = true;
3691  }
3692  p << ' ';
3693  p.printRegion(getThenGraph(),
3694  /*printEntryBlockArgs=*/false,
3695  /*printBlockTerminators=*/printBlockTerminators);
3696 
3697  // Print the 'else' regions if it exists and has a block.
3698  auto &elseRegion = getElseGraph();
3699  if (!elseRegion.empty()) {
3700  p << " else ";
3701  p.printRegion(elseRegion,
3702  /*printEntryBlockArgs=*/false,
3703  /*printBlockTerminators=*/printBlockTerminators);
3704  }
3705 
3706  p.printOptionalAttrDict((*this)->getAttrs());
3707 }
3708 
3709 LogicalResult IfOp::verify() {
3710  if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
3711  "'then_graph' arguments", getInputList(),
3712  "'input_list'")
3713  .failed())
3714  return failure();
3715 
3716  if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(),
3717  "'else_graph' arguments", getInputList(),
3718  "'input_list'")
3719  .failed())
3720  return failure();
3721 
3722  auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
3723  if (errorIfTypeOrShapeMismatch(*this, thenYield.getInputs(),
3724  "'then_graph' results", getOutputList(),
3725  "'output_list'")
3726  .failed())
3727  return failure();
3728 
3729  auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
3730  if (errorIfTypeOrShapeMismatch(*this, elseYield.getInputs(),
3731  "'else_graph' results", getOutputList(),
3732  "'output_list'")
3733  .failed())
3734  return failure();
3735 
3736  auto condType = getCondition().getType();
3737  if (errorIfShapeNotSizeOne(*this, condType).failed())
3738  return emitOpError() << "'condition' must be a size 1 tensor, got "
3739  << condType;
3740 
3741  return success();
3742 }
3743 
3744 LogicalResult WhileOp::verify() {
3745  if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'",
3746  getOutputList(), "'output_list'")
3747  .failed())
3748  return failure();
3749 
3750  if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(),
3751  "'cond_graph' arguments", getInputList(),
3752  "'input_list'")
3753  .failed())
3754  return failure();
3755 
3756  if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(),
3757  "'body_graph' arguments", getInputList(),
3758  "'input_list'")
3759  .failed())
3760  return failure();
3761 
3762  auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
3763  if (errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
3764  "'body_graph' results", getInputList(),
3765  "'input_list'")
3766  .failed())
3767  return failure();
3768 
3769  // Condition block output must be a single element tensor with a single bool
3770  // value.
3771  auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
3772  if (condYield.getInputs().size() != 1)
3773  return emitOpError() << "require 'cond_graph' only have one result";
3774 
3775  auto condOutType = condYield.getInputs()[0].getType();
3776  if (errorIfShapeNotSizeOne(*this, condOutType).failed())
3777  return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
3778  << condOutType;
3779 
3780  if (!getElementTypeOrSelf(condOutType).isInteger(1))
3781  return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
3782  << condOutType;
3783 
3784  return success();
3785 }
3786 
3787 LogicalResult ReverseOp::verify() {
3788  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
3789  /* outType = */ getOutput().getType())
3790  .failed())
3791  return failure();
3792  TensorType inputType = getInput1().getType();
3793  TensorType outputType = getOutput().getType();
3794  int32_t reverseAxis = getAxis();
3795 
3796  if (reverseAxis < 0)
3797  return emitOpError("expected non-negative reverse axis");
3798  if (inputType.hasRank()) {
3799  int64_t inputRank = inputType.getRank();
3800  // We allow for a special case where the input/output shape has rank 0 and
3801  // axis is also 0.
3802  if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
3803  return emitOpError("expect input tensor rank (")
3804  << inputRank << ") to be larger than reverse axis (" << reverseAxis
3805  << ")";
3806  }
3807  if (outputType.hasRank()) {
3808  int64_t outputRank = outputType.getRank();
3809  if (inputType.hasRank() && outputRank != inputType.getRank())
3810  return emitOpError(
3811  "expect output tensor rank to be equal to input tensor rank");
3812  if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
3813  return emitOpError("expect output tensor rank (")
3814  << outputRank << ") to be larger than reverse axis ("
3815  << reverseAxis << ")";
3816  }
3817  return success();
3818 }
3819 
3820 LogicalResult tosa::SelectOp::verify() {
3821  // verify input2 and input3 have same element type as output
3822  if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(),
3823  /* outType = */ getOutput().getType())
3824  .failed() ||
3825  verifySameElementTypes(*this, /* inType = */ getInput3().getType(),
3826  /* outType = */ getOutput().getType())
3827  .failed()) {
3828  return failure();
3829  }
3830  // verify input1 has element type of bool
3831  auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().getType());
3832  if (!predicateType) {
3833  return emitOpError("expect shaped tensor for input1, got ")
3834  << getInput1().getType();
3835  }
3836  auto predicateElementType = predicateType.getElementType();
3837  if (!predicateElementType.isInteger(1)) {
3838  return emitOpError("expect element type of bool for input1, got ")
3839  << predicateElementType;
3840  }
3841 
3842  return success();
3843 }
3844 
3845 LogicalResult tosa::VariableOp::verify() {
3846  StringRef symName = getName();
3847  FailureOr<tosa::VariableOp> varOp = findVariableDecl(*this, symName);
3848  if (succeeded(varOp))
3849  return emitOpError("illegal to have multiple declaration of '")
3850  << symName << "'";
3851 
3852  return success();
3853 }
3854 
3855 LogicalResult tosa::VariableReadOp::verify() {
3856  if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
3857  .failed())
3858  return failure();
3859 
3860  return success();
3861 }
3862 
3863 LogicalResult tosa::VariableWriteOp::verify() {
3864  if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
3865  .failed())
3866  return failure();
3867 
3868  return success();
3869 }
3870 
3871 // parse and print of WhileOp refer to the implementation of SCF dialect.
3872 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3875  Region *cond = result.addRegion();
3876  Region *body = result.addRegion();
3877 
3878  OptionalParseResult listResult =
3879  parser.parseOptionalAssignmentList(regionArgs, operands);
3880  if (listResult.has_value() && failed(listResult.value()))
3881  return failure();
3882 
3883  FunctionType functionType;
3884  SMLoc typeLoc = parser.getCurrentLocation();
3885  if (failed(parser.parseColonType(functionType)))
3886  return failure();
3887 
3888  result.addTypes(functionType.getResults());
3889 
3890  if (functionType.getNumInputs() != operands.size()) {
3891  return parser.emitError(typeLoc)
3892  << "expected as many input types as operands "
3893  << "(expected " << operands.size() << " got "
3894  << functionType.getNumInputs() << ")";
3895  }
3896 
3897  // Resolve input operands.
3898  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3899  parser.getCurrentLocation(),
3900  result.operands)))
3901  return failure();
3902 
3903  // Propagate the types into the region arguments.
3904  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3905  regionArgs[i].type = functionType.getInput(i);
3906 
3907  return failure(parser.parseRegion(*cond, regionArgs) ||
3908  parser.parseKeyword("do") || parser.parseRegion(*body) ||
3910 }
3911 
3913  Block::BlockArgListType blocksArgs,
3914  ValueRange initializers,
3915  StringRef prefix = "") {
3916  assert(blocksArgs.size() == initializers.size() &&
3917  "expected same length of arguments and initializers");
3918  if (initializers.empty())
3919  return;
3920 
3921  parser << prefix << '(';
3922  llvm::interleaveComma(
3923  llvm::zip(blocksArgs, initializers), parser,
3924  [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
3925  parser << ")";
3926 }
3927 
3928 void WhileOp::print(OpAsmPrinter &parser) {
3929  printInitializationList(parser, getCondGraph().front().getArguments(),
3930  getInputList(), " ");
3931  parser << " : ";
3932  parser.printFunctionalType(getInputList().getTypes(),
3933  getResults().getTypes());
3934  parser << ' ';
3935  parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
3936  parser << " do ";
3937  parser.printRegion(getBodyGraph());
3938  parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3939 }
3940 
3941 // Create a rank-1 const tensor for zero point of the source tensor.
3942 std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
3943  Location loc,
3944  Type srcElemType,
3945  int64_t zp) {
3946  srcElemType = getStorageElementTypeOrSelf(srcElemType);
3947  auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
3948  if (llvm::isa<FloatType>(srcElemType)) {
3949  auto zpAttr = DenseElementsAttr::get(
3950  zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
3951  return builder.create<tosa::ConstOp>(loc, zpType, zpAttr);
3952  }
3953  if (llvm::isa<IntegerType>(srcElemType)) {
3954  auto zpAttr =
3955  DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp));
3956  return builder.create<tosa::ConstOp>(loc, zpType, zpAttr);
3957  }
3958  llvm::errs() << "zero point is not allowed for unsupported data types\n";
3959  return std::nullopt;
3960 }
3961 
3962 //===----------------------------------------------------------------------===//
3963 // TOSA Shape and Shape Operators Helper functions.
3964 //===----------------------------------------------------------------------===//
3965 
3967  return mlir::isa<tosa::shapeType>(t);
3968 }
3969 
3970 LogicalResult
3972  int rank) {
3973  if (rank < 0)
3974  return emitError() << "invalid rank (must be >= 0): " << rank;
3975  return success();
3976 }
3977 
3979  for (auto v : op->getOperands()) {
3980  if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
3981  Operation *definingOp = v.getDefiningOp();
3982  if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
3983  return op->emitOpError("shape operand is not compile time resolvable");
3984  }
3985  }
3986  }
3987  return success();
3988 }
3989 
3991  for (auto type : op->getOperandTypes()) {
3992  if (!mlir::isa<mlir::tosa::shapeType>(type)) {
3993  return op->emitOpError("must have operands with tosa shape type");
3994  }
3995  }
3996  for (auto type : op->getResultTypes()) {
3997  if (!mlir::isa<mlir::tosa::shapeType>(type)) {
3998  return op->emitOpError("must have result with tosa shape type");
3999  }
4000  }
4001  return success();
4002 }
4003 
4004 LogicalResult
4006  if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) ||
4007  failed(verifyTosaShapeOperator(op)))
4008  return failure();
4009 
4010  // delegate function that returns rank of shape type
4011  auto getRank = [](const Type type) {
4012  return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4013  };
4014  auto operandTypes = op->getOperandTypes();
4015  auto resultTypes = op->getResultTypes();
4016 
4017  auto rank = getRank(*op->getOperandTypes().begin());
4018  for (auto type : operandTypes) {
4019  if (getRank(type) != rank) {
4020  return op->emitOpError("operands don't have matching ranks");
4021  }
4022  }
4023  for (auto type : resultTypes) {
4024  if (getRank(type) != rank) {
4025  return op->emitOpError("result shape has different rank than operands");
4026  }
4027  }
4028  return success();
4029 }
4030 
4031 //===----------------------------------------------------------------------===//
4032 // TOSA Shape Operators verify functions.
4033 //===----------------------------------------------------------------------===//
4034 
4035 LogicalResult tosa::ConstShapeOp::verify() {
4036  // check one dimensional rank
4037  auto valuesRank = getValues().getType().getRank();
4038  if (valuesRank != 1)
4039  return emitOpError("expect elements in attribute values with rank 1");
4040  // check that number of elements in values attr equal to rank of result shape
4041  auto count = getValues().getNumElements();
4042  auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
4043  if (!(count == rank || (count == 1 && rank == 0))) {
4044  return emitOpError("expect number of elements in attribute values (")
4045  << count << ") to be equal to the rank (" << rank
4046  << ") for the result shape type";
4047  }
4048  return success();
4049 }
4050 
4051 //===----------------------------------------------------------------------===//
4052 // TOSA Attribute Definitions.
4053 //===----------------------------------------------------------------------===//
4054 
4055 #define GET_ATTRDEF_CLASSES
4056 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
4057 
4058 //===----------------------------------------------------------------------===//
4059 // TOSA Type Definitions.
4060 //===----------------------------------------------------------------------===//
4061 #define GET_TYPEDEF_CLASSES
4062 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
4063 
4064 //===----------------------------------------------------------------------===//
4065 // TOSA Operator Definitions.
4066 //===----------------------------------------------------------------------===//
4067 
4068 #define GET_OP_CLASSES
4069 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
Definition: AMXDialect.cpp:85
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition: TosaOps.cpp:1035
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
Definition: TosaOps.cpp:2183
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType)
Definition: TosaOps.cpp:727
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:2746
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
Definition: TosaOps.cpp:296
static FailureOr< tosa::VariableOp > findVariableDecl(Operation *op, StringRef symName)
Definition: TosaOps.cpp:674
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
Definition: TosaOps.cpp:664
#define REDUCE_SHAPE_INFER(OP)
Definition: TosaOps.cpp:2771
static LogicalResult verifyConvOp(T op)
Definition: TosaOps.cpp:336
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
Definition: TosaOps.cpp:708
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:2961
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition: TosaOps.cpp:1149
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
Definition: TosaOps.cpp:1163
static LogicalResult verifyReduceOp(T op)
Definition: TosaOps.cpp:2796
#define NARY_SHAPE_INFER(OP)
Definition: TosaOps.cpp:2864
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
Definition: TosaOps.cpp:2253
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition: TosaOps.cpp:1013
static LogicalResult verifyConvOpErrorIf(T op)
Definition: TosaOps.cpp:482
static LogicalResult verifyConvOpModes(T op)
Definition: TosaOps.cpp:441
std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition: TosaOps.cpp:279
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:2852
#define COMPATIBLE_RETURN_TYPES(OP)
Definition: TosaOps.cpp:2762
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition: TosaOps.cpp:1194
Type getStorageElementTypeOrSelf(Type type)
Definition: TosaOps.cpp:285
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
Definition: TosaOps.cpp:1109
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition: TosaOps.cpp:989
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
Definition: TosaOps.cpp:1064
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
Definition: TosaOps.cpp:622
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition: TosaOps.cpp:138
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
Definition: TosaOps.cpp:2211
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Definition: TosaOps.cpp:3912
static LogicalResult verifyPoolingOp(T op)
Definition: TosaOps.cpp:790
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
Definition: TosaOps.cpp:1283
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:106
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:226
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:252
IntegerType getI32Type()
Definition: Builders.cpp:65
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:69
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:260
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:191
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:95
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
This class indicates that op operates on tosa shape types.
Definition: TosaOps.h:130
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
Definition: Operation.cpp:917
LogicalResult verifyTosaShapeOperator(Operation *op)
Definition: TosaOps.cpp:3990
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
Definition: TosaOps.cpp:4005
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
Definition: TosaOps.cpp:3978
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition: Traits.cpp:60
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:22
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Definition: QuantUtils.cpp:198
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
Definition: QuantUtils.cpp:289
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
Definition: TosaOps.cpp:227
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
Definition: QuantUtils.cpp:269
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
Definition: QuantUtils.cpp:162
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
Definition: TosaOps.cpp:252
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
Definition: QuantUtils.cpp:214
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
Definition: TosaOps.cpp:3942
bool isa_tosa_shape_type(mlir::Type t)
Definition: TosaOps.cpp:3966
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Definition: QuantUtils.cpp:243
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
Definition: TosaOps.cpp:316
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
Definition: ShapeUtils.h:33
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45