MLIR  21.0.0git
TosaOps.cpp
Go to the documentation of this file.
1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://www.mlplatform.org/tosa/tosa_spec.html
12 //
13 //===----------------------------------------------------------------------===//
14 
22 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 #include <numeric>
34 
35 using namespace mlir;
36 using namespace mlir::tosa;
37 
38 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
40 
41 //===----------------------------------------------------------------------===//
42 // Tosa dialect interface includes.
43 //===----------------------------------------------------------------------===//
44 
45 #include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
46 #include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
47 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
48 #include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
49 
50 namespace {
51 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
52 
53 //===----------------------------------------------------------------------===//
54 // Dialect Function Inliner Interface.
55 //===----------------------------------------------------------------------===//
56 struct TosaInlinerInterface : public DialectInlinerInterface {
58 
59  //===--------------------------------------------------------------------===//
60  // Analysis Hooks.
61  //===--------------------------------------------------------------------===//
62 
63  /// All operations can be inlined by default.
64  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
65  IRMapping &map) const final {
66  return true;
67  }
68 
69  /// All regions with If and While parent operators can be inlined.
70  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
71  IRMapping &map) const final {
72  return (isa<tosa::IfOp>(dest->getParentOp()) ||
73  isa<tosa::WhileOp>(dest->getParentOp()));
74  }
75 };
76 
77 /// This class implements the bytecode interface for the Tosa dialect.
78 struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
79  TosaDialectBytecodeInterface(Dialect *dialect)
80  : BytecodeDialectInterface(dialect) {}
81 
82  //===--------------------------------------------------------------------===//
83  // Attributes
84 
85  Attribute readAttribute(DialectBytecodeReader &reader) const override {
86  return ::readAttribute(getContext(), reader);
87  }
88 
89  LogicalResult writeAttribute(Attribute attr,
90  DialectBytecodeWriter &writer) const override {
91  return ::writeAttribute(attr, writer);
92  }
93 
94  //===--------------------------------------------------------------------===//
95  // Types
96 
97  Type readType(DialectBytecodeReader &reader) const override {
98  return ::readType(getContext(), reader);
99  }
100 
101  LogicalResult writeType(Type type,
102  DialectBytecodeWriter &writer) const override {
103  return ::writeType(type, writer);
104  }
105 
106  void writeVersion(DialectBytecodeWriter &writer) const final {
107  // TODO: Populate.
108  }
109 
110  std::unique_ptr<DialectVersion>
111  readVersion(DialectBytecodeReader &reader) const final {
112  // TODO: Populate
113  reader.emitError("Dialect does not support versioning");
114  return nullptr;
115  }
116 
117  LogicalResult upgradeFromVersion(Operation *topLevelOp,
118  const DialectVersion &version) const final {
119  return success();
120  }
121 };
122 
123 } // namespace
124 
125 //===----------------------------------------------------------------------===//
126 // TOSA control flow support.
127 //===----------------------------------------------------------------------===//
128 
129 /// Returns the while loop body.
130 SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
131  return {&getBodyGraph()};
132 }
133 
134 //===----------------------------------------------------------------------===//
135 // TOSA variable operator support.
136 //===----------------------------------------------------------------------===//
137 
139  return to_vector(llvm::map_range(shape, [](int64_t dim) {
140  return dim == -1 ? ShapedType::kDynamic : dim;
141  }));
142 }
143 
144 // returns type of variable op
145 RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
146  Type elementType = variableOp.getType();
147  DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
148  auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
149  return RankedTensorType::get(shape, elementType);
150 }
151 
152 //===----------------------------------------------------------------------===//
153 // Tosa dialect initialization.
154 //===----------------------------------------------------------------------===//
155 
156 void TosaDialect::initialize() {
157  addTypes<
158 #define GET_TYPEDEF_LIST
159 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
160  >();
161  addOperations<
162 #define GET_OP_LIST
163 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
164  >();
165  addAttributes<
166 #define GET_ATTRDEF_LIST
167 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
168  >();
169  addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
170  declarePromisedInterfaces<
171  mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
172  ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
173  LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
174  LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
175  BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
176  NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
177  GreaterEqualOp, MatMulOp>();
178 }
179 
181  Type type, Location loc) {
182  // Tosa dialect constants only support ElementsAttr unlike standard dialect
183  // constant which supports all attributes.
184  if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
185  return builder.create<tosa::ConstShapeOp>(
186  loc, type, llvm::cast<DenseIntElementsAttr>(value));
187  }
188  if (llvm::isa<ElementsAttr>(value))
189  return builder.create<tosa::ConstOp>(loc, type,
190  llvm::cast<ElementsAttr>(value));
191  return nullptr;
192 }
193 
194 //===----------------------------------------------------------------------===//
195 // Parsers and printers
196 //===----------------------------------------------------------------------===//
197 
198 namespace {
199 
200 ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
201  DenseElementsAttr &varShapeAttr,
202  TypeAttr &typeAttr) {
203  if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
204  if (!shapedType.hasRank())
205  return parser.emitError(parser.getCurrentLocation())
206  << "expected ranked type";
207 
208  auto elementType = shapedType.getElementType();
209  typeAttr = TypeAttr::get(elementType);
210  ArrayRef<int64_t> shape = shapedType.getShape();
211  Builder builder(parser.getContext());
212  varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
213  return success();
214  }
215  return parser.emitError(parser.getCurrentLocation())
216  << "expected shaped type";
217 }
218 
219 } // namespace
220 
221 // parses the optional initial value or type for a tosa variable
222 // with initial value:
223 // tosa.variable @name = dense<0.0> : tensor<1x8xf32>
224 //
225 // without initial value:
226 // tosa.variable @name : tensor<1x8xf32>
228  OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
229  Attribute &initialValueAttr) {
230  if (succeeded(parser.parseOptionalEqual())) {
231  if (failed(parser.parseAttribute(initialValueAttr))) {
232  return parser.emitError(parser.getCurrentLocation())
233  << "expected attribute";
234  }
235  if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
236  return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
237  typeAttr);
238  }
239  return parser.emitError(parser.getCurrentLocation())
240  << "expected Typed attr";
241  }
242 
243  initialValueAttr = nullptr;
244  Type parsedType;
245  if (failed(parser.parseColonType(parsedType))) {
246  return parser.emitError(parser.getCurrentLocation())
247  << "expected type after colon";
248  }
249  return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
250 }
251 
253  OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
254  TypeAttr typeAttr, Attribute initialValueAttr) {
255  bool needsSpace = false;
256  if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
257  auto shape =
258  convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
259  Type elementType = typeAttr.getValue();
260  RankedTensorType tensorType =
261  RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
262  auto tensorTypeAttr = TypeAttr::get(tensorType);
263  p << ": ";
264  p.printAttribute(tensorTypeAttr);
265  needsSpace = true; // subsequent attr value needs a space separator
266  }
267  if (initialValueAttr) {
268  if (needsSpace)
269  p << ' ';
270  p << "= ";
271  p.printAttribute(initialValueAttr);
272  }
273 }
274 
275 //===----------------------------------------------------------------------===//
276 // Tosa utilities.
277 //===----------------------------------------------------------------------===//
278 
279 std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
280  if (lhs % rhs != 0)
281  return std::nullopt;
282  return lhs / rhs;
283 }
284 
286  auto srcType = getElementTypeOrSelf(type);
287  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
288  srcType = quantType.getStorageType();
289  return srcType;
290 }
291 
293  return getStorageElementTypeOrSelf(value.getType());
294 }
295 
296 static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
297  Value valZp, StringRef name) {
299  Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
300 
301  bool bothInts =
302  mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
303  bool sameBitWidth =
304  (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
305 
306  if (!bothInts || !sameBitWidth) {
307  return op->emitOpError()
308  << "expected " << name << " and " << name
309  << "_zp to both be integer of the same bitwidth, but got " << eType
310  << " vs. " << eZpType;
311  }
312  return success();
313 }
314 
315 // Create a pad-const const tensor with value of `val` of required data-type
317  Value src, int32_t val) {
318  const auto srcType = getElementTypeOrSelf(src);
319  const auto srcElemType = getStorageElementTypeOrSelf(src);
320  const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
321  const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
322  const auto padConstAttr{
323  llvm::isa<FloatType>(srcElemType)
324  ? DenseElementsAttr::get(padConstEType,
325  builder.getFloatAttr(srcElemType, val))
326  : DenseElementsAttr::get(padConstEType,
327  builder.getIntegerAttr(srcElemType, val))};
328  return builder.create<tosa::ConstOp>(loc, padConstType, padConstAttr);
329 }
330 
331 //===----------------------------------------------------------------------===//
332 // TOSA Operator Verifiers.
333 //===----------------------------------------------------------------------===//
334 
335 template <typename T>
336 static LogicalResult verifyConvOp(T op) {
337  const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
338  const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
339 
340  auto inputEType = inputType.getElementType();
341  auto weightEType = weightType.getElementType();
342  auto biasEType =
343  llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
344  auto resultEType =
345  llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
346  bool biasIsFloat = llvm::isa<FloatType>(biasEType);
347  bool resultIsFloat = llvm::isa<FloatType>(resultEType);
348 
349  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
350  inputEType = quantType.getStorageType();
351 
352  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
353  weightEType = quantType.getStorageType();
354 
355  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
356  biasEType = quantType.getStorageType();
357 
358  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
359  resultEType = quantType.getStorageType();
360 
361  if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
362  // for now, only enforce bias element type == result element type for
363  // float types.
364  op.emitOpError(
365  "expect both bias and result to have same element type, got ")
366  << biasEType << " and " << resultEType;
367  return failure();
368  }
369 
370  if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
371  isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
372  if (inputEType != weightEType) {
373  op.emitOpError(
374  "expect both input and weight to have same element type, got ")
375  << inputEType << " and " << weightEType;
376  return failure();
377  }
378  }
379 
380  bool inputIsFloat = llvm::isa<FloatType>(inputEType);
381  bool weightIsFloat = llvm::isa<FloatType>(weightEType);
382 
383  // Either both must be float or both non-float.
384  if (inputIsFloat != weightIsFloat) {
385  op.emitOpError(
386  "expect both input and weight to be float or not together, got ")
387  << inputEType << " and " << weightEType;
388  return failure();
389  }
390 
391  auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
392  if (inputEType != inputZpEType) {
393  return op.emitOpError("expect both input and its zero point are the same "
394  "element type, got ")
395  << inputEType << " and " << inputZpEType;
396  }
397 
398  auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
399  if (weightEType != weightZpEType) {
400  return op.emitOpError("expect both weight and its zero point are the same "
401  "element type, got ")
402  << weightEType << " and " << weightZpEType;
403  }
404 
405  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
406  if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
407  return failure();
408 
409  FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
410  if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
411  return failure();
412 
413  return success();
414 }
415 
416 LogicalResult tosa::ConstOp::verify() {
417 
418  auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().getType());
419  auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
420 
421  if (!attrType || !outputType) {
422  emitOpError("expected tensors for attr/result type");
423  return failure();
424  }
425 
426  if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
427  outputType.getElementType())) {
428  if (result.getStorageType() == attrType.getElementType())
429  return success();
430  }
431 
432  if (attrType.getElementType() != outputType.getElementType()) {
433  emitOpError("expected same attr/result element types");
434  return failure();
435  }
436 
437  return success();
438 }
439 
440 template <typename T>
441 static LogicalResult verifyConvOpModes(T op) {
442  auto inputEType =
443  llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
444 
445  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
446  inputEType = quantType.getStorageType();
447 
448  auto accType = op.getAccType();
449  if (inputEType.isInteger(8) && !accType.isInteger(32))
450  return op.emitOpError("accumulator type for i8 tensor is not i32");
451 
452  if (inputEType.isInteger(16) && !accType.isInteger(48))
453  return op.emitOpError("accumulator type for i16 tensor is not i48");
454 
455  if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
456  return op.emitOpError("accumulator type for f8 tensor is not f16");
457 
458  if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
459  return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
460 
461  if (inputEType.isBF16() && !accType.isF32())
462  return op.emitOpError("accumulator type for bf16 tensor is not f32");
463 
464  if (inputEType.isF32() && !accType.isF32())
465  return op.emitOpError("accumulator type for f32 tensor is not f32");
466 
467  auto resultEType =
468  llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
469 
470  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
471  resultEType = quantType.getStorageType();
472 
473  return success();
474 }
475 
476 //===----------------------------------------------------------------------===//
477 // ERROR_IF functions.
478 // ERROR_IF is a predicate that must set an error if the condition holds.
479 //===----------------------------------------------------------------------===//
480 
481 template <typename T>
482 static LogicalResult verifyConvOpErrorIf(T op) {
483  llvm::ArrayRef<int64_t> padding = op.getPad();
484  if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
485  return op.emitOpError("expect all padding values to be >= 0, got ")
486  << padding;
487 
488  llvm::ArrayRef<int64_t> strides = op.getStride();
489  if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
490  return op.emitOpError("expect all stride values to be >= 1, got ")
491  << strides;
492 
493  llvm::ArrayRef<int64_t> dilations = op.getDilation();
494  if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
495  return op.emitOpError("expect all dilation values to be >= 1, got ")
496  << dilations;
497 
498  const RankedTensorType outputType =
499  llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
500  if (!outputType)
501  // Skip following checks if output is not ranked
502  return success();
503 
504  const RankedTensorType inputType =
505  llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
506  const RankedTensorType weightType =
507  llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
508 
509  if (inputType && weightType) {
510  const auto verifyOutputSize =
511  [&op](const int64_t inputSize, const int64_t kernelSize,
512  const int64_t outputSize, const int64_t padBefore,
513  const int64_t padAfter, const int64_t stride,
514  const int64_t dilation, const llvm::StringRef dimName,
515  const llvm::StringRef dimAxis,
516  const llvm::StringRef padBeforeName,
517  const llvm::StringRef padAfterName) -> LogicalResult {
518  if (inputSize == ShapedType::kDynamic ||
519  kernelSize == ShapedType::kDynamic)
520  return success();
521 
522  // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
523 
524  const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
525  inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
526  stride);
527  if (!calculatedOutSizeMinusOne.has_value())
528  return op.emitOpError("expected input_")
529  << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
530  << padAfterName << " - (kernel_" << dimName
531  << " - 1) * dilation_" << dimAxis
532  << " to be wholly divisible by stride_" << dimAxis << ", got ("
533  << inputSize << " - 1 + " << padBefore << " + " << padAfter
534  << " - (" << kernelSize << " - 1) * " << dilation << ") / "
535  << stride;
536 
537  const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
538  if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
539  return op.emitOpError("calculated output ")
540  << dimName << " did not match expected: "
541  << "calculated=" << calculatedOutSize
542  << ", expected=" << outputSize;
543 
544  return success();
545  };
546 
547  // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
548  if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
549  if (failed(verifyOutputSize(
550  inputType.getDimSize(1), weightType.getDimSize(1),
551  outputType.getDimSize(1), padding[0], padding[1], strides[0],
552  dilations[0], "height", "y", "top", "bottom")))
553  return failure();
554 
555  if (failed(verifyOutputSize(
556  inputType.getDimSize(2), weightType.getDimSize(2),
557  outputType.getDimSize(2), padding[2], padding[3], strides[1],
558  dilations[1], "width", "x", "left", "right")))
559  return failure();
560  }
561 
562  // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
563  if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
564  if (failed(verifyOutputSize(
565  inputType.getDimSize(1), weightType.getDimSize(0),
566  outputType.getDimSize(1), padding[0], padding[1], strides[0],
567  dilations[0], "height", "y", "top", "bottom")))
568  return failure();
569 
570  if (failed(verifyOutputSize(
571  inputType.getDimSize(2), weightType.getDimSize(1),
572  outputType.getDimSize(2), padding[2], padding[3], strides[1],
573  dilations[1], "width", "x", "left", "right")))
574  return failure();
575  }
576 
577  // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
578  if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
579  if (failed(verifyOutputSize(
580  inputType.getDimSize(1), weightType.getDimSize(1),
581  outputType.getDimSize(1), padding[0], padding[1], strides[0],
582  dilations[0], "depth", "d", "front", "back")))
583  return failure();
584 
585  if (failed(verifyOutputSize(
586  inputType.getDimSize(2), weightType.getDimSize(2),
587  outputType.getDimSize(2), padding[2], padding[3], strides[1],
588  dilations[1], "height", "y", "top", "bottom")))
589  return failure();
590 
591  if (failed(verifyOutputSize(
592  inputType.getDimSize(3), weightType.getDimSize(3),
593  outputType.getDimSize(3), padding[4], padding[5], strides[2],
594  dilations[2], "width", "x", "left", "right")))
595  return failure();
596  }
597  }
598 
599  const RankedTensorType biasType =
600  llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
601  if (!biasType)
602  // Skip following checks if bias is not ranked
603  return success();
604 
605  const int64_t biasChannels = biasType.getDimSize(0);
606  const int64_t outputChannels =
607  outputType.getDimSize(outputType.getRank() - 1);
608  if (biasChannels == ShapedType::kDynamic ||
609  outputChannels == ShapedType::kDynamic)
610  // Skip following checks if biasChannels or outputChannels is dynamic dim
611  return success();
612 
613  if (biasChannels != outputChannels && biasChannels != 1)
614  return op.emitOpError(
615  "bias channels expected to be equal to output channels (")
616  << outputChannels << ") or 1, got " << biasChannels;
617 
618  return success();
619 }
620 
621 // Verify whether same type and shape of the given two types.
622 static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
623  StringRef name1, Type type2,
624  StringRef name2) {
625  auto shapeType1 = dyn_cast<ShapedType>(type1);
626  auto shapeType2 = dyn_cast<ShapedType>(type2);
627  if (!shapeType1 || !shapeType2)
628  return failure();
629 
630  auto elemType1 = shapeType1.getElementType();
631  auto elemType2 = shapeType2.getElementType();
632  if (elemType1 != elemType2)
633  return op->emitOpError()
634  << "require same element type for " << name1 << " (" << elemType1
635  << ") and " << name2 << " (" << elemType2 << ")";
636 
637  if (failed(verifyCompatibleShape(type1, type2)))
638  return op->emitOpError()
639  << "require same shapes for " << name1 << " (" << type1 << ") and "
640  << name2 << " (" << type2 << ")";
641 
642  return success();
643 }
644 
645 // Verify whether same length, type, and shape of the given two tensor lists.
646 static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
647  StringRef name1,
648  ValueRange list2,
649  StringRef name2) {
650  if (list1.size() != list2.size())
651  return op->emitOpError()
652  << "require same number of values in " << name1 << " ("
653  << list1.size() << ") and " << name2 << " (" << list2.size() << ")";
654 
655  for (auto [type1, type2] :
656  llvm::zip_equal(list1.getTypes(), list2.getTypes())) {
657  if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
658  return failure();
659  }
660 
661  return success();
662 }
663 
664 static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
665  ShapeAdaptor shapeAdaptor(type);
666  if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
667  return success();
668 
669  return shapeAdaptor.getNumElements() == 1 ? success() : failure();
670 }
671 
672 // Returns the first declaration point prior to this operation or failure if
673 // not found.
674 static FailureOr<tosa::VariableOp> findVariableDecl(Operation *op,
675  StringRef symName) {
676  ModuleOp module = op->getParentOfType<ModuleOp>();
677  tosa::VariableOp varOp = nullptr;
678 
679  // TODO: Adopt SymbolTable trait to Varible ops.
680  // Currently, the variable's definition point is searched via walk(),
681  // starting from the top-level ModuleOp and stopping at the point of use. Once
682  // TOSA control flow and variable extensions reach the complete state, may
683  // leverage MLIR's Symbol Table functionality to look up symbol and enhance
684  // the search to a TOSA specific graph traversal over the IR structure.
685  module.walk([&](Operation *tempOp) {
686  // Reach this op itself.
687  if (tempOp == op) {
688  return WalkResult::interrupt();
689  }
690 
691  if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
692  if (symName == tosaOp.getName()) {
693  varOp = tosaOp;
694  return WalkResult::interrupt();
695  }
696  }
697 
698  return WalkResult::advance();
699  });
700 
701  if (varOp)
702  return varOp;
703 
704  return failure();
705 }
706 
707 template <typename T>
708 static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
709  StringRef symName = op.getName();
710  FailureOr<tosa::VariableOp> varOp = findVariableDecl(op, symName);
711  if (failed(varOp))
712  return op->emitOpError("'")
713  << symName << "' has not been declared by 'tosa.variable'";
714 
715  // Verify type and shape
716  auto variableType = getVariableType(varOp.value());
717  if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
718  "the input tensor")
719  .failed())
720  return failure();
721 
722  return success();
723 }
724 
725 // verify that inType and outType have same element types
726 template <typename T>
727 static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
728  auto inputType = llvm::dyn_cast<TensorType>(inType);
729  auto outputType = llvm::dyn_cast<TensorType>(outType);
730  if (!inputType) {
731  op.emitOpError("expect shaped tensor for input, got ") << inType;
732  return failure();
733  }
734  if (!outputType) {
735  op.emitOpError("expect shaped tensor for output, got ") << outType;
736  return failure();
737  }
738  auto inputElementType = inputType.getElementType();
739  auto outputElementType = outputType.getElementType();
740  auto inputQuantType =
741  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
742  auto outputQuantType =
743  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
744  if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
745  (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
746  inputElementType != outputElementType) {
747  // only check if both element types are int/index/float/UniformQuantized
748  // eg, not sure how to check quant::QuantizedType
749  // this happens in test_conv2d_q_grouped_convolution in
750  // tfl-to-tosa-pipeline.mlir
751  op.emitOpError("expect input and output to have same element type, got ")
752  << inputElementType << " and " << outputElementType;
753  return failure();
754  }
755  return success();
756 }
757 
758 LogicalResult tosa::ArgMaxOp::verify() {
759  const ShapedType resultType = llvm::cast<ShapedType>(getType());
760 
761  // Ensure output is of 32-bit integer
762  if (const auto resultETy = resultType.getElementType();
763  !resultETy.isIntOrIndex())
764  return emitOpError("result tensor is not of integer type");
765 
766  const auto inputType = llvm::cast<ShapedType>(getInput().getType());
767  if (!inputType.hasRank())
768  return success();
769 
770  // Ensure axis is within the tensor rank
771  const int64_t axis = getAxisAttr().getInt();
772  if (((axis < 0) || axis >= inputType.getRank()))
773  return emitOpError("specified axis is outside the rank of the tensor");
774 
775  if (!resultType.hasRank())
776  return success();
777 
778  const ArrayRef<int64_t> inputShape = inputType.getShape();
779  const ArrayRef<int64_t> outputShape = resultType.getShape();
780  llvm::SmallVector<int64_t> expectedOutputShape(inputShape);
781  expectedOutputShape.erase(expectedOutputShape.begin() + axis);
782  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
783  return emitOpError("expected output shape '")
784  << expectedOutputShape << "', got '" << outputShape << "'";
785 
786  return success();
787 }
788 
789 template <typename T>
790 static LogicalResult verifyPoolingOp(T op) {
791  const llvm::ArrayRef<int64_t> kernel = op.getKernel();
792  if (llvm::any_of(kernel, [](int64_t s) { return s < 1; }))
793  return op.emitOpError("expect all kernel values to be >= 1, got ")
794  << kernel;
795 
796  const llvm::ArrayRef<int64_t> strides = op.getStride();
797  if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
798  return op.emitOpError("expect all stride values to be >= 1, got ")
799  << strides;
800 
801  const llvm::ArrayRef<int64_t> padding = op.getPad();
802  if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
803  return op.emitOpError("expect all padding values to be >= 0, got ")
804  << padding;
805 
806  // Padding must be less than kernel size to avoid a divide-by-zero
807  const int64_t kernelX = kernel[1];
808  const int64_t padLeft = padding[2];
809  const int64_t padRight = padding[3];
810  if (padRight >= kernelX || padLeft >= kernelX)
811  return op.emitOpError("expected left/right padding to be less than the "
812  "width of the kernel, got pad_left=")
813  << padLeft << ", pad_right=" << padRight << ", kernel_x=" << kernelX;
814 
815  const int64_t kernelY = kernel[0];
816  const int64_t padTop = padding[0];
817  const int64_t padBottom = padding[1];
818  if (padTop >= kernelY || padBottom >= kernelY)
819  return op.emitOpError("expected top/bottom padding to be less than the "
820  "height of the kernel, got pad_top=")
821  << padTop << ", pad_bottom=" << padBottom
822  << ", kernel_y=" << kernelY;
823 
824  const auto inputType =
825  llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
826  const auto outputType =
827  llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
828  if (!inputType || !outputType)
829  return success();
830 
831  const auto verifyOutputSize =
832  [&op](const int64_t inputSize, const int64_t outputSize,
833  const int64_t kernelSize, const int64_t strideSize,
834  const int64_t padBefore, const int64_t padAfter,
835  const llvm::StringRef dimName, const llvm::StringRef dimAxis,
836  const llvm::StringRef padBeforeName,
837  const llvm::StringRef padAfterName) -> LogicalResult {
838  if (ShapedType::isDynamic(inputSize))
839  return success();
840 
841  const std::optional<int64_t> calculatedOutSizeMinusOne =
842  idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
843  if (!calculatedOutSizeMinusOne.has_value())
844  return op.emitOpError("expected input_")
845  << dimName << " + pad_" << padBeforeName << " + pad_"
846  << padAfterName << " - kernel_" << dimAxis
847  << " to be wholly divisible by stride_" << dimAxis << ", got ("
848  << inputSize << " + " << padBefore << " + " << padAfter << " - "
849  << kernelSize << ") / " << strideSize;
850 
851  const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
852  if (!ShapedType::isDynamic(outputSize) && calculatedOutSize != outputSize)
853  return op.emitOpError("calculated output ")
854  << dimName << " did not match expected: "
855  << "calculated=" << calculatedOutSize
856  << ", expected=" << outputSize;
857 
858  return success();
859  };
860 
861  if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
862  kernel[0], strides[0], padding[0], padding[1],
863  "height", "y", "top", "bottom")))
864  return failure();
865 
866  if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
867  kernel[1], strides[1], padding[2], padding[3],
868  "width", "x", "left", "right")))
869  return failure();
870 
871  return success();
872 }
873 
874 LogicalResult tosa::AvgPool2dOp::verify() {
875  if (failed(verifyPoolingOp(*this)))
876  return failure();
877 
878  const Type inputETy = getStorageElementTypeOrSelf(getInput().getType());
879  const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType());
880  const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType());
881  const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType());
882 
883  auto accType = getAccType();
884  if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
885  return emitOpError("accumulator type for integer tensor is not i32");
886 
887  if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
888  return emitOpError("accumulator type for f16 tensor is not f16/f32");
889 
890  if (inputETy.isBF16() && !accType.isF32())
891  return emitOpError("accumulator type for bf16 tensor is not f32");
892 
893  if (inputETy.isF32() && !accType.isF32())
894  return emitOpError("accumulator type for f32 tensor is not f32");
895 
896  if (inputETy != inputZpETy)
897  return emitOpError("expect both input and its zero point are the same "
898  "element type, got ")
899  << inputETy << " and " << inputZpETy;
900 
901  if (resultETy != outputZpETy)
902  return emitOpError("expect both output and its zero point are the same "
903  "element type, got ")
904  << resultETy << " and " << outputZpETy;
905 
906  FailureOr<int64_t> maybeIZp = getInputZeroPoint();
907  if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
908  return failure();
909 
910  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
911  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
912  return failure();
913 
914  return success();
915 }
916 
917 LogicalResult tosa::ClampOp::verify() {
918  mlir::Type inputETy =
919  llvm::cast<ShapedType>(getInput().getType()).getElementType();
920  if (auto quantType =
921  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
922  inputETy = quantType.getStorageType();
923  }
924  mlir::Type outputETy =
925  llvm::cast<ShapedType>(getOutput().getType()).getElementType();
926  if (auto quantType =
927  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
928  outputETy = quantType.getStorageType();
929  }
930  if (inputETy != outputETy)
931  return emitOpError("input/output element types are incompatible.");
932 
933  auto maxValAttr = getMaxValAttr();
934  auto minValAttr = getMinValAttr();
935 
936  unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
937 
938  if (inputETy.isInteger(dataTypeBitWidth)) {
939  // if input datatype is integer, check that the min_val/max_val attributes
940  // are integer attributes, and that their type is the same as the input's
941  // datatype
942  auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
943  auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
944  if (!intMaxValAttr || !intMinValAttr ||
945  (intMaxValAttr.getType() != intMinValAttr.getType()) ||
946  (intMaxValAttr.getType() != inputETy))
947  return emitOpError("min/max attributes types are incompatible with "
948  "input/output element types.");
949 
950  const bool isUnsigned = cast<IntegerType>(inputETy).isUnsigned();
951  const APInt minVal = intMinValAttr.getValue();
952  const APInt maxVal = intMaxValAttr.getValue();
953  if (isUnsigned ? maxVal.ult(minVal) : maxVal.slt(minVal))
954  return emitOpError("expected min_val <= max_val, got min_val=")
955  << minValAttr << ", max_val=" << maxValAttr;
956  } else {
957  // otherwise, input datatype is float, check that the min_val/max_val
958  // attributes share the same type and that their type is the same as the
959  // input's datatype
960  auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
961  auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
962  if (!floatMaxValAttr || !floatMinValAttr ||
963  (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
964  (floatMaxValAttr.getType() != inputETy))
965  return emitOpError("min/max attributes types are incompatible with "
966  "input/output element types.");
967 
968  const APFloat minVal = floatMinValAttr.getValue();
969  const APFloat maxVal = floatMaxValAttr.getValue();
970  if (minVal.isNaN() || maxVal.isNaN())
971  return emitOpError("min/max attributes should not be 'NaN', got min_val=")
972  << minValAttr << ", max_val=" << maxValAttr;
973 
974  if (maxVal < minVal)
975  return emitOpError("expected min_val <= max_val, got min_val=")
976  << minValAttr << ", max_val=" << maxValAttr;
977  }
978 
979  return success();
980 }
981 
982 //===----------------------------------------------------------------------===//
983 // TOSA Operator Quantization Builders.
984 //===----------------------------------------------------------------------===//
985 
986 /// This builder is called on all convolution operators except TransposeConv,
987 /// which has specialized output shape semantics. The builder also defines the
988 /// bitwidth of the output given the bit width of the input & weight content.
989 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
990  Type outputType, Value input, Value weight,
991  Value bias, DenseI64ArrayAttr pad,
992  DenseI64ArrayAttr stride,
993  DenseI64ArrayAttr dilation,
994  TypeAttr accType) {
995  auto zps = createZPsAsConst(builder, input, weight);
996  result.addOperands({input, weight, bias, zps.first, zps.second});
997  result.addAttribute("pad", pad);
998  result.addAttribute("stride", stride);
999  result.addAttribute("dilation", dilation);
1000  result.addAttribute("acc_type", accType);
1001  Type finalOutputType = outputType;
1002  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1003  if (quantAttr) {
1004  finalOutputType =
1005  buildConvOpResultTypeInfo(builder, outputType, input, weight);
1006  }
1007  result.addTypes(finalOutputType);
1008 }
1009 
1010 /// Handles tosa.transpose_conv2d which has outpad and output shape
1011 /// attributes.
1012 static void
1014  Type outputType, Value input, Value weight,
1015  Value bias, DenseI64ArrayAttr outpad,
1016  DenseI64ArrayAttr stride, TypeAttr accType) {
1017  auto zps = createZPsAsConst(builder, input, weight);
1018  result.addOperands({input, weight, bias, zps.first, zps.second});
1019  result.addAttribute("out_pad", outpad);
1020  result.addAttribute("stride", stride);
1021  result.addAttribute("acc_type", accType);
1022  Type finalOutputType = outputType;
1023  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1024  if (quantAttr) {
1025  finalOutputType =
1026  buildConvOpResultTypeInfo(builder, outputType, input, weight);
1027  }
1028  result.addTypes(finalOutputType);
1029 }
1030 
1031 /// The tosa.matmul op is also intended to be generated where a fully_connected
1032 /// op must be constructed where the weight is not a constant. In this case,
1033 /// the fully_connected op must be expressed using matmul.
1034 /// TODO: Add link to the leglization document explaining this.
1036  OperationState &result, Type outputType,
1037  Value a, Value b) {
1038  auto zps = createZPsAsConst(builder, a, b);
1039  result.addOperands({a, b, zps.first, zps.second});
1040 
1041  Type finalOutputType{outputType};
1042  if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
1043  auto eType = getStorageElementTypeOrSelf(a.getType());
1044  auto inputBits = eType.getIntOrFloatBitWidth();
1045 
1046  auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1047  assert(outputShapedType && "Output must be a shaped type");
1048 
1049  IntegerType accElementType;
1050  if (inputBits == 16)
1051  accElementType = builder.getIntegerType(48);
1052  else
1053  accElementType = builder.getI32Type();
1054 
1055  finalOutputType = outputShapedType.clone(accElementType);
1056  }
1057  result.addTypes(finalOutputType);
1058 }
1059 
1060 /// Both the tosa.avg_pool2d and unary ops use the same
1061 /// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
1062 /// has additional parameters not part of the unary ops.
1063 static void
1065  Type outputType, Value input,
1066  DenseArrayAttr kernel, DenseArrayAttr stride,
1067  DenseArrayAttr pad, TypeAttr accType) {
1068  const Location loc{result.location};
1069  int64_t inputZp{0};
1070  int64_t outputZp{0};
1071 
1072  if (auto quantAttr =
1073  buildUnaryOpQuantizationAttr(builder, input, outputType)) {
1074  inputZp = quantAttr.getInputZp();
1075  outputZp = quantAttr.getOutputZp();
1076  }
1077  const std::optional<Value> inputZpOp =
1078  createZeroPointTensor(builder, loc, input.getType(), inputZp);
1079  if (!inputZpOp) {
1080  (void)emitError(
1081  loc,
1082  "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1083  }
1084  const std::optional<Value> outputZpOp =
1085  createZeroPointTensor(builder, loc, outputType, outputZp);
1086  if (!outputZpOp) {
1087  (void)emitError(loc, "Failed to create output zero point tensor for "
1088  "quantized AVG_POOL2D op");
1089  }
1090 
1091  if (inputZpOp && outputZpOp) {
1092  result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1093  } else {
1094  // failed to create one or more zero points above: just add input as
1095  // operands this will trigger error in building the op because of missing
1096  // zero points
1097  result.addOperands({input});
1098  }
1099  result.addAttribute("kernel", kernel);
1100  result.addAttribute("stride", stride);
1101  result.addAttribute("pad", pad);
1102  result.addAttribute("acc_type", accType);
1103  result.types.push_back(outputType);
1104 }
1105 
1106 /// This builder is called on single-parameter negate operator
1107 /// to construct input and output zero points based on their
1108 /// types.
1110  OperationState &result, Type outputType,
1111  Value input) {
1112  const Location loc{result.location};
1113  int64_t input1Zp{0};
1114  int64_t outputZp{0};
1115  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
1116  if (quantAttr) {
1117  input1Zp = quantAttr.getInputZp();
1118  outputZp = quantAttr.getOutputZp();
1119  }
1120  const std::optional<Value> input1ZpOp =
1121  createZeroPointTensor(builder, loc, input.getType(), input1Zp);
1122  if (!input1ZpOp) {
1123  (void)emitError(
1124  loc, "Failed to create input1 zero point for quantized NEGATE op");
1125  }
1126 
1127  const std::optional<Value> outputZpOp =
1128  createZeroPointTensor(builder, loc, input.getType(), outputZp);
1129  if (!outputZpOp) {
1130  (void)emitError(
1131  loc, "Failed to create output zero point for quantized NEGATE op");
1132  }
1133 
1134  if (input1ZpOp && outputZpOp) {
1135  result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1136  } else {
1137  // failed to create one or more zero points above: just add input as
1138  // operands. This will trigger error in building the op because of
1139  // missing zero points
1140  result.addOperands({input});
1141  }
1142 
1143  result.types.push_back(outputType);
1144 }
1145 
1146 /// This builder is called on TOSA pad operator that needs to create its own
1147 /// OptionalAttr quantization_attr parameter to scale the padding values
1148 /// correctly. No pad_const is interpreted as zero-padding.
1149 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
1150  Type outputType, Value input,
1151  Value paddings) {
1152  const Location loc{result.location};
1153  int32_t zp{0};
1154  const auto quantAttr = buildPadOpQuantizationAttr(builder, input);
1155  if (quantAttr) {
1156  zp = static_cast<int32_t>(quantAttr.getInputZp());
1157  }
1158  const auto padConstOp{createPadConstTensor(builder, loc, input, zp)};
1159  result.addOperands({input, paddings, padConstOp});
1160  result.types.push_back(outputType);
1161 }
1162 
1163 static void buildVariableOp(OpBuilder &builder, OperationState &result,
1164  StringRef name, Type variableType,
1165  Attribute initialValue) {
1166  const Location loc{result.location};
1167  auto nameAttr = builder.getStringAttr(name);
1168 
1169  auto shapedType = dyn_cast<ShapedType>(variableType);
1170  if (!shapedType) {
1171  (void)emitError(loc, "variable type must be a shaped type");
1172  return;
1173  }
1174  if (!shapedType.hasRank()) {
1175  (void)emitError(loc, "variable type must be a ranked type");
1176  return;
1177  }
1178 
1179  auto elementType = shapedType.getElementType();
1180  auto elementTypeAttr = TypeAttr::get(elementType);
1181  ArrayRef<int64_t> shape = shapedType.getShape();
1182  auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
1183 
1184  result.addAttribute("name", nameAttr);
1185  result.addAttribute("var_shape", varShapeAttr);
1186  result.addAttribute("type", elementTypeAttr);
1187  result.addAttribute("initial_value", initialValue);
1188 }
1189 
1190 //===----------------------------------------------------------------------===//
1191 // TOSA Operator Return Type Inference.
1192 //===----------------------------------------------------------------------===//
1193 
1194 static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
1195  SmallVector<int64_t> &outShape) {
1196  int64_t outRank = 0;
1197  for (int i = 0, e = operands.size(); i != e; ++i) {
1198  auto shape = operands.getShape(i);
1199  if (!shape.hasRank()) {
1200  // TODO(jennik): Update function to have better case handling for
1201  // invalid operands and for ranked tensors.
1202  return failure();
1203  }
1204  outRank = std::max<int64_t>(outRank, shape.getRank());
1205  }
1206 
1207  outShape.resize(outRank, 1);
1208 
1209  for (int i = 0, e = operands.size(); i != e; ++i) {
1210  auto shape = operands.getShape(i);
1211  auto rankDiff = outShape.size() - shape.getRank();
1212 
1213  for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
1214  auto dim1 = outShape[i + rankDiff];
1215  auto dim2 = shape.getDimSize(i);
1216  auto resolvedDim = dim1;
1217 
1218  if (dim1 == 1) {
1219  resolvedDim = dim2;
1220  } else if (dim2 == 1) {
1221  resolvedDim = dim1;
1222  } else if (dim1 != dim2) {
1223  return failure();
1224  }
1225  outShape[i + rankDiff] = resolvedDim;
1226  }
1227  }
1228 
1229  return success();
1230 }
1231 
1232 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1233  MLIRContext *context, ::std::optional<Location> location,
1234  ArgMaxOp::Adaptor adaptor,
1235  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1236  ShapeAdaptor inputShape(adaptor.getInput().getType());
1237  IntegerAttr axis = adaptor.getProperties().axis;
1238  int32_t axisVal = axis.getValue().getSExtValue();
1239 
1240  if (!inputShape.hasRank()) {
1241  inferredReturnShapes.push_back(ShapedTypeComponents());
1242  return success();
1243  }
1244 
1245  SmallVector<int64_t> outShape;
1246  outShape.reserve(inputShape.getRank() - 1);
1247  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1248  if (i == axisVal)
1249  continue;
1250  outShape.push_back(inputShape.getDimSize(i));
1251  }
1252 
1253  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1254  return success();
1255 }
1256 
1257 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1258  MLIRContext *context, ::std::optional<Location> location,
1259  RFFT2dOp::Adaptor adaptor,
1260  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1261  ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1262 
1263  if (!inputShape.hasRank())
1264  return failure();
1265 
1266  llvm::SmallVector<int64_t> outputShape;
1267  outputShape.resize(3, ShapedType::kDynamic);
1268  outputShape[0] = inputShape.getDimSize(0);
1269  outputShape[1] = inputShape.getDimSize(1);
1270  int64_t inWidth = inputShape.getDimSize(2);
1271 
1272  // Note that we can support this calculation symbolically
1273  // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
1274  if (inWidth != ShapedType::kDynamic)
1275  outputShape[2] = inWidth / 2 + 1;
1276 
1277  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1278  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1279 
1280  return success();
1281 }
1282 
1283 static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
1284  const llvm::StringRef dimName) {
1285  const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1286  if (!isPowerOfTwo)
1287  return op->emitOpError("expected ")
1288  << dimName << " to be a power of two, got " << dimSize;
1289 
1290  return success();
1291 }
1292 
1293 LogicalResult tosa::RFFT2dOp::verify() {
1294  const auto outputTypes = getResultTypes();
1295  if (failed(verifyCompatibleShapes(outputTypes)))
1296  return emitOpError("expected output shapes to match, got ") << outputTypes;
1297 
1298  const auto inputType =
1299  llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1300  if (!inputType)
1301  return success();
1302 
1303  const int64_t height = inputType.getDimSize(1);
1304  if (!ShapedType::isDynamic(height) &&
1305  failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1306  return failure();
1307 
1308  const int64_t width = inputType.getDimSize(2);
1309  if (!ShapedType::isDynamic(width) &&
1310  failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1311  return failure();
1312 
1313  const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1314  if (!outputType)
1315  return success();
1316 
1317  // Batch and height input/output dimensions should match
1318  if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
1319  outputType.getShape().drop_back())))
1320  return emitOpError("expected batch and height dimensions of input/output "
1321  "to match, got input=")
1322  << inputType << " output=" << outputType;
1323 
1324  // Output width dimension expected to be input_width / 2 + 1
1325  const int64_t outputWidth = outputType.getDimSize(2);
1326  if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) &&
1327  (outputWidth != (width / 2) + 1))
1328  return emitOpError(
1329  "expected output width to be equal to input_width / 2 + 1, got ")
1330  << outputWidth;
1331 
1332  return success();
1333 }
1334 
1335 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1336  MLIRContext *context, ::std::optional<Location> location,
1337  FFT2dOp::Adaptor adaptor,
1338  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1339  inferredReturnShapes.push_back(
1340  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
1341  inferredReturnShapes.push_back(
1342  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
1343  return success();
1344 }
1345 
1346 LogicalResult tosa::FFT2dOp::verify() {
1347  const auto inputRealType =
1348  llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1349  const auto inputImagType =
1350  llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
1351  if (!inputRealType || !inputImagType)
1352  return success();
1353 
1354  const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
1355  return ShapedType::isDynamic(a) ? a : b;
1356  };
1357 
1358  const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1359  inputImagType.getDimSize(1));
1360  if (!ShapedType::isDynamic(height) &&
1361  failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1362  return failure();
1363 
1364  const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1365  inputImagType.getDimSize(2));
1366  if (!ShapedType::isDynamic(width) &&
1367  failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1368  return failure();
1369 
1370  return success();
1371 }
1372 
1373 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1374  MLIRContext *context, ::std::optional<Location> location,
1375  ConcatOp::Adaptor adaptor,
1376  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1377  // Infer all dimension sizes by reducing based on inputs.
1378  const Properties &prop = adaptor.getProperties();
1379  int32_t axis = prop.axis.getValue().getSExtValue();
1380  llvm::SmallVector<int64_t> outputShape;
1381  bool hasRankedInput = false;
1382  for (auto operand : adaptor.getOperands()) {
1383  ShapeAdaptor operandShape(operand.getType());
1384  if (!operandShape.hasRank())
1385  continue;
1386 
1387  // Copy the Operand's rank.
1388  if (!hasRankedInput)
1389  outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1390 
1391  // Copy shapes until the dim is non-dynamic.
1392  for (int i = 0, s = operandShape.getRank(); i < s; i++) {
1393  if (i == axis || operandShape.isDynamicDim(i))
1394  continue;
1395  if (outputShape[i] == ShapedType::kDynamic)
1396  outputShape[i] = operandShape.getDimSize(i);
1397  if (outputShape[i] != operandShape.getDimSize(i))
1398  return emitOptionalError(location,
1399  "Cannot concat tensors with different sizes"
1400  " on the non-axis dimension ",
1401  i);
1402  }
1403 
1404  hasRankedInput = true;
1405  }
1406 
1407  if (adaptor.getInput1().empty())
1408  return failure();
1409 
1410  Type inputType =
1411  llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1412  if (!hasRankedInput) {
1413  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1414  return success();
1415  }
1416 
1417  // Determine the dimension size along the concatenation axis.
1418  int64_t concatDimSize = 0;
1419  for (auto operand : adaptor.getOperands()) {
1420  ShapeAdaptor operandShape(operand.getType());
1421 
1422  // We need to know the length of the concatenation axis of all inputs to
1423  // determine the dimension size of the output shape.
1424  if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1425  concatDimSize = ShapedType::kDynamic;
1426  break;
1427  }
1428 
1429  concatDimSize += operandShape.getDimSize(axis);
1430  }
1431 
1432  outputShape[axis] = concatDimSize;
1433 
1434  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1435  return success();
1436 }
1437 
1438 LogicalResult tosa::ConcatOp::verify() {
1439  // check that each input has same element type as output
1440  auto outType = getOutput().getType();
1441  const Operation::operand_range inputList = getInput1();
1442 
1443  // Check there is at least one input
1444  if (inputList.empty())
1445  return emitOpError("expect at least one input");
1446 
1447  if (!llvm::all_of(inputList, [&](auto input) {
1448  return succeeded(verifySameElementTypes(
1449  *this, /* inType = */ input.getType(), outType));
1450  })) {
1451  return failure();
1452  }
1453 
1454  const int32_t axis = getAxis();
1455  ShapeAdaptor firstRankedInputShape = nullptr;
1456  for (const auto &input : inputList) {
1457  const Type inputType = input.getType();
1458  ShapeAdaptor currShape(inputType);
1459  if (currShape.hasRank()) {
1460  firstRankedInputShape = currShape;
1461  // Check axis is in expected range
1462  if (axis < 0 || axis >= firstRankedInputShape.getRank())
1463  return emitOpError("expect axis to be within range 0 < axis < "
1464  "rank(input1[firstRankedTensorIdx]), got ")
1465  << axis;
1466  break;
1467  }
1468  }
1469 
1470  const auto allOperandsHasRank = [](const Value input) {
1471  return ShapeAdaptor(input.getType()).hasRank();
1472  };
1473  if (llvm::all_of(inputList, allOperandsHasRank)) {
1474  const int64_t firstInputRank = firstRankedInputShape.getRank();
1475 
1476  for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {
1477  const ShapeAdaptor inputShape(input.getType());
1478  const int64_t inputRank = inputShape.getRank();
1479  const size_t operandNum = index + 1;
1480 
1481  // Check that each operand has the same rank
1482  if (inputRank != firstInputRank)
1483  return emitOpError(
1484  "expect all operands to have the same rank, but got ")
1485  << firstInputRank << " vs " << inputRank << " on operands 0 and "
1486  << operandNum;
1487 
1488  // Check non-axis dims match
1489  for (int i = 0; i < inputRank; i++) {
1490  const int64_t inputDim = inputShape.getDimSize(i);
1491  const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
1492  if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
1493  inputShape.isDynamicDim(i))
1494  continue;
1495  if (inputDim != firstInputDim)
1496  return emitOpError("expect all operand shapes to have the same sizes "
1497  "on non-axis dimensions, but got ")
1498  << inputDim << " vs " << firstInputDim << " at index " << i
1499  << " on operands 0 and " << operandNum;
1500  }
1501  }
1502 
1503  // ERROR_IF(axis_sum != shape[axis]);
1504  int64_t axisSum = 0;
1505  for (const auto &input : inputList) {
1506  const ShapeAdaptor inputShape(input.getType());
1507  if (inputShape.isDynamicDim(axis)) {
1508  // make axisSum negative to indicate invalid value
1509  axisSum = -1;
1510  break;
1511  }
1512  axisSum += inputShape.getDimSize(axis);
1513  }
1514  const ShapeAdaptor outputShape(outType);
1515  if (axisSum >= 0 && outputShape.hasRank() &&
1516  !outputShape.isDynamicDim(axis) &&
1517  axisSum != outputShape.getDimSize(axis))
1518  return emitOpError("requires sum of axis dimensions of input1 "
1519  "equal to output axis dimension, got ")
1520  << axisSum << " and " << outputShape.getDimSize(axis);
1521  }
1522 
1523  return success();
1524 }
1525 
1526 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1527  MLIRContext *context, ::std::optional<Location> location,
1528  ValueShapeRange operands, DictionaryAttr attributes,
1529  OpaqueProperties properties, RegionRange regions,
1530  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1531  auto elementType = IntegerType::get(context, /*width=*/1);
1532 
1533  llvm::SmallVector<int64_t> outShape;
1534  if (resolveBroadcastShape(operands, outShape).failed()) {
1535  inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
1536  return success();
1537  }
1538 
1539  inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
1540  return success();
1541 }
1542 
1543 bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1544  if (l.size() != r.size() || l.size() != 1)
1545  return false;
1546  return succeeded(verifyCompatibleShape(l[0], r[0]));
1547 }
1548 
1549 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1550  MLIRContext *context, ::std::optional<Location> location,
1551  MatMulOp::Adaptor adaptor,
1552  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1553  ShapeAdaptor lhsShape(adaptor.getA().getType());
1554  ShapeAdaptor rhsShape(adaptor.getB().getType());
1555 
1556  // All shapes are dynamic.
1557  SmallVector<int64_t> outShape;
1558  outShape.resize(3, ShapedType::kDynamic);
1559 
1560  if (lhsShape.hasRank()) {
1561  outShape[0] = lhsShape.getDimSize(0);
1562  outShape[1] = lhsShape.getDimSize(1);
1563  }
1564 
1565  if (rhsShape.hasRank()) {
1566  outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1567  : outShape[0];
1568  outShape[2] = rhsShape.getDimSize(2);
1569  }
1570 
1571  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1572  return success();
1573 }
1574 
1575 LogicalResult MatMulOp::verify() {
1576  auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
1577  auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
1578 
1579  // Must be shaped tensor types
1580  if (!aType)
1581  return emitOpError("expect a shaped tensor for input a, got ")
1582  << getA().getType();
1583 
1584  if (!bType)
1585  return emitOpError("expect a shaped tensor for input b, got ")
1586  << getB().getType();
1587 
1588  auto aElementType = aType.getElementType();
1589  auto bElementType = bType.getElementType();
1590 
1591  auto aQuantizedEType =
1592  llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1593  auto bQuantizedEType =
1594  llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1595 
1596  if (aQuantizedEType || bQuantizedEType) {
1597  if (!aQuantizedEType || !bQuantizedEType) {
1598  return emitOpError("expect operands to be both quantized or both not "
1599  "quantized, got ")
1600  << aElementType << " and " << bElementType;
1601  }
1602  // both a and b have quantized element types
1603  auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1604  auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1605  if (aQuantWidth != bQuantWidth) {
1606  return emitOpError("expect quantized operands to have same widths, got ")
1607  << aQuantWidth << " and " << bQuantWidth;
1608  }
1609  } else {
1610  // non-quantized element types
1611  if (aElementType != bElementType) {
1612  return emitOpError("expect same element type for inputs a and b, got ")
1613  << aElementType << " and " << bElementType;
1614  }
1615  }
1616 
1617  // check a_zp and b_zp
1618  auto aEType = getStorageElementTypeOrSelf(aType);
1619  auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
1620  if (aEType != aZpEType) {
1621  return emitOpError("expect input a and a_zp have the same "
1622  "element type, got ")
1623  << aEType << " and " << aZpEType;
1624  }
1625 
1626  auto bEType = getStorageElementTypeOrSelf(bType);
1627  auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
1628  if (bEType != bZpEType) {
1629  return emitOpError("expect input b and b_zp have the same "
1630  "element type, got ")
1631  << bEType << " and " << bZpEType;
1632  }
1633 
1634  FailureOr<int64_t> maybeAZp = getAZeroPoint();
1635  if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1636  return failure();
1637 
1638  FailureOr<int64_t> maybeBZp = getBZeroPoint();
1639  if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1640  return failure();
1641 
1642  return success();
1643 }
1644 
1645 LogicalResult tosa::PadOp::inferReturnTypeComponents(
1646  MLIRContext *context, ::std::optional<Location> location,
1647  PadOp::Adaptor adaptor,
1648  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1649  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1650  auto paddingRank =
1651  cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
1652  SmallVector<int64_t> outputShape;
1653 
1654  // If the input rank is unknown, we can infer the output rank using the
1655  // padding shape's rank divided by 2.
1656  if (!inputShape.hasRank()) {
1657  outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
1658  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1659  return success();
1660  }
1661 
1662  SmallVector<int64_t> paddingValues;
1663  // If the paddings value is not a constant, all dimensions must be dynamic.
1664  if (!tosa::getConstShapeValues(adaptor.getPadding().getDefiningOp(),
1665  paddingValues)) {
1666  outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1667  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1668  return success();
1669  }
1670 
1671  outputShape.reserve(inputShape.getRank());
1672  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1673  if (inputShape.isDynamicDim(i)) {
1674  outputShape.push_back(ShapedType::kDynamic);
1675  continue;
1676  }
1677  auto padFront = paddingValues[i * 2];
1678  auto padBack = paddingValues[i * 2 + 1];
1679  if (padFront < 0 || padBack < 0) {
1680  // if either padding for dim i is -1, output dim is unknown
1681  outputShape.push_back(ShapedType::kDynamic);
1682  continue;
1683  }
1684 
1685  outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
1686  }
1687 
1688  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1689  return success();
1690 }
1691 
1692 LogicalResult tosa::PadOp::verify() {
1693  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1694  /* outType = */ getOutput().getType())
1695  .failed()) {
1696  return failure();
1697  }
1698 
1699  if (auto padConst = getPadConst()) {
1700  if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
1701  /* outType = */ getOutput().getType())
1702  .failed()) {
1703  return failure();
1704  }
1705  }
1706 
1707  RankedTensorType inputType =
1708  llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1709  RankedTensorType outputType =
1710  llvm::dyn_cast<RankedTensorType>(getOutput().getType());
1711  if (!inputType || !outputType)
1712  return success();
1713 
1714  auto inputRank = inputType.getRank();
1715  auto outputRank = outputType.getRank();
1716  if (inputRank != outputRank)
1717  return emitOpError() << "expect same input and output tensor rank, but got "
1718  << "inputRank: " << inputRank
1719  << ", outputRank: " << outputRank;
1720 
1721  DenseIntElementsAttr paddingAttr;
1722  if (!matchPattern(getPadding(), m_Constant(&paddingAttr))) {
1723  return failure();
1724  }
1725 
1726  auto paddingValues = paddingAttr.getValues<APInt>();
1727  if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
1728  return emitOpError() << "padding tensor must have " << inputRank
1729  << " * 2 = " << inputRank * 2 << " elements, but got "
1730  << paddingValues.size();
1731 
1732  auto inputShape = inputType.getShape();
1733  auto outputShape = outputType.getShape();
1734 
1735  for (int64_t i = 0; i < inputRank; ++i) {
1736  int64_t padStart = paddingValues[i * 2].getSExtValue();
1737  int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
1738 
1739  if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
1740  return emitOpError()
1741  << "invalid padding values at dimension " << i
1742  << ": values must be non-negative or -1 for dynamic padding, got ["
1743  << padStart << ", " << padEnd << "]";
1744  }
1745 
1746  // Skip shape verification for dynamic input/output
1747  if (inputShape[i] == ShapedType::kDynamic ||
1748  outputShape[i] == ShapedType::kDynamic)
1749  continue;
1750 
1751  if (outputShape[i] != inputShape[i] + padStart + padEnd) {
1752  return emitOpError() << "mismatch in output shape at dimension " << i
1753  << ": expected " << inputShape[i] << " + "
1754  << padStart << " + " << padEnd << " = "
1755  << (inputShape[i] + padStart + padEnd)
1756  << ", but got " << outputShape[i];
1757  }
1758  }
1759 
1760  return success();
1761 }
1762 
1763 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
1764  MLIRContext *context, ::std::optional<Location> location,
1765  SliceOp::Adaptor adaptor,
1766  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1767 
1768  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
1769  SmallVector<int64_t> start;
1770  SmallVector<int64_t> size;
1771 
1772  if (!tosa::getConstShapeValues(adaptor.getStart().getDefiningOp(), start) ||
1773  !tosa::getConstShapeValues(adaptor.getSize().getDefiningOp(), size)) {
1774  auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
1775  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
1776  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
1777  return success();
1778  }
1779 
1780  // if size[i] is -1, all remaining elements in dimension i are included
1781  // in the slice, similar to TF.
1782  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1783  // initialize outputShape to all unknown
1784  SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
1785  if (inputShape.hasRank()) {
1786  for (size_t i = 0; i < size.size(); i++) {
1787  if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
1788  (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
1789  start[i] < inputShape.getDimSize(i))) {
1790  // size[i] is not 0 and not < -1, and start[i] is in valid range
1791  if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
1792  // input shape has unknown dim[i] - only valid if size[i] > 0
1793  if (size[i] > 0) {
1794  outputShape[i] = size[i];
1795  }
1796  } else {
1797  // input shape has known dim[i]
1798  if (size[i] == -1) {
1799  outputShape[i] = inputShape.getDimSize(i) - start[i];
1800  } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
1801  // start[i] + size[i] is within bound of input shape's dim[i]
1802  outputShape[i] = size[i];
1803  }
1804  }
1805  }
1806  }
1807  } else {
1808  outputShape = convertToMlirShape(size);
1809  }
1810  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1811  return success();
1812 }
1813 
1814 LogicalResult tosa::SliceOp::verify() {
1815  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1816  /* outType = */ getOutput().getType())
1817  .failed())
1818  return failure();
1819 
1820  const ShapeAdaptor inputShape(getInput1().getType());
1821  if (inputShape.hasRank()) {
1822  const auto inputRank = inputShape.getRank();
1823  const ShapeAdaptor outputShape(getOutput().getType());
1824  if (outputShape.hasRank() && inputRank != outputShape.getRank())
1825  return emitOpError(
1826  "expect input1 and output to have the same ranks, got ")
1827  << inputRank << " and " << outputShape.getRank();
1828 
1829  const auto startShapeRank =
1830  llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
1831  if (inputRank != startShapeRank)
1832  return emitOpError("length of start is not equal to rank of input shape");
1833 
1834  const auto sizeShapeRank =
1835  llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
1836  if (inputRank != sizeShapeRank)
1837  return emitOpError("length of size is not equal to rank of input shape");
1838  }
1839 
1840  return success();
1841 }
1842 
1843 LogicalResult tosa::MulOp::inferReturnTypeComponents(
1844  MLIRContext *context, ::std::optional<Location> location,
1845  ValueShapeRange operands, DictionaryAttr attributes,
1846  OpaqueProperties properties, RegionRange regions,
1847  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1848  // mul op's output shape only depend on input1 and input2, not on shift
1849  ValueShapeRange twoInputs = operands.drop_back();
1850  llvm::SmallVector<int64_t> outShape;
1851  if (resolveBroadcastShape(twoInputs, outShape).failed()) {
1852  inferredReturnShapes.push_back(ShapedTypeComponents());
1853  } else {
1854  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1855  }
1856  return success();
1857 }
1858 
1859 LogicalResult tosa::MulOp::verify() {
1860  const Value output = getOutput();
1861  auto resElemType = getElementTypeOrSelf(output);
1862 
1863  // Verify if the element type among operands and result match tosa
1864  // specification.
1865  if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
1866  IntegerType lhsIntType =
1867  dyn_cast<IntegerType>(getElementTypeOrSelf(getInput1()));
1868  IntegerType rhsIntType =
1869  dyn_cast<IntegerType>(getElementTypeOrSelf(getInput2()));
1870  if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
1871  return emitOpError("requires the same element type for all operands");
1872 
1873  // Though the spec requires the element type of result to be i32, a more
1874  // relaxed way is provided at dialect level for easier cooperating with
1875  // other dialects.
1876  if (lhsIntType.getWidth() > resIntType.getWidth())
1877  return emitOpError("invalid data type size for operands or result");
1878 
1879  } else {
1880  // For other supported type, the spec requires requires the same element
1881  // type for all operands (excludes `shift` operand) and results.
1882  for (int i = 0; i < 2; ++i) {
1883  if (getElementTypeOrSelf(getOperand(i)) != resElemType)
1884  return emitOpError(
1885  "requires the same element type for all operands and results");
1886  }
1887 
1888  // verify shift has value 0 for non-integer types
1889  ElementsAttr shift_elem;
1890  if (matchPattern(getShift(), m_Constant(&shift_elem))) {
1891  int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1892  if (shift != 0) {
1893  return emitOpError() << "require shift to be 0 for float type";
1894  }
1895  }
1896  }
1897 
1898  // Verify the op has same ranks for all main operands (excludes extra operands
1899  // such as shift of mul op, so this is the only difference with the built-in
1900  // `SameOperandsAndResultRank` trait) and results types, if known.
1901  TypeRange operandTypes = getOperandTypes();
1902  ShapedType aType = cast<ShapedType>(operandTypes[0]);
1903  ShapedType bType = cast<ShapedType>(operandTypes[1]);
1904 
1905  const bool aHasRank = aType.hasRank();
1906  const bool bHasRank = bType.hasRank();
1907  if (aHasRank && bHasRank) {
1908  const int64_t aRank = aType.getRank();
1909  const int64_t bRank = bType.getRank();
1910  if (aRank != bRank)
1911  return emitOpError("a and b operands don't have matching ranks, got ")
1912  << aRank << " and " << bRank;
1913 
1914  // check for broadcast compatible shapes
1915  SmallVector<int64_t> resultShape;
1917  aType.getShape(), bType.getShape(), resultShape))
1918  return emitOpError("a and b operands don't have broadcast-compatible "
1919  "shapes, got ")
1920  << aType << " and " << bType;
1921  }
1922 
1923  ShapedType resultType = cast<ShapedType>(output.getType());
1924  if (!resultType.hasRank())
1925  return success();
1926 
1927  const int64_t resultRank = resultType.getRank();
1928  if (aHasRank && resultRank != aType.getRank())
1929  return emitOpError("result type has different rank than a, got ")
1930  << resultRank << " vs " << aType.getRank();
1931  if (bHasRank && resultRank != bType.getRank())
1932  return emitOpError("result type has different rank than b, got ")
1933  << resultRank << " vs " << bType.getRank();
1934 
1935  return success();
1936 }
1937 
1938 LogicalResult tosa::TableOp::inferReturnTypeComponents(
1939  MLIRContext *context, ::std::optional<Location> location,
1940  TableOp::Adaptor adaptor,
1941  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1942  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1943 
1944  if (!inputShape.hasRank()) {
1945  inferredReturnShapes.push_back(ShapedTypeComponents());
1946  return success();
1947  }
1948 
1949  inferredReturnShapes.resize(1);
1950  inputShape.getDims(inferredReturnShapes[0]);
1951  return success();
1952 }
1953 
1954 LogicalResult tosa::TableOp::verify() {
1955  TensorType inputType = getInput1().getType();
1956  TensorType outputType = getOutput().getType();
1957 
1958  if (inputType.hasRank() && outputType.hasRank() &&
1959  inputType.getRank() != outputType.getRank())
1960  return emitOpError()
1961  << "expected input tensor rank to equal result tensor rank";
1962 
1963  auto inputDims = inputType.getShape();
1964  auto outputDims = outputType.getShape();
1965  for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
1966  int64_t dim = it.index();
1967  auto [inputDim, outputDim] = it.value();
1968  if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) {
1969  return emitOpError() << "dim(result, " << dim << ") = " << outputDim
1970  << " doesn't match dim(input, " << dim
1971  << ") = " << inputDim;
1972  }
1973  }
1974  return success();
1975 }
1976 
1977 LogicalResult
1978 tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
1979  // Multiples must be constants.
1980  DenseIntElementsAttr multiplesAttr;
1981  if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
1982  return failure();
1983  multiples = llvm::to_vector(
1984  llvm::map_range(multiplesAttr.getValues<APInt>(),
1985  [](const APInt &val) { return val.getSExtValue(); }));
1986  return success();
1987 }
1988 
1989 LogicalResult tosa::TileOp::inferReturnTypeComponents(
1990  MLIRContext *context, ::std::optional<Location> location,
1991  TileOp::Adaptor adaptor,
1992  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1993  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
1994  SmallVector<int64_t> multiples;
1995  if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
1996  multiples)) {
1997  auto rank =
1998  cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
1999  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2000  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2001  return success();
2002  } else {
2003  multiples = convertToMlirShape(multiples);
2004  }
2005 
2006  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2007  SmallVector<int64_t> outputShape;
2008  if (!inputShape.hasRank()) {
2009  outputShape.resize(multiples.size(), ShapedType::kDynamic);
2010  inferredReturnShapes.push_back(
2011  ShapedTypeComponents(outputShape, inputType));
2012  return success();
2013  } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
2014  return failure();
2015 
2016  // Any non dynamic dimension can be multiplied to a known size.
2017  outputShape.reserve(multiples.size());
2018  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2019  if (multiples[i] == ShapedType::kDynamic) {
2020  outputShape.push_back(ShapedType::kDynamic);
2021  } else {
2022  int64_t dim = inputShape.getDimSize(i);
2023  if (dim != ShapedType::kDynamic)
2024  dim *= multiples[i];
2025  outputShape.push_back(dim);
2026  }
2027  }
2028 
2029  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2030  return success();
2031 }
2032 
2033 LogicalResult tosa::TileOp::verify() {
2034  if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
2035  /* outType = */ getOutput().getType())
2036  .failed()) {
2037  return failure();
2038  }
2039  ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
2040  ShapedType outputType = llvm::cast<ShapedType>(getType());
2041 
2042  shapeType multiplesType =
2043  llvm::cast<tosa::shapeType>(getMultiples().getType());
2044 
2045  auto multiplesRank = multiplesType.getRank();
2046 
2047  if (inputType.hasRank()) {
2048  if (inputType.getRank() != multiplesRank)
2049  return emitOpError("expect 'multiples' to have rank ")
2050  << inputType.getRank() << " but got " << multiplesRank << ".";
2051  if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
2052  return emitOpError("expect same input and output tensor rank.");
2053  } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2054  return emitOpError("expect 'multiples' array to have length ")
2055  << outputType.getRank() << " but got " << multiplesRank << ".";
2056 
2057  SmallVector<int64_t> multiples;
2058  if (getConstantMultiples(multiples).succeeded() &&
2059  llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
2060  return emitOpError(
2061  "expect element of 'multiples' to be positive integer or -1.");
2062 
2063  return success();
2064 }
2065 
2066 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2067  if (l.size() != r.size() || l.size() != 1)
2068  return false;
2069  return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
2070 }
2071 
2072 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2073  MLIRContext *context, ::std::optional<Location> location,
2074  ReshapeOp::Adaptor adaptor,
2075  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2076  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2077  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2078  llvm::SmallVector<int64_t> newShapeValue;
2079  if (!tosa::getConstShapeValues(adaptor.getShape().getDefiningOp(),
2080  newShapeValue)) {
2081  auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2082  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2083  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2084  return success();
2085  } else {
2086  newShapeValue = convertToMlirShape(newShapeValue);
2087  }
2088 
2089  // We cannot infer from the total number of elements so we must take the
2090  // shape attribute as exact.
2091  if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2092  inferredReturnShapes.push_back(
2093  ShapedTypeComponents(newShapeValue, inputType));
2094  return success();
2095  }
2096 
2097  // Determine the number of elements covered by the slice of all static
2098  // dimensions. This allows us to infer the length of the remaining dynamic
2099  // dimension.
2100  int64_t numElements = inputShape.getNumElements();
2101  int64_t staticMul = 1;
2102  for (auto val : newShapeValue) {
2103  if (!ShapedType::isDynamic(val)) {
2104  staticMul *= val;
2105  }
2106  }
2107 
2108  // Determine the length of the dynamic dimension.
2109  for (auto &val : newShapeValue) {
2110  if (ShapedType::isDynamic(val))
2111  val = numElements / staticMul;
2112  }
2113 
2114  inferredReturnShapes.push_back(
2115  ShapedTypeComponents(newShapeValue, inputType));
2116  return success();
2117 }
2118 
2119 llvm::LogicalResult tosa::ReshapeOp::verify() {
2120  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2121  /* outType = */ getOutput().getType())
2122  .failed()) {
2123  return failure();
2124  }
2125  TensorType inputType = getInput1().getType();
2126 
2127  SmallVector<int64_t> shapeValues;
2128  if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
2129  // skip following checks if shape is not constant
2130  return mlir::success();
2131  }
2132 
2133  int missingDims = llvm::count(shapeValues, -1);
2134  if (missingDims > 1)
2135  return emitOpError() << "expected at most one target dimension to be -1";
2136 
2137  const auto outputType = dyn_cast<RankedTensorType>(getType());
2138  if (!outputType)
2139  return success();
2140 
2141  if ((int64_t)shapeValues.size() != outputType.getRank())
2142  return emitOpError() << "new shape does not match result rank";
2143 
2144  for (auto [newShapeDim, outputShapeDim] :
2145  zip(shapeValues, outputType.getShape())) {
2146  if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2147  outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2148  return emitOpError() << "new shape is inconsistent with result shape";
2149 
2150  if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2151  return emitOpError() << "new shape has invalid tensor dimension size "
2152  << newShapeDim;
2153  }
2154 
2155  if (inputType.hasStaticShape()) {
2156  int64_t inputElementsNum = inputType.getNumElements();
2157  if (outputType.hasStaticShape()) {
2158  int64_t outputElementsNum = outputType.getNumElements();
2159  if (inputElementsNum != outputElementsNum) {
2160  return emitOpError() << "cannot reshape " << inputElementsNum
2161  << " elements into " << outputElementsNum;
2162  }
2163  }
2164 
2165  int64_t newShapeElementsNum = std::accumulate(
2166  shapeValues.begin(), shapeValues.end(), 1LL,
2167  [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
2168  bool isStaticNewShape =
2169  llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
2170  if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2171  (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2172  return emitOpError() << "cannot reshape " << inputElementsNum
2173  << " elements into " << newShapeElementsNum;
2174  }
2175  }
2176 
2177  return mlir::success();
2178 }
2179 
2180 // return failure if val is not a constant
2181 // set zp to -1 if val is non-zero float or val is not integer nor float
2182 // otherwise set zp to val's constant value
2183 static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
2184  ElementsAttr zpAttr;
2185  if (!matchPattern(val, m_Constant(&zpAttr))) {
2186  return failure();
2187  }
2188 
2189  Type zpElemType = zpAttr.getElementType();
2190 
2191  if (llvm::isa<FloatType>(zpElemType)) {
2192  if (zpAttr.getValues<APFloat>()[0].isZero()) {
2193  return 0;
2194  }
2195  // return non-zero value to trigger error check
2196  return -1;
2197  }
2198 
2199  if (llvm::isa<IntegerType>(zpElemType)) {
2200  if (signExtend)
2201  return zpAttr.getValues<APInt>()[0].getSExtValue();
2202  else
2203  return zpAttr.getValues<APInt>()[0].getZExtValue();
2204  }
2205 
2206  // return non-zero value to trigger error check
2207  return -1;
2208 }
2209 
2210 template <typename T>
2211 static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
2212  const std::string &operand) {
2213  Type zpElemType = getElementTypeOrSelf(val);
2214 
2215  if (!zpElemType.isInteger(8) && zp != 0) {
2216  // convert operand to lower case for error message
2217  std::string lower = operand;
2218  std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
2219  return op.emitOpError()
2220  << lower << " zero point must be zero for non-int8 integer types";
2221  }
2222 
2223  return success();
2224 }
2225 
2226 static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
2227  const int64_t &zp,
2228  const std::string &operand) {
2229  bool isInputZp = (operand == "Input");
2230 
2231  bool tensorUnsigned =
2232  isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2233  StringRef tensorName = isInputZp ? "input" : "output";
2234 
2235  Type zpElemType = getElementTypeOrSelf(zpVal);
2236 
2237  if (zp != 0) {
2238  if (!zpElemType.isInteger(8) &&
2239  !(zpElemType.isInteger(16) && tensorUnsigned)) {
2240  return op.emitOpError()
2241  << "expect " << tensorName << "_zp of 0, got " << zp;
2242  }
2243  if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
2244  return op.emitOpError() << "expect " << tensorName
2245  << "_zp of 0 or 32768 for unsigned int16 "
2246  << tensorName << ", got " << zp;
2247  }
2248  }
2249 
2250  return success();
2251 }
2252 
2253 #define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2254  FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2255  return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2256  } \
2257  LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2258  return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2259  }
2260 
2261 ZERO_POINT_HELPER(Conv2DOp, Input, true)
2262 ZERO_POINT_HELPER(Conv2DOp, Weight, true)
2263 ZERO_POINT_HELPER(Conv3DOp, Input, true)
2264 ZERO_POINT_HELPER(Conv3DOp, Weight, true)
2265 ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true)
2266 ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true)
2267 ZERO_POINT_HELPER(TransposeConv2DOp, Input, true)
2268 ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true)
2269 ZERO_POINT_HELPER(AvgPool2dOp, Input, true)
2270 ZERO_POINT_HELPER(AvgPool2dOp, Output, true)
2271 ZERO_POINT_HELPER(MatMulOp, A, true)
2272 ZERO_POINT_HELPER(MatMulOp, B, true)
2273 ZERO_POINT_HELPER(NegateOp, Input1, true)
2274 ZERO_POINT_HELPER(NegateOp, Output, true)
2275 ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned())
2276 ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned())
2277 #undef ZERO_POINT_HELPER
2278 
2279 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2280  MLIRContext *context, ::std::optional<Location> location,
2281  TransposeOp::Adaptor adaptor,
2282  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2283  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2284 
2285  // If input rank and permutation length is unknown, the output rank is
2286  // unknown.
2287  if (!inputShape.hasRank()) {
2288  inferredReturnShapes.push_back(ShapedTypeComponents());
2289  return success();
2290  }
2291 
2292  const auto inputRank = inputShape.getRank();
2293 
2294  // This would imply the number of permutations does not match the rank of
2295  // the input which is illegal.
2296  if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
2297  return failure();
2298  }
2299 
2300  SmallVector<int64_t> outputShape;
2301  // Rank-0 means no permutations matter.
2302  if (inputRank == 0) {
2303  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2304  return success();
2305  }
2306 
2307  // Check whether the input dimensions are all the same.
2308  bool allTheSame = true;
2309  for (int i = 1, s = inputRank; i < s; i++) {
2310  if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
2311  allTheSame = false;
2312  break;
2313  }
2314  }
2315 
2316  // If all of the input dimensions are the same we don't care about the
2317  // permutation.
2318  if (allTheSame) {
2319  outputShape.resize(inputRank, inputShape.getDimSize(0));
2320  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2321  return success();
2322  }
2323 
2324  outputShape.resize(inputRank, ShapedType::kDynamic);
2325 
2326  // Constant permutation values must be within the input rank.
2327  if (llvm::any_of(adaptor.getPerms(),
2328  [inputRank](const auto i) { return i >= inputRank; }))
2329  return failure();
2330 
2331  outputShape.reserve(inputRank);
2332  for (int i = 0, s = inputRank; i < s; i++) {
2333  outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
2334  }
2335 
2336  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2337  return success();
2338 }
2339 
2340 LogicalResult tosa::TransposeOp::verify() {
2341  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2342  /* outType = */ getOutput().getType())
2343  .failed()) {
2344  return failure();
2345  }
2346 
2347  const ShapeAdaptor inputShape(getInput1().getType());
2348  const ShapeAdaptor outputShape(getOutput().getType());
2349 
2350  const llvm::ArrayRef<int32_t> constantPerms = getPerms();
2351 
2352  if (inputShape.hasRank() &&
2353  constantPerms.size() != static_cast<size_t>(inputShape.getRank()))
2354  return emitOpError() << "expected perms attribute to have size "
2355  << inputShape.getRank()
2356  << " (input rank) but got size "
2357  << constantPerms.size();
2358 
2359  if (inputShape.hasRank() && outputShape.hasRank() &&
2360  inputShape.getRank() != outputShape.getRank())
2361  return emitOpError()
2362  << "expected input tensor rank to equal result tensor rank";
2363 
2364  if (outputShape.hasRank() &&
2365  constantPerms.size() != static_cast<size_t>(outputShape.getRank()))
2366  return emitOpError() << "expected perms attribute to have size "
2367  << outputShape.getRank()
2368  << " (output rank) but got size "
2369  << constantPerms.size();
2370 
2371  if (!llvm::all_of(constantPerms,
2372  [&constantPerms](int32_t s) {
2373  return s >= 0 &&
2374  static_cast<size_t>(s) < constantPerms.size();
2375  }) ||
2376  !isPermutationVector(llvm::to_vector(llvm::map_range(
2377  constantPerms, [](int32_t v) -> int64_t { return v; }))))
2378  return emitOpError() << "expected valid permutation indices";
2379 
2380  // ERROR_IF(tensor_size(shape1) != tensor_size(shape))
2381  if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2382  inputShape.getNumElements() != outputShape.getNumElements())
2383  return emitOpError() << "expected input1 and output to have same numbers "
2384  "of elements, got "
2385  << inputShape.getNumElements() << " and "
2386  << outputShape.getNumElements();
2387 
2388  // Verify that the types of the input and output tensors are properly
2389  // permuted.
2390  if (inputShape.hasRank() && outputShape.hasRank()) {
2391  for (auto i = 0; i < outputShape.getRank(); i++) {
2392  if (inputShape.isDynamicDim(constantPerms[i]) ||
2393  outputShape.isDynamicDim(i))
2394  continue;
2395 
2396  if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2397  return emitOpError()
2398  << "expected output tensor dim " << i << " to match "
2399  << "input dim " << constantPerms[i] << " with value of "
2400  << inputShape.getDimSize(constantPerms[i]);
2401  }
2402  }
2403 
2404  return success();
2405 }
2406 
2407 LogicalResult TransposeOp::reifyResultShapes(
2408  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2409 
2410  const llvm::ArrayRef<int32_t> transposePerms = getPerms();
2411 
2412  Value input = getInput1();
2413  auto inputType = cast<TensorType>(input.getType());
2414 
2415  SmallVector<OpFoldResult> returnedDims(inputType.getRank());
2416  for (auto dim : transposePerms) {
2417  int32_t dimInInput = transposePerms[dim];
2418  if (inputType.isDynamicDim(dimInInput))
2419  returnedDims[dim] =
2420  builder.create<tensor::DimOp>(getLoc(), input, dimInInput)
2421  .getResult();
2422  else
2423  returnedDims[dim] =
2424  builder.getIndexAttr(inputType.getDimSize(dimInInput));
2425  }
2426 
2427  reifiedReturnShapes.emplace_back(std::move(returnedDims));
2428  return success();
2429 }
2430 
2431 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2432  MLIRContext *context, ::std::optional<Location> location,
2433  GatherOp::Adaptor adaptor,
2434  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2435  llvm::SmallVector<int64_t> outputShape;
2436  outputShape.resize(3, ShapedType::kDynamic);
2437 
2438  ShapeAdaptor valuesShape(adaptor.getValues().getType());
2439  if (valuesShape.hasRank()) {
2440  outputShape[0] = valuesShape.getDimSize(0);
2441  outputShape[2] = valuesShape.getDimSize(2);
2442  }
2443 
2444  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2445  if (indicesShape.hasRank()) {
2446  if (outputShape[0] == ShapedType::kDynamic)
2447  outputShape[0] = indicesShape.getDimSize(0);
2448  if (outputShape[1] == ShapedType::kDynamic)
2449  outputShape[1] = indicesShape.getDimSize(1);
2450  }
2451 
2452  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2453  return success();
2454 }
2455 
2456 LogicalResult tosa::GatherOp::verify() {
2457  if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
2458  /* outType = */ getOutput().getType())
2459  .failed()) {
2460  return failure();
2461  }
2462 
2463  const ShapeAdaptor valuesShape(getValues().getType());
2464  const ShapeAdaptor indicesShape(getIndices().getType());
2465  const ShapeAdaptor outputShape(getOutput().getType());
2466 
2467  int64_t N = ShapedType::kDynamic;
2468  int64_t W = ShapedType::kDynamic;
2469  int64_t C = ShapedType::kDynamic;
2470 
2471  if (valuesShape.hasRank()) {
2472  N = valuesShape.getDimSize(0);
2473  C = valuesShape.getDimSize(2);
2474  }
2475  if (indicesShape.hasRank()) {
2476  const int64_t indicesN = indicesShape.getDimSize(0);
2477  W = indicesShape.getDimSize(1);
2478  if (N == ShapedType::kDynamic)
2479  N = indicesN;
2480  else if (indicesN != ShapedType::kDynamic && N != indicesN)
2481  return emitOpError() << "requires indices dimension 0 to have size " << N
2482  << ", got " << indicesN;
2483  }
2484  if (outputShape.hasRank()) {
2485  const int64_t outputN = outputShape.getDimSize(0);
2486  const int64_t outputW = outputShape.getDimSize(1);
2487  const int64_t outputC = outputShape.getDimSize(2);
2488  if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2489  N != outputN)
2490  return emitOpError() << "requires output dimension 0 to have size " << N
2491  << ", got " << outputN;
2492 
2493  if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2494  W != outputW)
2495  return emitOpError() << "requires output dimension 1 to have size " << W
2496  << ", got " << outputW;
2497  if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2498  C != outputC)
2499  return emitOpError() << "requires output dimension 2 to have size " << C
2500  << ", got " << outputC;
2501  }
2502  return success();
2503 }
2504 
2505 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2506  MLIRContext *context, ::std::optional<Location> location,
2507  ResizeOp::Adaptor adaptor,
2508  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2509  llvm::SmallVector<int64_t, 4> outputShape;
2510  outputShape.resize(4, ShapedType::kDynamic);
2511 
2512  ShapeAdaptor inputShape(adaptor.getInput().getType());
2513  if (!inputShape.hasRank())
2514  return failure();
2515 
2516  outputShape[0] = inputShape.getDimSize(0);
2517  outputShape[3] = inputShape.getDimSize(3);
2518  int64_t inputHeight = inputShape.getDimSize(1);
2519  int64_t inputWidth = inputShape.getDimSize(2);
2520 
2521  if ((inputHeight == ShapedType::kDynamic) ||
2522  (inputWidth == ShapedType::kDynamic))
2523  return failure();
2524 
2525  SmallVector<int64_t> scaleInt, offsetInt, borderInt;
2526  if (!tosa::getConstShapeValues(adaptor.getScale().getDefiningOp(),
2527  scaleInt) ||
2528  !tosa::getConstShapeValues(adaptor.getOffset().getDefiningOp(),
2529  offsetInt) ||
2530  !tosa::getConstShapeValues(adaptor.getBorder().getDefiningOp(),
2531  borderInt)) {
2532  return failure();
2533  }
2534 
2535  // Compute the output shape based on attributes: scale, offset, and border.
2536  const int64_t outputHeight =
2537  (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2538  scaleInt[1]) +
2539  1;
2540 
2541  const int64_t outputWidth =
2542  (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2543  scaleInt[3]) +
2544  1;
2545 
2546  if (outputHeight < 0 || outputWidth < 0) {
2547  return emitOptionalError(
2548  location,
2549  "calculated output height and width must be non-negative, "
2550  "got height = ",
2551  outputHeight, ", width = ", outputWidth);
2552  }
2553 
2554  outputShape[1] = outputHeight;
2555  outputShape[2] = outputWidth;
2556  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2557  return success();
2558 }
2559 
2560 LogicalResult tosa::ResizeOp::verify() {
2561  const Value input = getInput();
2562  const Value output = getOutput();
2563  const RankedTensorType inputType =
2564  llvm::dyn_cast<RankedTensorType>(input.getType());
2565  const RankedTensorType outputType =
2566  llvm::dyn_cast<RankedTensorType>(output.getType());
2567 
2568  SmallVector<int64_t> scaleValues;
2569  SmallVector<int64_t> offsetValues;
2570  SmallVector<int64_t> borderValues;
2571  if (!tosa::getConstShapeValues(getScale().getDefiningOp(), scaleValues) ||
2572  !tosa::getConstShapeValues(getOffset().getDefiningOp(), offsetValues) ||
2573  !tosa::getConstShapeValues(getBorder().getDefiningOp(), borderValues)) {
2574  // Skip following checks if shape is not constant
2575  return success();
2576  }
2577 
2578  if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
2579  return emitOpError("expect all scale values to be > 0, got ")
2580  << scaleValues;
2581 
2582  const int64_t scaleYN = scaleValues[0];
2583  const int64_t scaleYD = scaleValues[1];
2584  const int64_t scaleXN = scaleValues[2];
2585  const int64_t scaleXD = scaleValues[3];
2586 
2587  const int64_t offsetY = offsetValues[0];
2588  const int64_t offsetX = offsetValues[1];
2589 
2590  const int64_t borderY = borderValues[0];
2591  const int64_t borderX = borderValues[1];
2592 
2593  if (!inputType)
2594  return success();
2595  if (!outputType)
2596  return success();
2597 
2598  const int64_t oh = outputType.getDimSize(1);
2599  const int64_t ow = outputType.getDimSize(2);
2600  const int64_t ih = inputType.getDimSize(1);
2601  const int64_t iw = inputType.getDimSize(2);
2602 
2603  // Don't check with input height that could be broadcast (ih != 1)
2604  // since Linalg, a consumer of TOSA, expects broadcasting support
2605  // in resize to be available. Taking the cautious approach for now,
2606  // we can consider removing support for broadcasting later.
2607  if (ih != ShapedType::kDynamic && ih != 1) {
2608  const std::optional<int64_t> calculatedOutHeightMinusOne =
2609  idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
2610  if (!calculatedOutHeightMinusOne.has_value())
2611  return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
2612  "border_y ")
2613  << "to be wholly divisible by scale_y_d, got ((" << ih
2614  << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
2615  << ") / " << scaleYD;
2616  const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
2617  if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
2618  return emitOpError("calculated output height did not match expected: ")
2619  << "calculated=" << calculatedOutHeight << ", expected=" << oh;
2620  }
2621 
2622  // Don't check with input width that could be broadcast (iw != 1)
2623  // since Linalg, a consumer of TOSA, expects broadcasting support
2624  // in resize to be available. Taking the cautious approach for now,
2625  // we can consider removing support for broadcasting later.
2626  if (iw != ShapedType::kDynamic && iw != 1) {
2627  const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
2628  const std::optional<int64_t> calculatedOutWidthMinusOne =
2629  idivCheck(scaledInWidth, scaleXD);
2630  if (!calculatedOutWidthMinusOne.has_value())
2631  return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
2632  "border_x ")
2633  << "to be wholly divisible by scale_x_d, got ((" << iw
2634  << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
2635  << ") / " << scaleXD;
2636  const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
2637  if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
2638  return emitOpError("calculated output width did not match expected: ")
2639  << "calculated=" << calculatedOutWidth << ", expected=" << ow;
2640  }
2641 
2642  return success();
2643 }
2644 
2645 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
2646  MLIRContext *context, ::std::optional<Location> location,
2647  ScatterOp::Adaptor adaptor,
2648  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2649  llvm::SmallVector<int64_t> outputShape;
2650  outputShape.resize(3, ShapedType::kDynamic);
2651 
2652  ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
2653  if (valuesInShape.hasRank()) {
2654  outputShape[0] = valuesInShape.getDimSize(0);
2655  outputShape[1] = valuesInShape.getDimSize(1);
2656  outputShape[2] = valuesInShape.getDimSize(2);
2657  }
2658 
2659  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2660  if (indicesShape.hasRank()) {
2661  if (outputShape[0] == ShapedType::kDynamic)
2662  outputShape[0] = indicesShape.getDimSize(0);
2663  }
2664 
2665  ShapeAdaptor inputShape(adaptor.getInput().getType());
2666  if (inputShape.hasRank()) {
2667  if (outputShape[0] == ShapedType::kDynamic)
2668  outputShape[0] = inputShape.getDimSize(0);
2669  if (outputShape[2] == ShapedType::kDynamic)
2670  outputShape[2] = inputShape.getDimSize(2);
2671  }
2672 
2673  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2674  return success();
2675 }
2676 
2677 LogicalResult tosa::ScatterOp::verify() {
2678  if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
2679  /* outType = */ getValuesOut().getType())
2680  .failed() ||
2681  verifySameElementTypes(*this, /* inType = */ getInput().getType(),
2682  /* outType = */ getValuesOut().getType())
2683  .failed()) {
2684  return failure();
2685  }
2686 
2687  const ShapeAdaptor valuesInShape(getValuesIn().getType());
2688  const ShapeAdaptor indicesShape(getIndices().getType());
2689  const ShapeAdaptor inputShape(getInput().getType());
2690  const ShapeAdaptor outputShape(getValuesOut().getType());
2691 
2692  int64_t N = ShapedType::kDynamic;
2693  int64_t K = ShapedType::kDynamic;
2694  int64_t W = ShapedType::kDynamic;
2695  int64_t C = ShapedType::kDynamic;
2696  if (valuesInShape.hasRank()) {
2697  N = valuesInShape.getDimSize(0);
2698  K = valuesInShape.getDimSize(1);
2699  C = valuesInShape.getDimSize(2);
2700  }
2701  if (indicesShape.hasRank()) {
2702  const int64_t indicesN = indicesShape.getDimSize(0);
2703  W = indicesShape.getDimSize(1);
2704  if (N == ShapedType::kDynamic)
2705  N = indicesN;
2706  else if (indicesN != ShapedType::kDynamic && N != indicesN)
2707  return emitOpError() << "requires indices dimension 0 to have size " << N
2708  << ", got " << indicesN;
2709  }
2710  if (inputShape.hasRank()) {
2711  const int64_t inputN = inputShape.getDimSize(0);
2712  const int64_t inputW = inputShape.getDimSize(1);
2713  const int64_t inputC = inputShape.getDimSize(2);
2714  if (N == ShapedType::kDynamic)
2715  N = inputN;
2716  else if (inputN != ShapedType::kDynamic && N != inputN)
2717  return emitOpError() << "requires input dimension 0 to have size " << N
2718  << ", got " << inputN;
2719  if (W == ShapedType::kDynamic)
2720  W = inputW;
2721  else if (inputW != ShapedType::kDynamic && W != inputW)
2722  return emitOpError() << "requires input dimension 1 to have size " << W
2723  << ", got " << inputW;
2724 
2725  if (C == ShapedType::kDynamic)
2726  C = inputC;
2727  else if (inputC != ShapedType::kDynamic && C != inputC)
2728  return emitOpError() << "requires input dimension 2 to have size " << C
2729  << ", got " << inputC;
2730  }
2731  if (outputShape.hasRank()) {
2732  const int64_t outputN = outputShape.getDimSize(0);
2733  const int64_t outputK = outputShape.getDimSize(1);
2734  const int64_t outputC = outputShape.getDimSize(2);
2735  if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2736  N != outputN)
2737  return emitOpError() << "requires values_out dimension 0 to have size "
2738  << N << ", got " << outputN;
2739  if (K == ShapedType::kDynamic)
2740  K = outputK;
2741  else if (outputK != ShapedType::kDynamic && K != outputK)
2742  return emitOpError() << "requires values_out dimension 1 to have size "
2743  << K << ", got " << outputK;
2744  if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2745  C != outputC)
2746  return emitOpError() << "requires values_out dimension 2 to have size "
2747  << C << ", got " << outputC;
2748  }
2749  if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
2750  return emitOpError() << "requires dimensions K >= W, got K=" << K
2751  << " and W=" << W;
2752 
2753  return success();
2754 }
2755 
2756 static LogicalResult ReduceInferReturnTypes(
2757  ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
2758  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2759  int64_t axisVal = axis.getValue().getSExtValue();
2760  if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
2761  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
2762  return success();
2763  }
2764 
2765  SmallVector<int64_t> outputShape;
2766  operandShape.getDims(outputShape);
2767  outputShape[axisVal] = 1;
2768  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2769  return success();
2770 }
2771 
2772 #define COMPATIBLE_RETURN_TYPES(OP) \
2773  bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
2774  if (l.size() != r.size() || l.size() != 1) \
2775  return false; \
2776  if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
2777  return false; \
2778  return succeeded(verifyCompatibleShape(l[0], r[0])); \
2779  }
2780 
2781 #define REDUCE_SHAPE_INFER(OP) \
2782  LogicalResult OP::inferReturnTypeComponents( \
2783  MLIRContext *context, ::std::optional<Location> location, \
2784  OP::Adaptor adaptor, \
2785  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2786  Type inputType = \
2787  llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
2788  ShapeAdaptor inputShape(adaptor.getInput().getType()); \
2789  const Properties &prop = adaptor.getProperties(); \
2790  return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
2791  inferredReturnShapes); \
2792  } \
2793  COMPATIBLE_RETURN_TYPES(OP)
2794 
2795 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
2796 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
2797 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
2798 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
2799 REDUCE_SHAPE_INFER(tosa::ReduceProductOp)
2800 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
2801 #undef REDUCE_SHAPE_INFER
2802 COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
2803 #undef COMPATIBLE_RETURN_TYPES
2804 
2805 template <typename T>
2806 static LogicalResult verifyReduceOp(T op) {
2807  // All TOSA reduce Ops have input, output and axis.
2808  TensorType inputType = op.getInput().getType();
2809  TensorType outputType = op.getOutput().getType();
2810  int32_t reduceAxis = op.getAxis();
2811 
2812  if (reduceAxis < 0) {
2813  op.emitOpError("reduce axis must not be negative");
2814  return failure();
2815  }
2816  if (inputType.hasRank()) {
2817  int64_t inputRank = inputType.getRank();
2818  // We allow for a special case where the input/output shape has rank 0 and
2819  // axis is also 0.
2820  if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
2821  op.emitOpError("expect input tensor rank (")
2822  << inputRank << ") to be larger than reduce axis (" << reduceAxis
2823  << ")";
2824  return failure();
2825  }
2826  }
2827  if (outputType.hasRank()) {
2828  int64_t outputRank = outputType.getRank();
2829  if (inputType.hasRank() && outputRank != inputType.getRank()) {
2830  op.emitOpError(
2831  "expect output tensor rank to be equal to input tensor rank");
2832  return failure();
2833  }
2834  if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
2835  op.emitOpError("expect output tensor rank (")
2836  << outputRank << ") to be larger than reduce axis (" << reduceAxis
2837  << ")";
2838  return failure();
2839  }
2840  // We can only verify the reduced dimension size to be 1 if this is not
2841  // the special case of output rank == 0.
2842  if (outputRank != 0) {
2843  auto outputShape = outputType.getShape();
2844  if (!outputType.isDynamicDim(reduceAxis) &&
2845  outputShape[reduceAxis] != 1) {
2846  op.emitOpError("expect reduced dimension size to be 1, got ")
2847  << outputShape[reduceAxis];
2848  return failure();
2849  }
2850  }
2851  }
2852  return success();
2853 }
2854 
2855 LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
2856 LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
2857 LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
2858 LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
2859 LogicalResult tosa::ReduceProductOp::verify() { return verifyReduceOp(*this); }
2860 LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
2861 
2862 static LogicalResult NAryInferReturnTypes(
2863  const ValueShapeRange &operands,
2864  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2865  llvm::SmallVector<int64_t> outShape;
2866  if (resolveBroadcastShape(operands, outShape).failed()) {
2867  inferredReturnShapes.push_back(ShapedTypeComponents());
2868  } else {
2869  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2870  }
2871  return success();
2872 }
2873 
2874 #define NARY_SHAPE_INFER(OP) \
2875  LogicalResult OP::inferReturnTypeComponents( \
2876  MLIRContext *context, ::std::optional<Location> location, \
2877  ValueShapeRange operands, DictionaryAttr attributes, \
2878  OpaqueProperties properties, RegionRange regions, \
2879  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2880  return NAryInferReturnTypes(operands, inferredReturnShapes); \
2881  }
2882 
2883 NARY_SHAPE_INFER(tosa::AbsOp)
2884 NARY_SHAPE_INFER(tosa::AddOp)
2885 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
2886 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
2887 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
2888 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
2889 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
2890 NARY_SHAPE_INFER(tosa::CastOp)
2891 NARY_SHAPE_INFER(tosa::CeilOp)
2892 NARY_SHAPE_INFER(tosa::ClampOp)
2893 NARY_SHAPE_INFER(tosa::ClzOp)
2894 NARY_SHAPE_INFER(tosa::CosOp)
2895 NARY_SHAPE_INFER(tosa::ExpOp)
2896 NARY_SHAPE_INFER(tosa::FloorOp)
2897 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
2898 NARY_SHAPE_INFER(tosa::GreaterOp)
2899 NARY_SHAPE_INFER(tosa::IdentityOp)
2900 NARY_SHAPE_INFER(tosa::IntDivOp)
2901 NARY_SHAPE_INFER(tosa::LogOp)
2902 NARY_SHAPE_INFER(tosa::LogicalAndOp)
2903 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
2904 NARY_SHAPE_INFER(tosa::LogicalNotOp)
2905 NARY_SHAPE_INFER(tosa::LogicalOrOp)
2906 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
2907 NARY_SHAPE_INFER(tosa::LogicalXorOp)
2908 NARY_SHAPE_INFER(tosa::MaximumOp)
2909 NARY_SHAPE_INFER(tosa::MinimumOp)
2910 NARY_SHAPE_INFER(tosa::PowOp)
2911 NARY_SHAPE_INFER(tosa::ReciprocalOp)
2912 NARY_SHAPE_INFER(tosa::ReverseOp)
2913 NARY_SHAPE_INFER(tosa::RsqrtOp)
2914 NARY_SHAPE_INFER(tosa::SinOp)
2915 NARY_SHAPE_INFER(tosa::SelectOp)
2916 NARY_SHAPE_INFER(tosa::SubOp)
2917 NARY_SHAPE_INFER(tosa::TanhOp)
2918 NARY_SHAPE_INFER(tosa::ErfOp)
2919 NARY_SHAPE_INFER(tosa::SigmoidOp)
2920 #undef PRED_SHAPE_INFER
2921 
2922 LogicalResult tosa::NegateOp::inferReturnTypeComponents(
2923  MLIRContext *context, ::std::optional<Location> location,
2924  NegateOp::Adaptor adaptor,
2925  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2926  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2927  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
2928  return success();
2929 }
2930 
2931 LogicalResult tosa::NegateOp::verify() {
2932  // Verify same element type
2933  const Type input1Type = getInput1().getType();
2934  const Type outputType = getOutput().getType();
2935  if (verifySameElementTypes(*this, input1Type, outputType).failed())
2936  return failure();
2937 
2938  // Verify same shape
2939  const SmallVector<Type, 2> types = {input1Type, outputType};
2940  if (failed(verifyCompatibleShapes(types)))
2941  return emitOpError() << "requires the same shape for input1 and output";
2942 
2943  const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
2944  const Type input1ZpEType =
2945  getStorageElementTypeOrSelf(getInput1Zp().getType());
2946  if (input1EType != input1ZpEType) {
2947  return emitOpError("expect both input1 and its zero point are the same "
2948  "element type, got ")
2949  << input1EType << " and " << input1ZpEType;
2950  }
2951  const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
2952  const Type outputZpEType =
2953  getStorageElementTypeOrSelf(getOutputZp().getType());
2954  if (outputEType != outputZpEType) {
2955  return emitOpError("expect both output and its zero point are the same "
2956  "element type, got ")
2957  << outputEType << " and " << outputZpEType;
2958  }
2959 
2960  FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2961  if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
2962  return failure();
2963 
2964  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2965  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
2966  return failure();
2967 
2968  return success();
2969 }
2970 
2971 static LogicalResult poolingInferReturnTypes(
2972  ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
2973  ArrayRef<int64_t> pad,
2974  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2975  llvm::SmallVector<int64_t> outputShape;
2976  outputShape.resize(4, ShapedType::kDynamic);
2977 
2978  // We only know the rank if the input type is unranked.
2979  if (!inputShape) {
2980  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2981  return success();
2982  }
2983 
2984  // Batch and number of channels are identical for pooling layer.
2985  outputShape[0] = inputShape.getDimSize(0);
2986  outputShape[3] = inputShape.getDimSize(3);
2987 
2988  int64_t height = inputShape.getDimSize(1);
2989  int64_t width = inputShape.getDimSize(2);
2990 
2991  if (!ShapedType::isDynamic(height)) {
2992  int64_t padded = height + pad[0] + pad[1] - kernel[0];
2993  outputShape[1] = padded / stride[0] + 1;
2994  }
2995 
2996  if (!ShapedType::isDynamic(width)) {
2997  int64_t padded = width + pad[2] + pad[3] - kernel[1];
2998  outputShape[2] = padded / stride[1] + 1;
2999  }
3000 
3001  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3002  return success();
3003 }
3004 
3005 LogicalResult Conv2DOp::inferReturnTypeComponents(
3006  MLIRContext *context, ::std::optional<Location> location,
3007  Conv2DOp::Adaptor adaptor,
3008  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3009  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3010 
3011  int64_t inputWidth = ShapedType::kDynamic;
3012  int64_t inputHeight = ShapedType::kDynamic;
3013  int64_t weightWidth = ShapedType::kDynamic;
3014  int64_t weightHeight = ShapedType::kDynamic;
3015 
3016  // Input shape describes input width/height and batch.
3017 
3018  ShapeAdaptor inputShape(adaptor.getInput().getType());
3019  if (inputShape.hasRank()) {
3020  outputShape[0] = inputShape.getDimSize(0);
3021  inputHeight = inputShape.getDimSize(1);
3022  inputWidth = inputShape.getDimSize(2);
3023  }
3024 
3025  // Weight shapes describes the filter width/height and the output channels.
3026  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3027  if (weightShape.hasRank()) {
3028  outputShape[3] = weightShape.getDimSize(0);
3029  weightHeight = weightShape.getDimSize(1);
3030  weightWidth = weightShape.getDimSize(2);
3031  }
3032 
3033  // Bias shape can describe the output channels.
3034  ShapeAdaptor biasShape(adaptor.getBias().getType());
3035  if (biasShape.hasRank()) {
3036  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3037  ? biasShape.getDimSize(0)
3038  : outputShape[3];
3039  }
3040 
3041  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3042  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3043  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3044 
3045  if (!ShapedType::isDynamic(inputHeight) &&
3046  !ShapedType::isDynamic(weightHeight)) {
3047  int64_t inputSize = inputHeight + padding[0] + padding[1];
3048  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3049  int64_t unstridedResult = inputSize - filterSize + 1;
3050  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3051  }
3052 
3053  if (!ShapedType::isDynamic(inputWidth) &&
3054  !ShapedType::isDynamic(weightWidth)) {
3055  int64_t inputSize = inputWidth + padding[2] + padding[3];
3056  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3057  int64_t unstridedResult = inputSize - filterSize + 1;
3058  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3059  }
3060 
3061  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3062  return success();
3063 }
3064 
3065 LogicalResult Conv2DOp::verify() {
3066  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3067  verifyConvOpErrorIf(*this).failed())
3068  return failure();
3069  return success();
3070 }
3071 
3072 LogicalResult Conv3DOp::inferReturnTypeComponents(
3073  MLIRContext *context, ::std::optional<Location> location,
3074  Conv3DOp::Adaptor adaptor,
3075  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3076  llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
3077 
3078  int64_t inputWidth = ShapedType::kDynamic;
3079  int64_t inputHeight = ShapedType::kDynamic;
3080  int64_t inputDepth = ShapedType::kDynamic;
3081 
3082  int64_t weightWidth = ShapedType::kDynamic;
3083  int64_t weightHeight = ShapedType::kDynamic;
3084  int64_t weightDepth = ShapedType::kDynamic;
3085 
3086  // Input shape describes input width/height and batch.
3087  ShapeAdaptor inputShape(adaptor.getInput().getType());
3088  if (inputShape.hasRank()) {
3089  outputShape[0] = inputShape.getDimSize(0);
3090  inputDepth = inputShape.getDimSize(1);
3091  inputHeight = inputShape.getDimSize(2);
3092  inputWidth = inputShape.getDimSize(3);
3093  }
3094 
3095  // Weight shapes describes the filter width/height and the output channels.
3096  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3097  if (weightShape.hasRank()) {
3098  outputShape[4] = weightShape.getDimSize(0);
3099  weightDepth = weightShape.getDimSize(1);
3100  weightHeight = weightShape.getDimSize(2);
3101  weightWidth = weightShape.getDimSize(3);
3102  }
3103 
3104  // Bias shape can describe the output channels.
3105  ShapeAdaptor biasShape(adaptor.getBias().getType());
3106  if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
3107  outputShape[4] = biasShape.getDimSize(0);
3108  }
3109 
3110  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3111  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3112  llvm::ArrayRef<int64_t> pad = adaptor.getPad();
3113 
3114  if (!ShapedType::isDynamic(inputDepth) &&
3115  !ShapedType::isDynamic(weightDepth)) {
3116  int32_t inputSize = inputDepth + pad[0] + pad[1];
3117  int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
3118  int32_t unstridedResult = inputSize - filterSize + 1;
3119  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3120  }
3121 
3122  if (!ShapedType::isDynamic(inputHeight) &&
3123  !ShapedType::isDynamic(weightHeight)) {
3124  int32_t inputSize = inputHeight + pad[2] + pad[3];
3125  int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
3126  int32_t unstridedResult = inputSize - filterSize + 1;
3127  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3128  }
3129 
3130  if (!ShapedType::isDynamic(inputWidth) &&
3131  !ShapedType::isDynamic(weightWidth)) {
3132  int32_t inputSize = inputWidth + pad[4] + pad[5];
3133  int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
3134  int32_t unstridedResult = inputSize - filterSize + 1;
3135  outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3136  }
3137 
3138  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3139  return success();
3140 }
3141 
3142 LogicalResult Conv3DOp::verify() {
3143  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3144  verifyConvOpErrorIf(*this).failed())
3145  return failure();
3146  return success();
3147 }
3148 
3149 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3150  MLIRContext *context, ::std::optional<Location> location,
3151  AvgPool2dOp::Adaptor adaptor,
3152  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3153  ShapeAdaptor inputShape(adaptor.getInput().getType());
3154  const Properties &prop = adaptor.getProperties();
3155  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3156  inferredReturnShapes);
3157 }
3158 
3159 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3160  MLIRContext *context, ::std::optional<Location> location,
3161  MaxPool2dOp::Adaptor adaptor,
3162  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3163  ShapeAdaptor inputShape(adaptor.getInput().getType());
3164  const Properties &prop = adaptor.getProperties();
3165  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3166  inferredReturnShapes);
3167 }
3168 
3169 LogicalResult MaxPool2dOp::verify() {
3170  if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
3171  /* outType = */ getOutput().getType())))
3172  return failure();
3173 
3174  if (failed(verifyPoolingOp(*this)))
3175  return failure();
3176 
3177  return success();
3178 }
3179 
3180 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3181  MLIRContext *context, ::std::optional<Location> location,
3182  DepthwiseConv2DOp::Adaptor adaptor,
3183  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3184  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3185 
3186  int64_t inputWidth = ShapedType::kDynamic;
3187  int64_t inputHeight = ShapedType::kDynamic;
3188  int64_t inputChannels = ShapedType::kDynamic;
3189 
3190  int64_t weightWidth = ShapedType::kDynamic;
3191  int64_t weightHeight = ShapedType::kDynamic;
3192  int64_t depthChannels = ShapedType::kDynamic;
3193 
3194  // Input shape describes input width/height and batch.
3195  ShapeAdaptor inputShape(adaptor.getInput().getType());
3196  if (inputShape.hasRank()) {
3197  outputShape[0] = inputShape.getDimSize(0);
3198  inputHeight = inputShape.getDimSize(1);
3199  inputWidth = inputShape.getDimSize(2);
3200  inputChannels = inputShape.getDimSize(3);
3201  }
3202 
3203  // Weight shapes describes the filter width/height and the output channels.
3204  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3205  if (weightShape.hasRank()) {
3206  weightHeight = weightShape.getDimSize(0);
3207  weightWidth = weightShape.getDimSize(1);
3208  inputChannels = ShapedType::isDynamic(inputChannels)
3209  ? weightShape.getDimSize(2)
3210  : inputChannels;
3211  depthChannels = weightShape.getDimSize(3);
3212  }
3213 
3214  // If both inputChannels and depthChannels are available we can determine
3215  // the output channels.
3216  if (!ShapedType::isDynamic(inputChannels) &&
3217  !ShapedType::isDynamic(depthChannels)) {
3218  outputShape[3] = inputChannels * depthChannels;
3219  }
3220 
3221  // Bias shape can describe the output channels.
3222  ShapeAdaptor biasShape(adaptor.getBias().getType());
3223  if (biasShape.hasRank()) {
3224  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3225  ? biasShape.getDimSize(0)
3226  : outputShape[3];
3227  }
3228 
3229  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3230  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3231  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3232 
3233  if (!ShapedType::isDynamic(inputHeight) &&
3234  !ShapedType::isDynamic(weightHeight)) {
3235  int64_t inputSize = inputHeight + padding[0] + padding[1];
3236  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3237  int64_t unstridedResult = inputSize - filterSize + 1;
3238  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3239  }
3240 
3241  if (!ShapedType::isDynamic(inputWidth) &&
3242  !ShapedType::isDynamic(weightWidth)) {
3243  int64_t inputSize = inputWidth + padding[2] + padding[3];
3244  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3245  int64_t unstridedResult = inputSize - filterSize + 1;
3246  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3247  }
3248 
3249  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3250  return success();
3251 }
3252 
3253 LogicalResult DepthwiseConv2DOp::verify() {
3254  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3255  verifyConvOpErrorIf(*this).failed())
3256  return failure();
3257  return success();
3258 }
3259 
3260 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3261  MLIRContext *context, ::std::optional<Location> location,
3262  TransposeConv2DOp::Adaptor adaptor,
3263  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3264  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3265 
3266  int64_t inputWidth = ShapedType::kDynamic;
3267  int64_t inputHeight = ShapedType::kDynamic;
3268  int64_t weightWidth = ShapedType::kDynamic;
3269  int64_t weightHeight = ShapedType::kDynamic;
3270 
3271  // Input shape describes input width/height and batch.
3272  ShapeAdaptor inputShape(adaptor.getInput().getType());
3273  if (inputShape.hasRank()) {
3274  outputShape[0] = ShapedType::isDynamic(outputShape[0])
3275  ? inputShape.getDimSize(0)
3276  : outputShape[0];
3277  inputHeight = inputShape.getDimSize(1);
3278  inputWidth = inputShape.getDimSize(2);
3279  }
3280 
3281  // Weight shapes describes the filter width/height and the output channels.
3282  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3283  if (weightShape.hasRank()) {
3284  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3285  ? weightShape.getDimSize(0)
3286  : outputShape[3];
3287  weightHeight = weightShape.getDimSize(1);
3288  weightWidth = weightShape.getDimSize(2);
3289  }
3290 
3291  // Bias shape can describe the output channels.
3292  ShapeAdaptor biasShape(adaptor.getInput().getType());
3293  if (biasShape.hasRank()) {
3294  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3295  ? biasShape.getDimSize(0)
3296  : outputShape[3];
3297  }
3298 
3299  llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
3300  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3301 
3302  if (!ShapedType::isDynamic(inputHeight) &&
3303  !ShapedType::isDynamic(weightHeight)) {
3304  int64_t calculateSize =
3305  (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3306  outputShape[1] =
3307  ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3308  }
3309 
3310  if (!ShapedType::isDynamic(inputWidth) &&
3311  !ShapedType::isDynamic(weightWidth)) {
3312  int64_t calculateSize =
3313  (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3314  outputShape[2] =
3315  ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3316  }
3317 
3318  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3319  return success();
3320 }
3321 
3322 LogicalResult TransposeConv2DOp::verify() {
3323  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
3324  return failure();
3325 
3326  const llvm::ArrayRef<int64_t> strides = getStride();
3327  const int64_t strideY = strides[0];
3328  const int64_t strideX = strides[1];
3329 
3330  if (strideY < 1 || strideX < 1)
3331  return emitOpError("expect all stride values to be >= 1, got [")
3332  << strides << "]";
3333 
3334  const auto checkPadAgainstKernelDim =
3335  [this](int64_t pad_value, int64_t kernel_dim_size,
3336  llvm::StringRef pad_name,
3337  llvm::StringRef kernel_dim_name) -> LogicalResult {
3338  if (pad_value <= -kernel_dim_size)
3339  return emitOpError("expected ")
3340  << pad_name << " > -" << kernel_dim_name
3341  << ", but got: " << pad_name << "=" << pad_value << " and "
3342  << kernel_dim_name << "=" << kernel_dim_size;
3343  return success();
3344  };
3345 
3346  const llvm::ArrayRef<int64_t> padding = getOutPad();
3347  const int64_t outPadTop = padding[0];
3348  const int64_t outPadBottom = padding[1];
3349  const int64_t outPadLeft = padding[2];
3350  const int64_t outPadRight = padding[3];
3351 
3352  const auto weightType =
3353  llvm::dyn_cast<RankedTensorType>(getWeight().getType());
3354 
3355  if (weightType) {
3356  const int64_t kernelHeight = weightType.getDimSize(1);
3357  if (!ShapedType::isDynamic(kernelHeight)) {
3358  if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
3359  "out_pad_top", "KH")))
3360  return failure();
3361 
3362  if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3363  "out_pad_bottom", "KH")))
3364  return failure();
3365  }
3366 
3367  const int64_t kernelWidth = weightType.getDimSize(2);
3368  if (!ShapedType::isDynamic(kernelWidth)) {
3369  if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
3370  "out_pad_left", "KW")))
3371  return failure();
3372 
3373  if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3374  "out_pad_right", "KW")))
3375  return failure();
3376  }
3377  }
3378 
3379  // Rest of the checks depend on the output type being a RankedTensorType
3380  const auto outputType =
3381  llvm::dyn_cast<RankedTensorType>(getOutput().getType());
3382  if (!outputType)
3383  return success();
3384 
3385  const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
3386  if (inputType && weightType) {
3387  const int64_t inputHeight = inputType.getDimSize(1);
3388  const int64_t kernelHeight = weightType.getDimSize(1);
3389  const int64_t outputHeight = outputType.getDimSize(1);
3390 
3391  if (!ShapedType::isDynamic(inputHeight) &&
3392  !ShapedType::isDynamic(outputHeight)) {
3393  if (outputHeight !=
3394  (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3395  return emitOpError(
3396  "dimension mismatch: expected OH == (IH - 1) * stride_y "
3397  "+ out_pad_top + out_pad_bottom + KH, but got ")
3398  << outputHeight << " != (" << inputHeight << " - 1) * "
3399  << strideY << " + " << outPadTop << " + " << outPadBottom
3400  << " + " << kernelHeight;
3401  }
3402 
3403  const int64_t inputWidth = inputType.getDimSize(2);
3404  const int64_t kernelWidth = weightType.getDimSize(2);
3405  const int64_t outputWidth = outputType.getDimSize(2);
3406 
3407  if (!ShapedType::isDynamic(inputWidth) &&
3408  !ShapedType::isDynamic(outputWidth)) {
3409  if (outputWidth !=
3410  (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3411  return emitOpError(
3412  "dimension mismatch: expected OW == (IW - 1) * stride_x "
3413  "+ out_pad_left + out_pad_right + KW, but got ")
3414  << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
3415  << " + " << outPadLeft << " + " << outPadRight << " + "
3416  << kernelWidth;
3417  }
3418  }
3419 
3420  const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().getType());
3421 
3422  if (!biasType)
3423  return success();
3424 
3425  const int64_t biasChannels = biasType.getDimSize(0);
3426 
3427  // Skip further checks if bias is dynamic
3428  if (biasChannels == ShapedType::kDynamic)
3429  return success();
3430 
3431  const int64_t outputChannels = outputType.getDimSize(3);
3432  if (biasChannels != outputChannels && biasChannels != 1)
3433  return emitOpError(
3434  "bias channels expected to be equal to output channels (")
3435  << outputChannels << ") or 1, got " << biasChannels;
3436 
3437  return success();
3438 }
3439 
3440 LogicalResult RescaleOp::verify() {
3441  auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
3442  if (!inputType) {
3443  emitOpError("expect shaped tensor for input, got ") << getInput().getType();
3444  return failure();
3445  }
3446 
3447  auto inputElementType =
3448  getStorageElementTypeOrSelf(inputType.getElementType());
3449  if (!mlir::isa<IntegerType>(inputElementType)) {
3450  emitOpError("expect input to have integer element type, got ")
3451  << inputElementType;
3452  return failure();
3453  }
3454 
3455  auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
3456  if (!outputType) {
3457  emitOpError("expect shaped tensor for output, got ")
3458  << getOutput().getType();
3459  return failure();
3460  }
3461 
3462  auto outputElementType =
3463  getStorageElementTypeOrSelf(outputType.getElementType());
3464  if (!mlir::isa<IntegerType>(outputElementType)) {
3465  emitOpError("expect output to have integer element type, got ")
3466  << outputElementType;
3467  return failure();
3468  }
3469 
3470  if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
3471  .failed())
3472  return failure();
3473 
3474  if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
3475  .failed())
3476  return failure();
3477 
3478  FailureOr<int64_t> maybeIZp = getInputZeroPoint();
3479  if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
3480  return failure();
3481 
3482  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3483  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3484  return failure();
3485 
3486  auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
3487  if (!multiplierType) {
3488  emitOpError("expect shaped tensor for multiplier, got ")
3489  << getMultiplier().getType();
3490  return failure();
3491  }
3492 
3493  auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
3494  if (!shiftType) {
3495  emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
3496  return failure();
3497  }
3498 
3499  // multiplier element type must be i32 for scale32 = true
3500  if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
3501  emitOpError("expect i32 element type for multiplier for scale32=true, got ")
3502  << multiplierType.getElementType();
3503  return failure();
3504  }
3505 
3506  // multiplier element type must be i16 for scale32 = false
3507  if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
3508  emitOpError(
3509  "expect i16 element type for multiplier for scale32=false, got ")
3510  << multiplierType.getElementType();
3511  return failure();
3512  }
3513 
3514  if (!inputType.hasRank())
3515  return success();
3516 
3517  // multiplier/shift must have shape = {numChannels},
3518  // where numChannel is 1 if per_channel = false
3519  // otherwise numChannel is dimension in input shape's last axis
3520  int64_t numChannels = 1;
3521  if (getPerChannel()) {
3522  if (inputType.getRank() < 1) {
3523  emitOpError("requires input to be at least rank 1 when per_channel is "
3524  "true, but got rank ")
3525  << inputType.getRank();
3526  return failure();
3527  }
3528  numChannels = inputType.getDimSize(inputType.getRank() - 1);
3529  }
3530 
3531  if (!multiplierType.hasRank())
3532  return success();
3533 
3534  ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
3535  // multiplier input has rank 1 by dialect definition
3536  if (multiplierShape[0] != ShapedType::kDynamic &&
3537  multiplierShape[0] != numChannels) {
3538  emitOpError("expect shape of { ")
3539  << numChannels << " } for multiplier input, got { "
3540  << multiplierShape[0] << " }";
3541  return failure();
3542  }
3543 
3544  if (!shiftType.hasRank())
3545  return success();
3546 
3547  ArrayRef<int64_t> shiftShape = shiftType.getShape();
3548  // shift input has rank 1 by dialect definition
3549  if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3550  emitOpError("expect shape of { ")
3551  << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
3552  return failure();
3553  }
3554 
3555  return success();
3556 }
3557 
3558 LogicalResult RescaleOp::inferReturnTypeComponents(
3559  MLIRContext *context, ::std::optional<Location> location,
3560  RescaleOp::Adaptor adaptor,
3561  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3562  ShapeAdaptor inputShape(adaptor.getInput().getType());
3563  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3564  return success();
3565 }
3566 
3567 LogicalResult IfOp::inferReturnTypeComponents(
3568  MLIRContext *context, ::std::optional<Location> location,
3569  IfOp::Adaptor adaptor,
3570  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3572  for (Region *region : adaptor.getRegions()) {
3573  for (auto &block : *region)
3574  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3575  yieldOps.push_back(returnOp);
3576  }
3577 
3578  if (yieldOps.empty())
3579  return failure();
3580 
3581  // Get the initial type information for the yield op.
3582  llvm::SmallVector<ValueKnowledge> resultKnowledge;
3583  resultKnowledge.reserve(yieldOps.front().getNumOperands());
3584  for (auto operand : yieldOps.front().getOperands()) {
3585  resultKnowledge.push_back(
3586  ValueKnowledge::getKnowledgeFromType(operand.getType()));
3587  }
3588 
3589  for (auto yieldOp : yieldOps) {
3590  if (resultKnowledge.size() != yieldOp.getNumOperands())
3591  return failure();
3592 
3593  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
3594  int32_t index = it.index();
3595  auto meet = ValueKnowledge::meet(
3596  resultKnowledge[index],
3597  ValueKnowledge::getKnowledgeFromType(it.value().getType()));
3598  if (!meet)
3599  continue;
3600  resultKnowledge[index] = meet;
3601  }
3602  }
3603 
3604  for (const ValueKnowledge &result : resultKnowledge) {
3605  inferredReturnShapes.push_back(result.getShapedTypeComponents());
3606  }
3607 
3608  return success();
3609 }
3610 
3611 LogicalResult WhileOp::inferReturnTypeComponents(
3612  MLIRContext *context, ::std::optional<Location> location,
3613  WhileOp::Adaptor adaptor,
3614  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3616  for (auto &block : adaptor.getBodyGraph())
3617  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3618  yieldOps.push_back(returnOp);
3619 
3620  // TOSA's while must have a tosa.yield as its terminator. If not found this
3621  // tosa.while is invalid.
3622  if (yieldOps.empty())
3623  return failure();
3624 
3625  // Get the initial type information from the operand types.
3626  llvm::SmallVector<ValueKnowledge> resultKnowledge;
3627  resultKnowledge.reserve(yieldOps.front().getNumOperands());
3628  for (auto operand : yieldOps.front().getOperands()) {
3629  resultKnowledge.push_back(
3630  ValueKnowledge::getKnowledgeFromType(operand.getType()));
3631  }
3632 
3633  for (auto yieldOp : yieldOps) {
3634  if (resultKnowledge.size() != yieldOp.getNumOperands())
3635  return failure();
3636 
3637  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
3638  int32_t index = it.index();
3639  if (auto meet = ValueKnowledge::meet(
3640  resultKnowledge[index],
3641  ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
3642  resultKnowledge[index] = meet;
3643  }
3644  }
3645  }
3646 
3647  for (const ValueKnowledge &result : resultKnowledge) {
3648  inferredReturnShapes.push_back(result.getShapedTypeComponents());
3649  }
3650 
3651  return success();
3652 }
3653 
3654 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
3655  if (auto vt = llvm::dyn_cast<VectorType>(getType()))
3656  return llvm::to_vector<4>(vt.getShape());
3657  return std::nullopt;
3658 }
3659 
3660 // parse and print of IfOp refer to the implementation of SCF dialect.
3661 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
3662  // Create the regions for 'then'.
3663  result.regions.reserve(2);
3664  Region *thenRegion = result.addRegion();
3665  Region *elseRegion = result.addRegion();
3666 
3667  auto &builder = parser.getBuilder();
3669  // Create a i1 tensor type for the boolean condition.
3670  Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
3671  if (parser.parseOperand(cond) ||
3672  parser.resolveOperand(cond, i1Type, result.operands))
3673  return failure();
3674  // Parse optional results type list.
3675  if (parser.parseOptionalArrowTypeList(result.types))
3676  return failure();
3677  // Parse the 'then' region.
3678  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
3679  return failure();
3680 
3681  // If we find an 'else' keyword then parse the 'else' region.
3682  if (!parser.parseOptionalKeyword("else")) {
3683  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
3684  return failure();
3685  }
3686 
3687  // Parse the optional attribute list.
3688  if (parser.parseOptionalAttrDict(result.attributes))
3689  return failure();
3690  return success();
3691 }
3692 
3693 void IfOp::print(OpAsmPrinter &p) {
3694  bool printBlockTerminators = false;
3695 
3696  p << " " << getCondition();
3697  if (!getResults().empty()) {
3698  p << " -> (" << getResultTypes() << ")";
3699  // Print yield explicitly if the op defines values.
3700  printBlockTerminators = true;
3701  }
3702  p << ' ';
3703  p.printRegion(getThenGraph(),
3704  /*printEntryBlockArgs=*/false,
3705  /*printBlockTerminators=*/printBlockTerminators);
3706 
3707  // Print the 'else' regions if it exists and has a block.
3708  auto &elseRegion = getElseGraph();
3709  if (!elseRegion.empty()) {
3710  p << " else ";
3711  p.printRegion(elseRegion,
3712  /*printEntryBlockArgs=*/false,
3713  /*printBlockTerminators=*/printBlockTerminators);
3714  }
3715 
3716  p.printOptionalAttrDict((*this)->getAttrs());
3717 }
3718 
3719 LogicalResult IfOp::verify() {
3720  if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
3721  "'then_graph' arguments", getInputList(),
3722  "'input_list'")
3723  .failed())
3724  return failure();
3725 
3726  if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(),
3727  "'else_graph' arguments", getInputList(),
3728  "'input_list'")
3729  .failed())
3730  return failure();
3731 
3732  auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
3733  if (errorIfTypeOrShapeMismatch(*this, thenYield.getInputs(),
3734  "'then_graph' results", getOutputList(),
3735  "'output_list'")
3736  .failed())
3737  return failure();
3738 
3739  auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
3740  if (errorIfTypeOrShapeMismatch(*this, elseYield.getInputs(),
3741  "'else_graph' results", getOutputList(),
3742  "'output_list'")
3743  .failed())
3744  return failure();
3745 
3746  auto condType = getCondition().getType();
3747  if (errorIfShapeNotSizeOne(*this, condType).failed())
3748  return emitOpError() << "'condition' must be a size 1 tensor, got "
3749  << condType;
3750 
3751  return success();
3752 }
3753 
3754 LogicalResult WhileOp::verify() {
3755  if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'",
3756  getOutputList(), "'output_list'")
3757  .failed())
3758  return failure();
3759 
3760  if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(),
3761  "'cond_graph' arguments", getInputList(),
3762  "'input_list'")
3763  .failed())
3764  return failure();
3765 
3766  if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(),
3767  "'body_graph' arguments", getInputList(),
3768  "'input_list'")
3769  .failed())
3770  return failure();
3771 
3772  auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
3773  if (errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
3774  "'body_graph' results", getInputList(),
3775  "'input_list'")
3776  .failed())
3777  return failure();
3778 
3779  // Condition block output must be a single element tensor with a single bool
3780  // value.
3781  auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
3782  if (condYield.getInputs().size() != 1)
3783  return emitOpError() << "require 'cond_graph' only have one result";
3784 
3785  auto condOutType = condYield.getInputs()[0].getType();
3786  if (errorIfShapeNotSizeOne(*this, condOutType).failed())
3787  return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
3788  << condOutType;
3789 
3790  if (!getElementTypeOrSelf(condOutType).isInteger(1))
3791  return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
3792  << condOutType;
3793 
3794  return success();
3795 }
3796 
3797 LogicalResult ReverseOp::verify() {
3798  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
3799  /* outType = */ getOutput().getType())
3800  .failed())
3801  return failure();
3802  TensorType inputType = getInput1().getType();
3803  TensorType outputType = getOutput().getType();
3804  int32_t reverseAxis = getAxis();
3805 
3806  if (reverseAxis < 0)
3807  return emitOpError("expected non-negative reverse axis");
3808  if (inputType.hasRank()) {
3809  int64_t inputRank = inputType.getRank();
3810  // We allow for a special case where the input/output shape has rank 0 and
3811  // axis is also 0.
3812  if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
3813  return emitOpError("expect input tensor rank (")
3814  << inputRank << ") to be larger than reverse axis (" << reverseAxis
3815  << ")";
3816  }
3817  if (outputType.hasRank()) {
3818  int64_t outputRank = outputType.getRank();
3819  if (inputType.hasRank() && outputRank != inputType.getRank())
3820  return emitOpError(
3821  "expect output tensor rank to be equal to input tensor rank");
3822  if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
3823  return emitOpError("expect output tensor rank (")
3824  << outputRank << ") to be larger than reverse axis ("
3825  << reverseAxis << ")";
3826  }
3827  return success();
3828 }
3829 
3830 LogicalResult tosa::SelectOp::verify() {
3831  // verify input2 and input3 have same element type as output
3832  if (verifySameElementTypes(*this, /* inType = */ getOnTrue().getType(),
3833  /* outType = */ getOutput().getType())
3834  .failed() ||
3835  verifySameElementTypes(*this, /* inType = */ getOnFalse().getType(),
3836  /* outType = */ getOutput().getType())
3837  .failed()) {
3838  return failure();
3839  }
3840  // verify input1 has element type of bool
3841  auto predicateType = llvm::dyn_cast<ShapedType>(getPred().getType());
3842  if (!predicateType) {
3843  return emitOpError("expect shaped tensor for input1, got ")
3844  << getInput1().getType();
3845  }
3846  auto predicateElementType = predicateType.getElementType();
3847  if (!predicateElementType.isInteger(1)) {
3848  return emitOpError("expect element type of bool for input1, got ")
3849  << predicateElementType;
3850  }
3851 
3852  return success();
3853 }
3854 
3855 LogicalResult tosa::VariableOp::verify() {
3856  StringRef symName = getName();
3857  FailureOr<tosa::VariableOp> varOp = findVariableDecl(*this, symName);
3858  if (succeeded(varOp))
3859  return emitOpError("illegal to have multiple declaration of '")
3860  << symName << "'";
3861 
3862  return success();
3863 }
3864 
3865 LogicalResult tosa::VariableReadOp::verify() {
3866  if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
3867  .failed())
3868  return failure();
3869 
3870  return success();
3871 }
3872 
3873 LogicalResult tosa::VariableWriteOp::verify() {
3874  if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
3875  .failed())
3876  return failure();
3877 
3878  return success();
3879 }
3880 
3881 // parse and print of WhileOp refer to the implementation of SCF dialect.
3882 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3885  Region *cond = result.addRegion();
3886  Region *body = result.addRegion();
3887 
3888  OptionalParseResult listResult =
3889  parser.parseOptionalAssignmentList(regionArgs, operands);
3890  if (listResult.has_value() && failed(listResult.value()))
3891  return failure();
3892 
3893  FunctionType functionType;
3894  SMLoc typeLoc = parser.getCurrentLocation();
3895  if (failed(parser.parseColonType(functionType)))
3896  return failure();
3897 
3898  result.addTypes(functionType.getResults());
3899 
3900  if (functionType.getNumInputs() != operands.size()) {
3901  return parser.emitError(typeLoc)
3902  << "expected as many input types as operands "
3903  << "(expected " << operands.size() << " got "
3904  << functionType.getNumInputs() << ")";
3905  }
3906 
3907  // Resolve input operands.
3908  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3909  parser.getCurrentLocation(),
3910  result.operands)))
3911  return failure();
3912 
3913  // Propagate the types into the region arguments.
3914  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3915  regionArgs[i].type = functionType.getInput(i);
3916 
3917  return failure(parser.parseRegion(*cond, regionArgs) ||
3918  parser.parseKeyword("do") || parser.parseRegion(*body) ||
3920 }
3921 
3923  Block::BlockArgListType blocksArgs,
3924  ValueRange initializers,
3925  StringRef prefix = "") {
3926  assert(blocksArgs.size() == initializers.size() &&
3927  "expected same length of arguments and initializers");
3928  if (initializers.empty())
3929  return;
3930 
3931  parser << prefix << '(';
3932  llvm::interleaveComma(
3933  llvm::zip(blocksArgs, initializers), parser,
3934  [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
3935  parser << ")";
3936 }
3937 
3938 void WhileOp::print(OpAsmPrinter &parser) {
3939  printInitializationList(parser, getCondGraph().front().getArguments(),
3940  getInputList(), " ");
3941  parser << " : ";
3942  parser.printFunctionalType(getInputList().getTypes(),
3943  getResults().getTypes());
3944  parser << ' ';
3945  parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
3946  parser << " do ";
3947  parser.printRegion(getBodyGraph());
3948  parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3949 }
3950 
3951 // Create a rank-1 const tensor for zero point of the source tensor.
3952 std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
3953  Location loc,
3954  Type srcElemType,
3955  int64_t zp) {
3956  srcElemType = getStorageElementTypeOrSelf(srcElemType);
3957  auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
3958  if (llvm::isa<FloatType>(srcElemType)) {
3959  auto zpAttr = DenseElementsAttr::get(
3960  zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
3961  return builder.create<tosa::ConstOp>(loc, zpType, zpAttr);
3962  }
3963  if (llvm::isa<IntegerType>(srcElemType)) {
3964  auto zpAttr =
3965  DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp));
3966  return builder.create<tosa::ConstOp>(loc, zpType, zpAttr);
3967  }
3968  llvm::errs() << "zero point is not allowed for unsupported data types\n";
3969  return std::nullopt;
3970 }
3971 
3972 //===----------------------------------------------------------------------===//
3973 // TOSA Shape and Shape Operators Helper functions.
3974 //===----------------------------------------------------------------------===//
3975 
3977  return mlir::isa<tosa::shapeType>(t);
3978 }
3979 
3980 LogicalResult
3982  int rank) {
3983  if (rank < 0)
3984  return emitError() << "invalid rank (must be >= 0): " << rank;
3985  return success();
3986 }
3987 
3989  for (auto v : op->getOperands()) {
3990  if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
3991  Operation *definingOp = v.getDefiningOp();
3992  if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
3993  return op->emitOpError("shape operand is not compile time resolvable");
3994  }
3995  }
3996  }
3997  return success();
3998 }
3999 
4001  for (auto type : op->getOperandTypes()) {
4002  if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4003  return op->emitOpError("must have operands with tosa shape type");
4004  }
4005  }
4006  for (auto type : op->getResultTypes()) {
4007  if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4008  return op->emitOpError("must have result with tosa shape type");
4009  }
4010  }
4011  return success();
4012 }
4013 
4014 LogicalResult
4016  if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) ||
4017  failed(verifyTosaShapeOperator(op)))
4018  return failure();
4019 
4020  // delegate function that returns rank of shape type
4021  auto getRank = [](const Type type) {
4022  return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4023  };
4024  auto operandTypes = op->getOperandTypes();
4025  auto resultTypes = op->getResultTypes();
4026 
4027  auto rank = getRank(*op->getOperandTypes().begin());
4028  for (auto type : operandTypes) {
4029  if (getRank(type) != rank) {
4030  return op->emitOpError("operands don't have matching ranks");
4031  }
4032  }
4033  for (auto type : resultTypes) {
4034  if (getRank(type) != rank) {
4035  return op->emitOpError("result shape has different rank than operands");
4036  }
4037  }
4038  return success();
4039 }
4040 
4041 //===----------------------------------------------------------------------===//
4042 // TOSA Shape Operators verify functions.
4043 //===----------------------------------------------------------------------===//
4044 
4045 LogicalResult tosa::ConstShapeOp::verify() {
4046  // check one dimensional rank
4047  auto valuesRank = getValues().getType().getRank();
4048  if (valuesRank != 1)
4049  return emitOpError("expect elements in attribute values with rank 1");
4050  // check that number of elements in values attr equal to rank of result shape
4051  auto count = getValues().getNumElements();
4052  auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
4053  if (!(count == rank || (count == 1 && rank == 0))) {
4054  return emitOpError("expect number of elements in attribute values (")
4055  << count << ") to be equal to the rank (" << rank
4056  << ") for the result shape type";
4057  }
4058  return success();
4059 }
4060 
4061 //===----------------------------------------------------------------------===//
4062 // TOSA Attribute Definitions.
4063 //===----------------------------------------------------------------------===//
4064 
4065 #define GET_ATTRDEF_CLASSES
4066 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
4067 
4068 //===----------------------------------------------------------------------===//
4069 // TOSA Type Definitions.
4070 //===----------------------------------------------------------------------===//
4071 #define GET_TYPEDEF_CLASSES
4072 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
4073 
4074 //===----------------------------------------------------------------------===//
4075 // TOSA Operator Definitions.
4076 //===----------------------------------------------------------------------===//
4077 
4078 #define GET_OP_CLASSES
4079 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
Definition: AMXDialect.cpp:85
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition: TosaOps.cpp:1035
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
Definition: TosaOps.cpp:2183
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType)
Definition: TosaOps.cpp:727
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:2756
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
Definition: TosaOps.cpp:296
static FailureOr< tosa::VariableOp > findVariableDecl(Operation *op, StringRef symName)
Definition: TosaOps.cpp:674
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
Definition: TosaOps.cpp:664
#define REDUCE_SHAPE_INFER(OP)
Definition: TosaOps.cpp:2781
static LogicalResult verifyConvOp(T op)
Definition: TosaOps.cpp:336
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
Definition: TosaOps.cpp:708
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:2971
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition: TosaOps.cpp:1149
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
Definition: TosaOps.cpp:1163
static LogicalResult verifyReduceOp(T op)
Definition: TosaOps.cpp:2806
#define NARY_SHAPE_INFER(OP)
Definition: TosaOps.cpp:2874
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
Definition: TosaOps.cpp:2253
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition: TosaOps.cpp:1013
static LogicalResult verifyConvOpErrorIf(T op)
Definition: TosaOps.cpp:482
static LogicalResult verifyConvOpModes(T op)
Definition: TosaOps.cpp:441
std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition: TosaOps.cpp:279
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:2862
#define COMPATIBLE_RETURN_TYPES(OP)
Definition: TosaOps.cpp:2772
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition: TosaOps.cpp:1194
Type getStorageElementTypeOrSelf(Type type)
Definition: TosaOps.cpp:285
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
Definition: TosaOps.cpp:1109
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition: TosaOps.cpp:989
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
Definition: TosaOps.cpp:1064
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
Definition: TosaOps.cpp:622
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition: TosaOps.cpp:138
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
Definition: TosaOps.cpp:2211
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Definition: TosaOps.cpp:3922
static LogicalResult verifyPoolingOp(T op)
Definition: TosaOps.cpp:790
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
Definition: TosaOps.cpp:1283
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:249
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:188
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:94
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
This class indicates that op operates on tosa shape types.
Definition: TosaOps.h:130
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
Definition: Operation.cpp:919
LogicalResult verifyTosaShapeOperator(Operation *op)
Definition: TosaOps.cpp:4000
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
Definition: TosaOps.cpp:4015
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
Definition: TosaOps.cpp:3988
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition: Traits.cpp:60
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:22
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Definition: QuantUtils.cpp:198
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
Definition: QuantUtils.cpp:289
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
Definition: TosaOps.cpp:227
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
Definition: QuantUtils.cpp:269
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
Definition: QuantUtils.cpp:162
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
Definition: TosaOps.cpp:252
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
Definition: QuantUtils.cpp:214
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
Definition: TosaOps.cpp:3952
bool isa_tosa_shape_type(mlir::Type t)
Definition: TosaOps.cpp:3976
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Definition: QuantUtils.cpp:243
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
Definition: TosaOps.cpp:316
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
Definition: ShapeUtils.h:33
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45