MLIR  22.0.0git
TosaOps.cpp
Go to the documentation of this file.
1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://www.mlplatform.org/tosa/tosa_spec.html
12 //
13 //===----------------------------------------------------------------------===//
14 
22 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/TypeUtilities.h"
28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 
31 #include <numeric>
32 
33 using namespace mlir;
34 using namespace mlir::tosa;
35 
36 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
38 
39 //===----------------------------------------------------------------------===//
40 // Tosa dialect interface includes.
41 //===----------------------------------------------------------------------===//
42 
43 #include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
44 #include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
45 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
46 #include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
47 
48 namespace {
49 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
50 
51 //===----------------------------------------------------------------------===//
52 // Dialect Function Inliner Interface.
53 //===----------------------------------------------------------------------===//
54 struct TosaInlinerInterface : public DialectInlinerInterface {
56 
57  //===--------------------------------------------------------------------===//
58  // Analysis Hooks.
59  //===--------------------------------------------------------------------===//
60 
61  /// All operations can be inlined by default.
62  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
63  IRMapping &map) const final {
64  return true;
65  }
66 
67  /// All regions with If and While parent operators can be inlined.
68  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
69  IRMapping &map) const final {
70  return (isa<tosa::IfOp>(dest->getParentOp()) ||
71  isa<tosa::WhileOp>(dest->getParentOp()));
72  }
73 };
74 
75 /// This class implements the bytecode interface for the Tosa dialect.
76 struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
77  TosaDialectBytecodeInterface(Dialect *dialect)
78  : BytecodeDialectInterface(dialect) {}
79 
80  //===--------------------------------------------------------------------===//
81  // Attributes
82 
83  Attribute readAttribute(DialectBytecodeReader &reader) const override {
84  return ::readAttribute(getContext(), reader);
85  }
86 
87  LogicalResult writeAttribute(Attribute attr,
88  DialectBytecodeWriter &writer) const override {
89  return ::writeAttribute(attr, writer);
90  }
91 
92  //===--------------------------------------------------------------------===//
93  // Types
94 
95  Type readType(DialectBytecodeReader &reader) const override {
96  return ::readType(getContext(), reader);
97  }
98 
99  LogicalResult writeType(Type type,
100  DialectBytecodeWriter &writer) const override {
101  return ::writeType(type, writer);
102  }
103 
104  void writeVersion(DialectBytecodeWriter &writer) const final {
105  // TODO: Populate.
106  }
107 
108  std::unique_ptr<DialectVersion>
109  readVersion(DialectBytecodeReader &reader) const final {
110  // TODO: Populate
111  reader.emitError("Dialect does not support versioning");
112  return nullptr;
113  }
114 
115  LogicalResult upgradeFromVersion(Operation *topLevelOp,
116  const DialectVersion &version) const final {
117  return success();
118  }
119 };
120 
121 } // namespace
122 
123 //===----------------------------------------------------------------------===//
124 // TOSA control flow support.
125 //===----------------------------------------------------------------------===//
126 
127 /// Returns the while loop body.
128 SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
129  return {&getBodyGraph()};
130 }
131 
132 //===----------------------------------------------------------------------===//
133 // TOSA variable operator support.
134 //===----------------------------------------------------------------------===//
135 
137  return to_vector(llvm::map_range(shape, [](int64_t dim) {
138  return dim == -1 ? ShapedType::kDynamic : dim;
139  }));
140 }
141 
142 // returns type of variable op
143 RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
144  Type elementType = variableOp.getType();
145  DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
146  auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
147  return RankedTensorType::get(shape, elementType);
148 }
149 
150 //===----------------------------------------------------------------------===//
151 // Tosa dialect initialization.
152 //===----------------------------------------------------------------------===//
153 
154 void TosaDialect::initialize() {
155  addTypes<
156 #define GET_TYPEDEF_LIST
157 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
158  >();
159  addOperations<
160 #define GET_OP_LIST
161 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
162  >();
163  addAttributes<
164 #define GET_ATTRDEF_LIST
165 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
166  >();
167  addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
168  declarePromisedInterfaces<
169  shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
170  ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
171  LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
172  LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
173  BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
174  NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
175  GreaterEqualOp, MatMulOp>();
176 }
177 
179  Type type, Location loc) {
180  // Tosa dialect constants only support ElementsAttr unlike standard dialect
181  // constant which supports all attributes.
182  if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
183  return tosa::ConstShapeOp::create(builder, loc, type,
184  llvm::cast<DenseIntElementsAttr>(value));
185  }
186  if (llvm::isa<ElementsAttr>(value))
187  return tosa::ConstOp::create(builder, loc, type,
188  llvm::cast<ElementsAttr>(value));
189  return nullptr;
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // Parsers and printers
194 //===----------------------------------------------------------------------===//
195 
196 namespace {
197 
198 ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
199  DenseElementsAttr &varShapeAttr,
200  TypeAttr &typeAttr) {
201  if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
202  if (!shapedType.hasRank())
203  return parser.emitError(parser.getCurrentLocation())
204  << "expected ranked type";
205 
206  auto elementType = shapedType.getElementType();
207  typeAttr = TypeAttr::get(elementType);
208  ArrayRef<int64_t> shape = shapedType.getShape();
209  Builder builder(parser.getContext());
210  varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
211  return success();
212  }
213  return parser.emitError(parser.getCurrentLocation())
214  << "expected shaped type";
215 }
216 
217 } // namespace
218 
219 // parses the optional initial value or type for a tosa variable
220 // with initial value:
221 // tosa.variable @name = dense<0.0> : tensor<1x8xf32>
222 //
223 // without initial value:
224 // tosa.variable @name : tensor<1x8xf32>
226  OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
227  Attribute &initialValueAttr) {
228  if (succeeded(parser.parseOptionalEqual())) {
229  if (failed(parser.parseAttribute(initialValueAttr))) {
230  return parser.emitError(parser.getCurrentLocation())
231  << "expected attribute";
232  }
233  if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
234  return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
235  typeAttr);
236  }
237  return parser.emitError(parser.getCurrentLocation())
238  << "expected Typed attr";
239  }
240 
241  initialValueAttr = nullptr;
242  Type parsedType;
243  if (failed(parser.parseColonType(parsedType))) {
244  return parser.emitError(parser.getCurrentLocation())
245  << "expected type after colon";
246  }
247  return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
248 }
249 
251  OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
252  TypeAttr typeAttr, Attribute initialValueAttr) {
253  bool needsSpace = false;
254  if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
255  auto shape =
256  convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
257  Type elementType = typeAttr.getValue();
258  RankedTensorType tensorType =
259  RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
260  auto tensorTypeAttr = TypeAttr::get(tensorType);
261  p << ": ";
262  p.printAttribute(tensorTypeAttr);
263  needsSpace = true; // subsequent attr value needs a space separator
264  }
265  if (initialValueAttr) {
266  if (needsSpace)
267  p << ' ';
268  p << "= ";
269  p.printAttribute(initialValueAttr);
270  }
271 }
272 
273 //===----------------------------------------------------------------------===//
274 // Tosa utilities.
275 //===----------------------------------------------------------------------===//
276 
277 std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
278  if (lhs % rhs != 0)
279  return std::nullopt;
280  return lhs / rhs;
281 }
282 
284  auto srcType = getElementTypeOrSelf(type);
285  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
286  srcType = quantType.getStorageType();
287  return srcType;
288 }
289 
291  return getStorageElementTypeOrSelf(value.getType());
292 }
293 
294 static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
295  Value valZp, StringRef name) {
297  Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
298 
299  bool bothInts =
300  mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
301  bool sameBitWidth =
302  (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
303 
304  if (!bothInts || !sameBitWidth) {
305  return op->emitOpError()
306  << "expected " << name << " and " << name
307  << "_zp to both be integer of the same bitwidth, but got " << eType
308  << " vs. " << eZpType;
309  }
310  return success();
311 }
312 
313 // Create a pad-const const tensor with value of `val` of required data-type
315  Value src, int32_t val) {
316  const auto srcType = getElementTypeOrSelf(src);
317  const auto srcElemType = getStorageElementTypeOrSelf(src);
318  const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
319  const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
320  const auto padConstAttr{
321  llvm::isa<FloatType>(srcElemType)
322  ? DenseElementsAttr::get(padConstEType,
323  builder.getFloatAttr(srcElemType, val))
324  : DenseElementsAttr::get(padConstEType,
325  builder.getIntegerAttr(srcElemType, val))};
326  return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
327 }
328 
329 //===----------------------------------------------------------------------===//
330 // TOSA Operator Verifiers.
331 //===----------------------------------------------------------------------===//
332 
333 template <typename T>
334 static LogicalResult verifyConvOp(T op) {
335  const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
336  const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
337 
338  auto inputEType = inputType.getElementType();
339  auto weightEType = weightType.getElementType();
340  auto biasEType =
341  llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
342  auto resultEType =
343  llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
344  bool biasIsFloat = llvm::isa<FloatType>(biasEType);
345  bool resultIsFloat = llvm::isa<FloatType>(resultEType);
346 
347  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
348  inputEType = quantType.getStorageType();
349 
350  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
351  weightEType = quantType.getStorageType();
352 
353  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
354  biasEType = quantType.getStorageType();
355 
356  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
357  resultEType = quantType.getStorageType();
358 
359  if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
360  // for now, only enforce bias element type == result element type for
361  // float types.
362  op.emitOpError(
363  "expect both bias and result to have same element type, got ")
364  << biasEType << " and " << resultEType;
365  return failure();
366  }
367 
368  if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
369  isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
370  if (inputEType != weightEType) {
371  op.emitOpError(
372  "expect both input and weight to have same element type, got ")
373  << inputEType << " and " << weightEType;
374  return failure();
375  }
376  }
377 
378  bool inputIsFloat = llvm::isa<FloatType>(inputEType);
379  bool weightIsFloat = llvm::isa<FloatType>(weightEType);
380 
381  // Either both must be float or both non-float.
382  if (inputIsFloat != weightIsFloat) {
383  op.emitOpError(
384  "expect both input and weight to be float or not together, got ")
385  << inputEType << " and " << weightEType;
386  return failure();
387  }
388 
389  auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
390  if (inputEType != inputZpEType) {
391  return op.emitOpError("expect both input and its zero point are the same "
392  "element type, got ")
393  << inputEType << " and " << inputZpEType;
394  }
395 
396  auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
397  if (weightEType != weightZpEType) {
398  return op.emitOpError("expect both weight and its zero point are the same "
399  "element type, got ")
400  << weightEType << " and " << weightZpEType;
401  }
402 
403  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
404  if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
405  return failure();
406 
407  FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
408  if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
409  return failure();
410 
411  return success();
412 }
413 
414 LogicalResult tosa::ConstOp::verify() {
415 
416  auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().getType());
417  auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
418 
419  if (!attrType || !outputType) {
420  emitOpError("expected tensors for attr/result type");
421  return failure();
422  }
423 
424  if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
425  outputType.getElementType())) {
426  if (result.getStorageType() == attrType.getElementType())
427  return success();
428  }
429 
430  if (attrType.getElementType() != outputType.getElementType()) {
431  emitOpError("expected same attr/result element types");
432  return failure();
433  }
434 
435  return success();
436 }
437 
438 template <typename T>
439 static LogicalResult verifyConvOpModes(T op) {
440  auto inputEType =
441  llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
442 
443  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
444  inputEType = quantType.getStorageType();
445 
446  auto accType = op.getAccType();
447  if (inputEType.isInteger(8) && !accType.isInteger(32))
448  return op.emitOpError("accumulator type for i8 tensor is not i32");
449 
450  if (inputEType.isInteger(16) && !accType.isInteger(48))
451  return op.emitOpError("accumulator type for i16 tensor is not i48");
452 
453  if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
454  return op.emitOpError("accumulator type for f8 tensor is not f16");
455 
456  if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
457  return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
458 
459  if (inputEType.isBF16() && !accType.isF32())
460  return op.emitOpError("accumulator type for bf16 tensor is not f32");
461 
462  if (inputEType.isF32() && !accType.isF32())
463  return op.emitOpError("accumulator type for f32 tensor is not f32");
464 
465  auto resultEType =
466  llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
467 
468  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
469  resultEType = quantType.getStorageType();
470 
471  return success();
472 }
473 
474 //===----------------------------------------------------------------------===//
475 // ERROR_IF functions.
476 // ERROR_IF is a predicate that must set an error if the condition holds.
477 //===----------------------------------------------------------------------===//
478 
479 template <typename T>
480 static LogicalResult verifyConvOpErrorIf(T op) {
481  llvm::ArrayRef<int64_t> padding = op.getPad();
482  if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
483  return op.emitOpError("expect all padding values to be >= 0, got ")
484  << padding;
485 
486  llvm::ArrayRef<int64_t> strides = op.getStride();
487  if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
488  return op.emitOpError("expect all stride values to be >= 1, got ")
489  << strides;
490 
491  llvm::ArrayRef<int64_t> dilations = op.getDilation();
492  if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
493  return op.emitOpError("expect all dilation values to be >= 1, got ")
494  << dilations;
495 
496  const RankedTensorType outputType =
497  llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
498  if (!outputType)
499  // Skip following checks if output is not ranked
500  return success();
501 
502  const RankedTensorType inputType =
503  llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
504  const RankedTensorType weightType =
505  llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
506 
507  if (inputType && weightType) {
508  const auto verifyOutputSize =
509  [&op](const int64_t inputSize, const int64_t kernelSize,
510  const int64_t outputSize, const int64_t padBefore,
511  const int64_t padAfter, const int64_t stride,
512  const int64_t dilation, const llvm::StringRef dimName,
513  const llvm::StringRef dimAxis,
514  const llvm::StringRef padBeforeName,
515  const llvm::StringRef padAfterName) -> LogicalResult {
516  if (inputSize == ShapedType::kDynamic ||
517  kernelSize == ShapedType::kDynamic)
518  return success();
519 
520  // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
521 
522  const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
523  inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
524  stride);
525  if (!calculatedOutSizeMinusOne.has_value())
526  return op.emitOpError("expected input_")
527  << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
528  << padAfterName << " - (kernel_" << dimName
529  << " - 1) * dilation_" << dimAxis
530  << " to be wholly divisible by stride_" << dimAxis << ", got ("
531  << inputSize << " - 1 + " << padBefore << " + " << padAfter
532  << " - (" << kernelSize << " - 1) * " << dilation << ") / "
533  << stride;
534 
535  const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
536  if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
537  return op.emitOpError("calculated output ")
538  << dimName << " did not match expected: "
539  << "calculated=" << calculatedOutSize
540  << ", expected=" << outputSize;
541 
542  return success();
543  };
544 
545  // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
546  if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
547  if (failed(verifyOutputSize(
548  inputType.getDimSize(1), weightType.getDimSize(1),
549  outputType.getDimSize(1), padding[0], padding[1], strides[0],
550  dilations[0], "height", "y", "top", "bottom")))
551  return failure();
552 
553  if (failed(verifyOutputSize(
554  inputType.getDimSize(2), weightType.getDimSize(2),
555  outputType.getDimSize(2), padding[2], padding[3], strides[1],
556  dilations[1], "width", "x", "left", "right")))
557  return failure();
558  }
559 
560  // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
561  if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
562  if (failed(verifyOutputSize(
563  inputType.getDimSize(1), weightType.getDimSize(0),
564  outputType.getDimSize(1), padding[0], padding[1], strides[0],
565  dilations[0], "height", "y", "top", "bottom")))
566  return failure();
567 
568  if (failed(verifyOutputSize(
569  inputType.getDimSize(2), weightType.getDimSize(1),
570  outputType.getDimSize(2), padding[2], padding[3], strides[1],
571  dilations[1], "width", "x", "left", "right")))
572  return failure();
573  }
574 
575  // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
576  if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
577  if (failed(verifyOutputSize(
578  inputType.getDimSize(1), weightType.getDimSize(1),
579  outputType.getDimSize(1), padding[0], padding[1], strides[0],
580  dilations[0], "depth", "d", "front", "back")))
581  return failure();
582 
583  if (failed(verifyOutputSize(
584  inputType.getDimSize(2), weightType.getDimSize(2),
585  outputType.getDimSize(2), padding[2], padding[3], strides[1],
586  dilations[1], "height", "y", "top", "bottom")))
587  return failure();
588 
589  if (failed(verifyOutputSize(
590  inputType.getDimSize(3), weightType.getDimSize(3),
591  outputType.getDimSize(3), padding[4], padding[5], strides[2],
592  dilations[2], "width", "x", "left", "right")))
593  return failure();
594  }
595  }
596 
597  const RankedTensorType biasType =
598  llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
599  if (!biasType)
600  // Skip following checks if bias is not ranked
601  return success();
602 
603  const int64_t biasChannels = biasType.getDimSize(0);
604  const int64_t outputChannels =
605  outputType.getDimSize(outputType.getRank() - 1);
606  if (biasChannels == ShapedType::kDynamic ||
607  outputChannels == ShapedType::kDynamic)
608  // Skip following checks if biasChannels or outputChannels is dynamic dim
609  return success();
610 
611  if (biasChannels != outputChannels && biasChannels != 1)
612  return op.emitOpError(
613  "bias channels expected to be equal to output channels (")
614  << outputChannels << ") or 1, got " << biasChannels;
615 
616  return success();
617 }
618 
619 // Verify whether same type and shape of the given two types.
620 static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
621  StringRef name1, Type type2,
622  StringRef name2) {
623  auto shapeType1 = dyn_cast<ShapedType>(type1);
624  auto shapeType2 = dyn_cast<ShapedType>(type2);
625  if (!shapeType1 || !shapeType2)
626  return failure();
627 
628  auto elemType1 = shapeType1.getElementType();
629  auto elemType2 = shapeType2.getElementType();
630  if (elemType1 != elemType2)
631  return op->emitOpError()
632  << "require same element type for " << name1 << " (" << elemType1
633  << ") and " << name2 << " (" << elemType2 << ")";
634 
635  if (failed(verifyCompatibleShape(type1, type2)))
636  return op->emitOpError()
637  << "require same shapes for " << name1 << " (" << type1 << ") and "
638  << name2 << " (" << type2 << ")";
639 
640  return success();
641 }
642 
643 // Verify whether same length, type, and shape of the given two tensor lists.
644 static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
645  StringRef name1,
646  ValueRange list2,
647  StringRef name2) {
648  if (list1.size() != list2.size())
649  return op->emitOpError()
650  << "require same number of values in " << name1 << " ("
651  << list1.size() << ") and " << name2 << " (" << list2.size() << ")";
652 
653  for (auto [type1, type2] :
654  llvm::zip_equal(list1.getTypes(), list2.getTypes())) {
655  if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
656  return failure();
657  }
658 
659  return success();
660 }
661 
662 static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
663  ShapeAdaptor shapeAdaptor(type);
664  if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
665  return success();
666 
667  return shapeAdaptor.getNumElements() == 1 ? success() : failure();
668 }
669 
670 // Returns the first declaration point prior to this operation or failure if
671 // not found.
672 static FailureOr<tosa::VariableOp> findVariableDecl(Operation *op,
673  StringRef symName) {
674  ModuleOp module = op->getParentOfType<ModuleOp>();
675  tosa::VariableOp varOp = nullptr;
676 
677  // TODO: Adopt SymbolTable trait to Varible ops.
678  // Currently, the variable's definition point is searched via walk(),
679  // starting from the top-level ModuleOp and stopping at the point of use. Once
680  // TOSA control flow and variable extensions reach the complete state, may
681  // leverage MLIR's Symbol Table functionality to look up symbol and enhance
682  // the search to a TOSA specific graph traversal over the IR structure.
683  module.walk([&](Operation *tempOp) {
684  // Reach this op itself.
685  if (tempOp == op) {
686  return WalkResult::interrupt();
687  }
688 
689  if (auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
690  if (symName == tosaOp.getName()) {
691  varOp = tosaOp;
692  return WalkResult::interrupt();
693  }
694  }
695 
696  return WalkResult::advance();
697  });
698 
699  if (varOp)
700  return varOp;
701 
702  return failure();
703 }
704 
705 template <typename T>
706 static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
707  StringRef symName = op.getName();
708  FailureOr<tosa::VariableOp> varOp = findVariableDecl(op, symName);
709  if (failed(varOp))
710  return op->emitOpError("'")
711  << symName << "' has not been declared by 'tosa.variable'";
712 
713  // Verify type and shape
714  auto variableType = getVariableType(varOp.value());
715  if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
716  "the input tensor")
717  .failed())
718  return failure();
719 
720  return success();
721 }
722 
723 // verify that inType and outType have same element types
724 template <typename T>
725 static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
726  auto inputType = llvm::dyn_cast<TensorType>(inType);
727  auto outputType = llvm::dyn_cast<TensorType>(outType);
728  if (!inputType) {
729  op.emitOpError("expect shaped tensor for input, got ") << inType;
730  return failure();
731  }
732  if (!outputType) {
733  op.emitOpError("expect shaped tensor for output, got ") << outType;
734  return failure();
735  }
736  auto inputElementType = inputType.getElementType();
737  auto outputElementType = outputType.getElementType();
738  auto inputQuantType =
739  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
740  auto outputQuantType =
741  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
742  if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
743  (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
744  inputElementType != outputElementType) {
745  // only check if both element types are int/index/float/UniformQuantized
746  // eg, not sure how to check quant::QuantizedType
747  // this happens in test_conv2d_q_grouped_convolution in
748  // tfl-to-tosa-pipeline.mlir
749  op.emitOpError("expect input and output to have same element type, got ")
750  << inputElementType << " and " << outputElementType;
751  return failure();
752  }
753  return success();
754 }
755 
756 LogicalResult tosa::ArgMaxOp::verify() {
757  const ShapedType resultType = llvm::cast<ShapedType>(getType());
758 
759  // Ensure output is of 32-bit integer
760  if (const auto resultETy = resultType.getElementType();
761  !resultETy.isIntOrIndex())
762  return emitOpError("result tensor is not of integer type");
763 
764  const auto inputType = llvm::cast<ShapedType>(getInput().getType());
765  if (!inputType.hasRank())
766  return success();
767 
768  // Ensure axis is within the tensor rank
769  const int64_t axis = getAxisAttr().getInt();
770  if (((axis < 0) || axis >= inputType.getRank()))
771  return emitOpError("specified axis is outside the rank of the tensor");
772 
773  if (!resultType.hasRank())
774  return success();
775 
776  const ArrayRef<int64_t> inputShape = inputType.getShape();
777  const ArrayRef<int64_t> outputShape = resultType.getShape();
778  llvm::SmallVector<int64_t> expectedOutputShape(inputShape);
779  expectedOutputShape.erase(expectedOutputShape.begin() + axis);
780  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
781  return emitOpError("expected output shape '")
782  << expectedOutputShape << "', got '" << outputShape << "'";
783 
784  return success();
785 }
786 
787 template <typename T>
788 static LogicalResult verifyPoolingOp(T op) {
789  const llvm::ArrayRef<int64_t> kernel = op.getKernel();
790  if (llvm::any_of(kernel, [](int64_t s) { return s < 1; }))
791  return op.emitOpError("expect all kernel values to be >= 1, got ")
792  << kernel;
793 
794  const llvm::ArrayRef<int64_t> strides = op.getStride();
795  if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
796  return op.emitOpError("expect all stride values to be >= 1, got ")
797  << strides;
798 
799  const llvm::ArrayRef<int64_t> padding = op.getPad();
800  if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
801  return op.emitOpError("expect all padding values to be >= 0, got ")
802  << padding;
803 
804  // Padding must be less than kernel size to avoid a divide-by-zero
805  const int64_t kernelX = kernel[1];
806  const int64_t padLeft = padding[2];
807  const int64_t padRight = padding[3];
808  if (padRight >= kernelX || padLeft >= kernelX)
809  return op.emitOpError("expected left/right padding to be less than the "
810  "width of the kernel, got pad_left=")
811  << padLeft << ", pad_right=" << padRight << ", kernel_x=" << kernelX;
812 
813  const int64_t kernelY = kernel[0];
814  const int64_t padTop = padding[0];
815  const int64_t padBottom = padding[1];
816  if (padTop >= kernelY || padBottom >= kernelY)
817  return op.emitOpError("expected top/bottom padding to be less than the "
818  "height of the kernel, got pad_top=")
819  << padTop << ", pad_bottom=" << padBottom
820  << ", kernel_y=" << kernelY;
821 
822  const auto inputType =
823  llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
824  const auto outputType =
825  llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
826  if (!inputType || !outputType)
827  return success();
828 
829  const auto verifyOutputSize =
830  [&op](const int64_t inputSize, const int64_t outputSize,
831  const int64_t kernelSize, const int64_t strideSize,
832  const int64_t padBefore, const int64_t padAfter,
833  const llvm::StringRef dimName, const llvm::StringRef dimAxis,
834  const llvm::StringRef padBeforeName,
835  const llvm::StringRef padAfterName) -> LogicalResult {
836  if (ShapedType::isDynamic(inputSize))
837  return success();
838 
839  const std::optional<int64_t> calculatedOutSizeMinusOne =
840  idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
841  if (!calculatedOutSizeMinusOne.has_value())
842  return op.emitOpError("expected input_")
843  << dimName << " + pad_" << padBeforeName << " + pad_"
844  << padAfterName << " - kernel_" << dimAxis
845  << " to be wholly divisible by stride_" << dimAxis << ", got ("
846  << inputSize << " + " << padBefore << " + " << padAfter << " - "
847  << kernelSize << ") / " << strideSize;
848 
849  const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
850  if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
851  return op.emitOpError("calculated output ")
852  << dimName << " did not match expected: "
853  << "calculated=" << calculatedOutSize
854  << ", expected=" << outputSize;
855 
856  return success();
857  };
858 
859  if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
860  kernel[0], strides[0], padding[0], padding[1],
861  "height", "y", "top", "bottom")))
862  return failure();
863 
864  if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
865  kernel[1], strides[1], padding[2], padding[3],
866  "width", "x", "left", "right")))
867  return failure();
868 
869  return success();
870 }
871 
872 LogicalResult tosa::AvgPool2dOp::verify() {
873  if (failed(verifyPoolingOp(*this)))
874  return failure();
875 
876  const Type inputETy = getStorageElementTypeOrSelf(getInput().getType());
877  const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType());
878  const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType());
879  const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType());
880 
881  auto accType = getAccType();
882  if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
883  return emitOpError("accumulator type for integer tensor is not i32");
884 
885  if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
886  return emitOpError("accumulator type for f16 tensor is not f16/f32");
887 
888  if (inputETy.isBF16() && !accType.isF32())
889  return emitOpError("accumulator type for bf16 tensor is not f32");
890 
891  if (inputETy.isF32() && !accType.isF32())
892  return emitOpError("accumulator type for f32 tensor is not f32");
893 
894  if (inputETy != inputZpETy)
895  return emitOpError("expect both input and its zero point are the same "
896  "element type, got ")
897  << inputETy << " and " << inputZpETy;
898 
899  if (resultETy != outputZpETy)
900  return emitOpError("expect both output and its zero point are the same "
901  "element type, got ")
902  << resultETy << " and " << outputZpETy;
903 
904  FailureOr<int64_t> maybeIZp = getInputZeroPoint();
905  if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
906  return failure();
907 
908  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
909  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
910  return failure();
911 
912  return success();
913 }
914 
915 LogicalResult tosa::ClampOp::verify() {
916  mlir::Type inputETy =
917  llvm::cast<ShapedType>(getInput().getType()).getElementType();
918  if (auto quantType =
919  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
920  inputETy = quantType.getStorageType();
921  }
922  mlir::Type outputETy =
923  llvm::cast<ShapedType>(getOutput().getType()).getElementType();
924  if (auto quantType =
925  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
926  outputETy = quantType.getStorageType();
927  }
928  if (inputETy != outputETy)
929  return emitOpError("input/output element types are incompatible.");
930 
931  auto maxValAttr = getMaxValAttr();
932  auto minValAttr = getMinValAttr();
933 
934  unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
935 
936  if (inputETy.isInteger(dataTypeBitWidth)) {
937  // if input datatype is integer, check that the min_val/max_val attributes
938  // are integer attributes, and that their type is the same as the input's
939  // datatype
940  auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
941  auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
942  if (!intMaxValAttr || !intMinValAttr ||
943  (intMaxValAttr.getType() != intMinValAttr.getType()) ||
944  (intMaxValAttr.getType() != inputETy))
945  return emitOpError("min/max attributes types are incompatible with "
946  "input/output element types.");
947 
948  const bool isUnsigned = inputETy.isUnsignedInteger();
949  const bool isBoolean = inputETy.isInteger(1);
950  const APInt minVal = intMinValAttr.getValue();
951  const APInt maxVal = intMaxValAttr.getValue();
952  if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
953  return emitOpError("expected min_val <= max_val, got min_val=")
954  << minValAttr << ", max_val=" << maxValAttr;
955  } else {
956  // otherwise, input datatype is float, check that the min_val/max_val
957  // attributes share the same type and that their type is the same as the
958  // input's datatype
959  auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
960  auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
961  if (!floatMaxValAttr || !floatMinValAttr ||
962  (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
963  (floatMaxValAttr.getType() != inputETy))
964  return emitOpError("min/max attributes types are incompatible with "
965  "input/output element types.");
966 
967  const APFloat minVal = floatMinValAttr.getValue();
968  const APFloat maxVal = floatMaxValAttr.getValue();
969  if (minVal.isNaN() || maxVal.isNaN())
970  return emitOpError("min/max attributes should not be 'NaN', got min_val=")
971  << minValAttr << ", max_val=" << maxValAttr;
972 
973  if (maxVal < minVal)
974  return emitOpError("expected min_val <= max_val, got min_val=")
975  << minValAttr << ", max_val=" << maxValAttr;
976  }
977 
978  return success();
979 }
980 
981 //===----------------------------------------------------------------------===//
982 // TOSA Operator Quantization Builders.
983 //===----------------------------------------------------------------------===//
984 
985 /// This builder is called on all convolution operators except TransposeConv,
986 /// which has specialized output shape semantics. The builder also defines the
987 /// bitwidth of the output given the bit width of the input & weight content.
988 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
989  Type outputType, Value input, Value weight,
990  Value bias, DenseI64ArrayAttr pad,
991  DenseI64ArrayAttr stride,
992  DenseI64ArrayAttr dilation,
993  TypeAttr accType) {
994  auto zps = createZPsAsConst(builder, input, weight);
995  result.addOperands({input, weight, bias, zps.first, zps.second});
996  result.addAttribute("pad", pad);
997  result.addAttribute("stride", stride);
998  result.addAttribute("dilation", dilation);
999  result.addAttribute("acc_type", accType);
1000  Type finalOutputType = outputType;
1001  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1002  if (quantAttr) {
1003  finalOutputType =
1004  buildConvOpResultTypeInfo(builder, outputType, input, weight);
1005  }
1006  result.addTypes(finalOutputType);
1007 }
1008 
1009 /// Handles tosa.transpose_conv2d which has outpad and output shape
1010 /// attributes.
1011 static void
1013  Type outputType, Value input, Value weight,
1014  Value bias, DenseI64ArrayAttr outpad,
1015  DenseI64ArrayAttr stride, TypeAttr accType) {
1016  auto zps = createZPsAsConst(builder, input, weight);
1017  result.addOperands({input, weight, bias, zps.first, zps.second});
1018  result.addAttribute("out_pad", outpad);
1019  result.addAttribute("stride", stride);
1020  result.addAttribute("acc_type", accType);
1021  Type finalOutputType = outputType;
1022  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1023  if (quantAttr) {
1024  finalOutputType =
1025  buildConvOpResultTypeInfo(builder, outputType, input, weight);
1026  }
1027  result.addTypes(finalOutputType);
1028 }
1029 
1030 /// The tosa.matmul op is also intended to be generated where a fully_connected
1031 /// op must be constructed where the weight is not a constant. In this case,
1032 /// the fully_connected op must be expressed using matmul.
1033 /// TODO: Add link to the leglization document explaining this.
1035  OperationState &result, Type outputType,
1036  Value a, Value b) {
1037  auto zps = createZPsAsConst(builder, a, b);
1038  result.addOperands({a, b, zps.first, zps.second});
1039 
1040  Type finalOutputType{outputType};
1041  if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
1042  auto eType = getStorageElementTypeOrSelf(a.getType());
1043  auto inputBits = eType.getIntOrFloatBitWidth();
1044 
1045  auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1046  assert(outputShapedType && "Output must be a shaped type");
1047 
1048  IntegerType accElementType;
1049  if (inputBits == 16)
1050  accElementType = builder.getIntegerType(48);
1051  else
1052  accElementType = builder.getI32Type();
1053 
1054  finalOutputType = outputShapedType.clone(accElementType);
1055  }
1056  result.addTypes(finalOutputType);
1057 }
1058 
1059 /// Both the tosa.avg_pool2d and unary ops use the same
1060 /// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
1061 /// has additional parameters not part of the unary ops.
1062 static void
1064  Type outputType, Value input,
1065  DenseArrayAttr kernel, DenseArrayAttr stride,
1066  DenseArrayAttr pad, TypeAttr accType) {
1067  const Location loc{result.location};
1068  int64_t inputZp{0};
1069  int64_t outputZp{0};
1070 
1071  if (auto quantAttr =
1072  buildUnaryOpQuantizationAttr(builder, input, outputType)) {
1073  inputZp = quantAttr.getInputZp();
1074  outputZp = quantAttr.getOutputZp();
1075  }
1076  const std::optional<Value> inputZpOp =
1077  createZeroPointTensor(builder, loc, input.getType(), inputZp);
1078  if (!inputZpOp) {
1079  (void)emitError(
1080  loc,
1081  "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1082  }
1083  const std::optional<Value> outputZpOp =
1084  createZeroPointTensor(builder, loc, outputType, outputZp);
1085  if (!outputZpOp) {
1086  (void)emitError(loc, "Failed to create output zero point tensor for "
1087  "quantized AVG_POOL2D op");
1088  }
1089 
1090  if (inputZpOp && outputZpOp) {
1091  result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1092  } else {
1093  // failed to create one or more zero points above: just add input as
1094  // operands this will trigger error in building the op because of missing
1095  // zero points
1096  result.addOperands({input});
1097  }
1098  result.addAttribute("kernel", kernel);
1099  result.addAttribute("stride", stride);
1100  result.addAttribute("pad", pad);
1101  result.addAttribute("acc_type", accType);
1102  result.types.push_back(outputType);
1103 }
1104 
1105 /// This builder is called on single-parameter negate operator
1106 /// to construct input and output zero points based on their
1107 /// types.
1109  OperationState &result, Type outputType,
1110  Value input) {
1111  const Location loc{result.location};
1112  int64_t input1Zp{0};
1113  int64_t outputZp{0};
1114  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
1115  if (quantAttr) {
1116  input1Zp = quantAttr.getInputZp();
1117  outputZp = quantAttr.getOutputZp();
1118  }
1119  const std::optional<Value> input1ZpOp =
1120  createZeroPointTensor(builder, loc, input.getType(), input1Zp);
1121  if (!input1ZpOp) {
1122  (void)emitError(
1123  loc, "Failed to create input1 zero point for quantized NEGATE op");
1124  }
1125 
1126  const std::optional<Value> outputZpOp =
1127  createZeroPointTensor(builder, loc, input.getType(), outputZp);
1128  if (!outputZpOp) {
1129  (void)emitError(
1130  loc, "Failed to create output zero point for quantized NEGATE op");
1131  }
1132 
1133  if (input1ZpOp && outputZpOp) {
1134  result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1135  } else {
1136  // failed to create one or more zero points above: just add input as
1137  // operands. This will trigger error in building the op because of
1138  // missing zero points
1139  result.addOperands({input});
1140  }
1141 
1142  result.types.push_back(outputType);
1143 }
1144 
1145 /// This builder is called on TOSA pad operator that needs to create its own
1146 /// OptionalAttr quantization_attr parameter to scale the padding values
1147 /// correctly. No pad_const is interpreted as zero-padding.
1148 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
1149  Type outputType, Value input,
1150  Value paddings) {
1151  const Location loc{result.location};
1152  int32_t zp{0};
1153  const auto quantAttr = buildPadOpQuantizationAttr(builder, input);
1154  if (quantAttr) {
1155  zp = static_cast<int32_t>(quantAttr.getInputZp());
1156  }
1157  const auto padConstOp{createPadConstTensor(builder, loc, input, zp)};
1158  result.addOperands({input, paddings, padConstOp});
1159  result.types.push_back(outputType);
1160 }
1161 
1162 static void buildVariableOp(OpBuilder &builder, OperationState &result,
1163  StringRef name, Type variableType,
1164  Attribute initialValue) {
1165  const Location loc{result.location};
1166  auto nameAttr = builder.getStringAttr(name);
1167 
1168  auto shapedType = dyn_cast<ShapedType>(variableType);
1169  if (!shapedType) {
1170  (void)emitError(loc, "variable type must be a shaped type");
1171  return;
1172  }
1173  if (!shapedType.hasRank()) {
1174  (void)emitError(loc, "variable type must be a ranked type");
1175  return;
1176  }
1177 
1178  auto elementType = shapedType.getElementType();
1179  auto elementTypeAttr = TypeAttr::get(elementType);
1180  ArrayRef<int64_t> shape = shapedType.getShape();
1181  auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
1182 
1183  result.addAttribute("name", nameAttr);
1184  result.addAttribute("var_shape", varShapeAttr);
1185  result.addAttribute("type", elementTypeAttr);
1186  result.addAttribute("initial_value", initialValue);
1187 }
1188 
1189 //===----------------------------------------------------------------------===//
1190 // TOSA Operator Return Type Inference.
1191 //===----------------------------------------------------------------------===//
1192 
1193 static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
1194  SmallVector<int64_t> &outShape) {
1195  int64_t outRank = 0;
1196  for (int i = 0, e = operands.size(); i != e; ++i) {
1197  auto shape = operands.getShape(i);
1198  if (!shape.hasRank()) {
1199  // TODO(jennik): Update function to have better case handling for
1200  // invalid operands and for ranked tensors.
1201  return failure();
1202  }
1203  outRank = std::max<int64_t>(outRank, shape.getRank());
1204  }
1205 
1206  outShape.resize(outRank, 1);
1207 
1208  for (int i = 0, e = operands.size(); i != e; ++i) {
1209  auto shape = operands.getShape(i);
1210  auto rankDiff = outShape.size() - shape.getRank();
1211 
1212  for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
1213  auto dim1 = outShape[i + rankDiff];
1214  auto dim2 = shape.getDimSize(i);
1215  auto resolvedDim = dim1;
1216 
1217  if (dim1 == 1) {
1218  resolvedDim = dim2;
1219  } else if (dim2 == 1) {
1220  resolvedDim = dim1;
1221  } else if (dim1 != dim2) {
1222  return failure();
1223  }
1224  outShape[i + rankDiff] = resolvedDim;
1225  }
1226  }
1227 
1228  return success();
1229 }
1230 
1231 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1232  MLIRContext *context, ::std::optional<Location> location,
1233  ArgMaxOp::Adaptor adaptor,
1234  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1235  ShapeAdaptor inputShape(adaptor.getInput().getType());
1236  IntegerAttr axis = adaptor.getProperties().axis;
1237  int32_t axisVal = axis.getValue().getSExtValue();
1238 
1239  if (!inputShape.hasRank()) {
1240  inferredReturnShapes.push_back(ShapedTypeComponents());
1241  return success();
1242  }
1243 
1244  SmallVector<int64_t> outShape;
1245  outShape.reserve(inputShape.getRank() - 1);
1246  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1247  if (i == axisVal)
1248  continue;
1249  outShape.push_back(inputShape.getDimSize(i));
1250  }
1251 
1252  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1253  return success();
1254 }
1255 
1256 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1257  MLIRContext *context, ::std::optional<Location> location,
1258  RFFT2dOp::Adaptor adaptor,
1259  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1260  ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1261 
1262  if (!inputShape.hasRank())
1263  return failure();
1264 
1265  llvm::SmallVector<int64_t> outputShape;
1266  outputShape.resize(3, ShapedType::kDynamic);
1267  outputShape[0] = inputShape.getDimSize(0);
1268  outputShape[1] = inputShape.getDimSize(1);
1269  int64_t inWidth = inputShape.getDimSize(2);
1270 
1271  // Note that we can support this calculation symbolically
1272  // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
1273  if (inWidth != ShapedType::kDynamic)
1274  outputShape[2] = inWidth / 2 + 1;
1275 
1276  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1277  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1278 
1279  return success();
1280 }
1281 
1282 static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
1283  const llvm::StringRef dimName) {
1284  const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1285  if (!isPowerOfTwo)
1286  return op->emitOpError("expected ")
1287  << dimName << " to be a power of two, got " << dimSize;
1288 
1289  return success();
1290 }
1291 
1292 LogicalResult tosa::RFFT2dOp::verify() {
1293  const auto outputTypes = getResultTypes();
1294  if (failed(verifyCompatibleShapes(outputTypes)))
1295  return emitOpError("expected output shapes to match, got ") << outputTypes;
1296 
1297  const auto inputType =
1298  llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1299  if (!inputType)
1300  return success();
1301 
1302  const int64_t height = inputType.getDimSize(1);
1303  if (ShapedType::isStatic(height) &&
1304  failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1305  return failure();
1306 
1307  const int64_t width = inputType.getDimSize(2);
1308  if (ShapedType::isStatic(width) &&
1309  failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1310  return failure();
1311 
1312  const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1313  if (!outputType)
1314  return success();
1315 
1316  // Batch and height input/output dimensions should match
1317  if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
1318  outputType.getShape().drop_back())))
1319  return emitOpError("expected batch and height dimensions of input/output "
1320  "to match, got input=")
1321  << inputType << " output=" << outputType;
1322 
1323  // Output width dimension expected to be input_width / 2 + 1
1324  const int64_t outputWidth = outputType.getDimSize(2);
1325  if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1326  (outputWidth != (width / 2) + 1))
1327  return emitOpError(
1328  "expected output width to be equal to input_width / 2 + 1, got ")
1329  << outputWidth;
1330 
1331  return success();
1332 }
1333 
1334 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1335  MLIRContext *context, ::std::optional<Location> location,
1336  FFT2dOp::Adaptor adaptor,
1337  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1338  inferredReturnShapes.push_back(
1339  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
1340  inferredReturnShapes.push_back(
1341  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
1342  return success();
1343 }
1344 
1345 LogicalResult tosa::FFT2dOp::verify() {
1346  const auto inputRealType =
1347  llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1348  const auto inputImagType =
1349  llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
1350  if (!inputRealType || !inputImagType)
1351  return success();
1352 
1353  const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
1354  return ShapedType::isDynamic(a) ? a : b;
1355  };
1356 
1357  const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1358  inputImagType.getDimSize(1));
1359  if (ShapedType::isStatic(height) &&
1360  failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1361  return failure();
1362 
1363  const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1364  inputImagType.getDimSize(2));
1365  if (ShapedType::isStatic(width) &&
1366  failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1367  return failure();
1368 
1369  return success();
1370 }
1371 
1372 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1373  MLIRContext *context, ::std::optional<Location> location,
1374  ConcatOp::Adaptor adaptor,
1375  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1376  // Infer all dimension sizes by reducing based on inputs.
1377  const Properties &prop = adaptor.getProperties();
1378  int32_t axis = prop.axis.getValue().getSExtValue();
1379  llvm::SmallVector<int64_t> outputShape;
1380  bool hasRankedInput = false;
1381  for (auto operand : adaptor.getOperands()) {
1382  ShapeAdaptor operandShape(operand.getType());
1383  if (!operandShape.hasRank())
1384  continue;
1385 
1386  // Copy the Operand's rank.
1387  if (!hasRankedInput)
1388  outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1389 
1390  // Copy shapes until the dim is non-dynamic.
1391  for (int i = 0, s = operandShape.getRank(); i < s; i++) {
1392  if (i == axis || operandShape.isDynamicDim(i))
1393  continue;
1394  if (outputShape[i] == ShapedType::kDynamic)
1395  outputShape[i] = operandShape.getDimSize(i);
1396  if (outputShape[i] != operandShape.getDimSize(i))
1397  return emitOptionalError(location,
1398  "Cannot concat tensors with different sizes"
1399  " on the non-axis dimension ",
1400  i);
1401  }
1402 
1403  hasRankedInput = true;
1404  }
1405 
1406  if (adaptor.getInput1().empty())
1407  return failure();
1408 
1409  Type inputType =
1410  llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1411  if (!hasRankedInput) {
1412  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1413  return success();
1414  }
1415 
1416  // Determine the dimension size along the concatenation axis.
1417  int64_t concatDimSize = 0;
1418  for (auto operand : adaptor.getOperands()) {
1419  ShapeAdaptor operandShape(operand.getType());
1420 
1421  // We need to know the length of the concatenation axis of all inputs to
1422  // determine the dimension size of the output shape.
1423  if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1424  concatDimSize = ShapedType::kDynamic;
1425  break;
1426  }
1427 
1428  concatDimSize += operandShape.getDimSize(axis);
1429  }
1430 
1431  outputShape[axis] = concatDimSize;
1432 
1433  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1434  return success();
1435 }
1436 
1437 LogicalResult tosa::ConcatOp::verify() {
1438  // check that each input has same element type as output
1439  auto outType = getOutput().getType();
1440  const Operation::operand_range inputList = getInput1();
1441 
1442  // Check there is at least one input
1443  if (inputList.empty())
1444  return emitOpError("expect at least one input");
1445 
1446  if (!llvm::all_of(inputList, [&](auto input) {
1447  return succeeded(verifySameElementTypes(
1448  *this, /* inType = */ input.getType(), outType));
1449  })) {
1450  return failure();
1451  }
1452 
1453  const int32_t axis = getAxis();
1454  ShapeAdaptor firstRankedInputShape = nullptr;
1455  for (const auto &input : inputList) {
1456  const Type inputType = input.getType();
1457  ShapeAdaptor currShape(inputType);
1458  if (currShape.hasRank()) {
1459  firstRankedInputShape = currShape;
1460  // Check axis is in expected range
1461  if (axis < 0 || axis >= firstRankedInputShape.getRank())
1462  return emitOpError("expect axis to be within range 0 < axis < "
1463  "rank(input1[firstRankedTensorIdx]), got ")
1464  << axis;
1465  break;
1466  }
1467  }
1468 
1469  const auto allOperandsHasRank = [](const Value input) {
1470  return ShapeAdaptor(input.getType()).hasRank();
1471  };
1472  if (llvm::all_of(inputList, allOperandsHasRank)) {
1473  const int64_t firstInputRank = firstRankedInputShape.getRank();
1474 
1475  for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {
1476  const ShapeAdaptor inputShape(input.getType());
1477  const int64_t inputRank = inputShape.getRank();
1478  const size_t operandNum = index + 1;
1479 
1480  // Check that each operand has the same rank
1481  if (inputRank != firstInputRank)
1482  return emitOpError(
1483  "expect all operands to have the same rank, but got ")
1484  << firstInputRank << " vs " << inputRank << " on operands 0 and "
1485  << operandNum;
1486 
1487  // Check non-axis dims match
1488  for (int i = 0; i < inputRank; i++) {
1489  const int64_t inputDim = inputShape.getDimSize(i);
1490  const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
1491  if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
1492  inputShape.isDynamicDim(i))
1493  continue;
1494  if (inputDim != firstInputDim)
1495  return emitOpError("expect all operand shapes to have the same sizes "
1496  "on non-axis dimensions, but got ")
1497  << inputDim << " vs " << firstInputDim << " at index " << i
1498  << " on operands 0 and " << operandNum;
1499  }
1500  }
1501 
1502  // ERROR_IF(axis_sum != shape[axis]);
1503  int64_t axisSum = 0;
1504  for (const auto &input : inputList) {
1505  const ShapeAdaptor inputShape(input.getType());
1506  if (inputShape.isDynamicDim(axis)) {
1507  // make axisSum negative to indicate invalid value
1508  axisSum = -1;
1509  break;
1510  }
1511  axisSum += inputShape.getDimSize(axis);
1512  }
1513  const ShapeAdaptor outputShape(outType);
1514  if (axisSum >= 0 && outputShape.hasRank() &&
1515  !outputShape.isDynamicDim(axis) &&
1516  axisSum != outputShape.getDimSize(axis))
1517  return emitOpError("requires sum of axis dimensions of input1 "
1518  "equal to output axis dimension, got ")
1519  << axisSum << " and " << outputShape.getDimSize(axis);
1520  }
1521 
1522  return success();
1523 }
1524 
1525 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1526  MLIRContext *context, ::std::optional<Location> location,
1527  ValueShapeRange operands, DictionaryAttr attributes,
1528  OpaqueProperties properties, RegionRange regions,
1529  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1530  auto elementType = IntegerType::get(context, /*width=*/1);
1531 
1532  llvm::SmallVector<int64_t> outShape;
1533  if (resolveBroadcastShape(operands, outShape).failed()) {
1534  inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
1535  return success();
1536  }
1537 
1538  inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
1539  return success();
1540 }
1541 
1542 bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1543  if (l.size() != r.size() || l.size() != 1)
1544  return false;
1545  return succeeded(verifyCompatibleShape(l[0], r[0]));
1546 }
1547 
1548 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1549  MLIRContext *context, ::std::optional<Location> location,
1550  MatMulOp::Adaptor adaptor,
1551  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1552  ShapeAdaptor lhsShape(adaptor.getA().getType());
1553  ShapeAdaptor rhsShape(adaptor.getB().getType());
1554 
1555  // All shapes are dynamic.
1556  SmallVector<int64_t> outShape;
1557  outShape.resize(3, ShapedType::kDynamic);
1558 
1559  if (lhsShape.hasRank()) {
1560  outShape[0] = lhsShape.getDimSize(0);
1561  outShape[1] = lhsShape.getDimSize(1);
1562  }
1563 
1564  if (rhsShape.hasRank()) {
1565  outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1566  : outShape[0];
1567  outShape[2] = rhsShape.getDimSize(2);
1568  }
1569 
1570  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1571  return success();
1572 }
1573 
1574 LogicalResult MatMulOp::verify() {
1575  auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
1576  auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
1577 
1578  // Must be shaped tensor types
1579  if (!aType)
1580  return emitOpError("expect a shaped tensor for input a, got ")
1581  << getA().getType();
1582 
1583  if (!bType)
1584  return emitOpError("expect a shaped tensor for input b, got ")
1585  << getB().getType();
1586 
1587  auto aElementType = aType.getElementType();
1588  auto bElementType = bType.getElementType();
1589 
1590  auto aQuantizedEType =
1591  llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1592  auto bQuantizedEType =
1593  llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1594 
1595  if (aQuantizedEType || bQuantizedEType) {
1596  if (!aQuantizedEType || !bQuantizedEType) {
1597  return emitOpError("expect operands to be both quantized or both not "
1598  "quantized, got ")
1599  << aElementType << " and " << bElementType;
1600  }
1601  // both a and b have quantized element types
1602  auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1603  auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1604  if (aQuantWidth != bQuantWidth) {
1605  return emitOpError("expect quantized operands to have same widths, got ")
1606  << aQuantWidth << " and " << bQuantWidth;
1607  }
1608  } else {
1609  // non-quantized element types
1610  if (aElementType != bElementType) {
1611  return emitOpError("expect same element type for inputs a and b, got ")
1612  << aElementType << " and " << bElementType;
1613  }
1614  }
1615 
1616  // check a_zp and b_zp
1617  auto aEType = getStorageElementTypeOrSelf(aType);
1618  auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
1619  if (aEType != aZpEType) {
1620  return emitOpError("expect input a and a_zp have the same "
1621  "element type, got ")
1622  << aEType << " and " << aZpEType;
1623  }
1624 
1625  auto bEType = getStorageElementTypeOrSelf(bType);
1626  auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
1627  if (bEType != bZpEType) {
1628  return emitOpError("expect input b and b_zp have the same "
1629  "element type, got ")
1630  << bEType << " and " << bZpEType;
1631  }
1632 
1633  FailureOr<int64_t> maybeAZp = getAZeroPoint();
1634  if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1635  return failure();
1636 
1637  FailureOr<int64_t> maybeBZp = getBZeroPoint();
1638  if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1639  return failure();
1640 
1641  return success();
1642 }
1643 
1644 LogicalResult tosa::PadOp::inferReturnTypeComponents(
1645  MLIRContext *context, ::std::optional<Location> location,
1646  PadOp::Adaptor adaptor,
1647  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1648  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1649  auto paddingRank =
1650  cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
1651  SmallVector<int64_t> outputShape;
1652 
1653  // If the input rank is unknown, we can infer the output rank using the
1654  // padding shape's rank divided by 2.
1655  if (!inputShape.hasRank()) {
1656  outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
1657  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1658  return success();
1659  }
1660 
1661  SmallVector<int64_t> paddingValues;
1662  // If the paddings value is not a constant, all dimensions must be dynamic.
1663  if (!tosa::getConstShapeValues(adaptor.getPadding().getDefiningOp(),
1664  paddingValues)) {
1665  outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1666  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1667  return success();
1668  }
1669 
1670  outputShape.reserve(inputShape.getRank());
1671  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1672  if (inputShape.isDynamicDim(i)) {
1673  outputShape.push_back(ShapedType::kDynamic);
1674  continue;
1675  }
1676  auto padFront = paddingValues[i * 2];
1677  auto padBack = paddingValues[i * 2 + 1];
1678  if (padFront < 0 || padBack < 0) {
1679  // if either padding for dim i is -1, output dim is unknown
1680  outputShape.push_back(ShapedType::kDynamic);
1681  continue;
1682  }
1683 
1684  outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
1685  }
1686 
1687  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1688  return success();
1689 }
1690 
1691 LogicalResult tosa::PadOp::verify() {
1692  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1693  /* outType = */ getOutput().getType())
1694  .failed()) {
1695  return failure();
1696  }
1697 
1698  if (auto padConst = getPadConst()) {
1699  if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
1700  /* outType = */ getOutput().getType())
1701  .failed()) {
1702  return failure();
1703  }
1704  }
1705 
1706  RankedTensorType inputType =
1707  llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1708  RankedTensorType outputType =
1709  llvm::dyn_cast<RankedTensorType>(getOutput().getType());
1710  if (!inputType || !outputType)
1711  return success();
1712 
1713  auto inputRank = inputType.getRank();
1714  auto outputRank = outputType.getRank();
1715  if (inputRank != outputRank)
1716  return emitOpError() << "expect same input and output tensor rank, but got "
1717  << "inputRank: " << inputRank
1718  << ", outputRank: " << outputRank;
1719 
1720  DenseIntElementsAttr paddingAttr;
1721  if (!matchPattern(getPadding(), m_Constant(&paddingAttr))) {
1722  return failure();
1723  }
1724 
1725  auto paddingValues = paddingAttr.getValues<APInt>();
1726  if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
1727  return emitOpError() << "padding tensor must have " << inputRank
1728  << " * 2 = " << inputRank * 2 << " elements, but got "
1729  << paddingValues.size();
1730 
1731  auto inputShape = inputType.getShape();
1732  auto outputShape = outputType.getShape();
1733 
1734  for (int64_t i = 0; i < inputRank; ++i) {
1735  int64_t padStart = paddingValues[i * 2].getSExtValue();
1736  int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
1737 
1738  if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
1739  return emitOpError()
1740  << "invalid padding values at dimension " << i
1741  << ": values must be non-negative or -1 for dynamic padding, got ["
1742  << padStart << ", " << padEnd << "]";
1743  }
1744 
1745  // Skip shape verification for dynamic input/output
1746  if (inputShape[i] == ShapedType::kDynamic ||
1747  outputShape[i] == ShapedType::kDynamic)
1748  continue;
1749 
1750  if (outputShape[i] != inputShape[i] + padStart + padEnd) {
1751  return emitOpError() << "mismatch in output shape at dimension " << i
1752  << ": expected " << inputShape[i] << " + "
1753  << padStart << " + " << padEnd << " = "
1754  << (inputShape[i] + padStart + padEnd)
1755  << ", but got " << outputShape[i];
1756  }
1757  }
1758 
1759  return success();
1760 }
1761 
1762 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
1763  MLIRContext *context, ::std::optional<Location> location,
1764  SliceOp::Adaptor adaptor,
1765  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1766 
1767  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
1768  SmallVector<int64_t> start;
1769  SmallVector<int64_t> size;
1770 
1771  if (!tosa::getConstShapeValues(adaptor.getStart().getDefiningOp(), start) ||
1772  !tosa::getConstShapeValues(adaptor.getSize().getDefiningOp(), size)) {
1773  auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
1774  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
1775  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
1776  return success();
1777  }
1778 
1779  // if size[i] is -1, all remaining elements in dimension i are included
1780  // in the slice, similar to TF.
1781  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1782  // initialize outputShape to all unknown
1783  SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
1784  if (inputShape.hasRank()) {
1785  for (size_t i = 0; i < size.size(); i++) {
1786  if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
1787  (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
1788  start[i] < inputShape.getDimSize(i))) {
1789  // size[i] is not 0 and not < -1, and start[i] is in valid range
1790  if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
1791  // input shape has unknown dim[i] - only valid if size[i] > 0
1792  if (size[i] > 0) {
1793  outputShape[i] = size[i];
1794  }
1795  } else {
1796  // input shape has known dim[i]
1797  if (size[i] == -1) {
1798  outputShape[i] = inputShape.getDimSize(i) - start[i];
1799  } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
1800  // start[i] + size[i] is within bound of input shape's dim[i]
1801  outputShape[i] = size[i];
1802  }
1803  }
1804  }
1805  }
1806  } else {
1807  outputShape = convertToMlirShape(size);
1808  }
1809  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1810  return success();
1811 }
1812 
1813 LogicalResult tosa::SliceOp::verify() {
1814  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1815  /* outType = */ getOutput().getType())
1816  .failed())
1817  return failure();
1818 
1819  const ShapeAdaptor inputShape(getInput1().getType());
1820  if (inputShape.hasRank()) {
1821  const auto inputRank = inputShape.getRank();
1822  const ShapeAdaptor outputShape(getOutput().getType());
1823  if (outputShape.hasRank() && inputRank != outputShape.getRank())
1824  return emitOpError(
1825  "expect input1 and output to have the same ranks, got ")
1826  << inputRank << " and " << outputShape.getRank();
1827 
1828  const auto startShapeRank =
1829  llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
1830  if (inputRank != startShapeRank)
1831  return emitOpError("length of start is not equal to rank of input shape");
1832 
1833  const auto sizeShapeRank =
1834  llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
1835  if (inputRank != sizeShapeRank)
1836  return emitOpError("length of size is not equal to rank of input shape");
1837  }
1838 
1839  return success();
1840 }
1841 
1842 LogicalResult tosa::MulOp::inferReturnTypeComponents(
1843  MLIRContext *context, ::std::optional<Location> location,
1844  ValueShapeRange operands, DictionaryAttr attributes,
1845  OpaqueProperties properties, RegionRange regions,
1846  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1847  // mul op's output shape only depend on input1 and input2, not on shift
1848  ValueShapeRange twoInputs = operands.drop_back();
1849  llvm::SmallVector<int64_t> outShape;
1850  if (resolveBroadcastShape(twoInputs, outShape).failed()) {
1851  inferredReturnShapes.push_back(ShapedTypeComponents());
1852  } else {
1853  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1854  }
1855  return success();
1856 }
1857 
1858 LogicalResult tosa::MulOp::verify() {
1859  const Value output = getOutput();
1860  auto resElemType = getElementTypeOrSelf(output);
1861 
1862  // Verify if the element type among operands and result match tosa
1863  // specification.
1864  if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
1865  IntegerType lhsIntType =
1866  dyn_cast<IntegerType>(getElementTypeOrSelf(getInput1()));
1867  IntegerType rhsIntType =
1868  dyn_cast<IntegerType>(getElementTypeOrSelf(getInput2()));
1869  if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
1870  return emitOpError("requires the same element type for all operands");
1871 
1872  // Though the spec requires the element type of result to be i32, a more
1873  // relaxed way is provided at dialect level for easier cooperating with
1874  // other dialects.
1875  if (lhsIntType.getWidth() > resIntType.getWidth())
1876  return emitOpError("invalid data type size for operands or result");
1877 
1878  } else {
1879  // For other supported type, the spec requires requires the same element
1880  // type for all operands (excludes `shift` operand) and results.
1881  for (int i = 0; i < 2; ++i) {
1882  if (getElementTypeOrSelf(getOperand(i)) != resElemType)
1883  return emitOpError(
1884  "requires the same element type for all operands and results");
1885  }
1886 
1887  // verify shift has value 0 for non-integer types
1888  ElementsAttr shift_elem;
1889  if (matchPattern(getShift(), m_Constant(&shift_elem))) {
1890  int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1891  if (shift != 0) {
1892  return emitOpError() << "require shift to be 0 for float type";
1893  }
1894  }
1895  }
1896 
1897  // Verify the op has same ranks for all main operands (excludes extra operands
1898  // such as shift of mul op, so this is the only difference with the built-in
1899  // `SameOperandsAndResultRank` trait) and results types, if known.
1900  TypeRange operandTypes = getOperandTypes();
1901  ShapedType aType = cast<ShapedType>(operandTypes[0]);
1902  ShapedType bType = cast<ShapedType>(operandTypes[1]);
1903 
1904  const bool aHasRank = aType.hasRank();
1905  const bool bHasRank = bType.hasRank();
1906  if (aHasRank && bHasRank) {
1907  const int64_t aRank = aType.getRank();
1908  const int64_t bRank = bType.getRank();
1909  if (aRank != bRank)
1910  return emitOpError("a and b operands don't have matching ranks, got ")
1911  << aRank << " and " << bRank;
1912 
1913  // check for broadcast compatible shapes
1914  SmallVector<int64_t> resultShape;
1916  aType.getShape(), bType.getShape(), resultShape))
1917  return emitOpError("a and b operands don't have broadcast-compatible "
1918  "shapes, got ")
1919  << aType << " and " << bType;
1920  }
1921 
1922  ShapedType resultType = cast<ShapedType>(output.getType());
1923  if (!resultType.hasRank())
1924  return success();
1925 
1926  const int64_t resultRank = resultType.getRank();
1927  if (aHasRank && resultRank != aType.getRank())
1928  return emitOpError("result type has different rank than a, got ")
1929  << resultRank << " vs " << aType.getRank();
1930  if (bHasRank && resultRank != bType.getRank())
1931  return emitOpError("result type has different rank than b, got ")
1932  << resultRank << " vs " << bType.getRank();
1933 
1934  return success();
1935 }
1936 
1937 LogicalResult tosa::TableOp::inferReturnTypeComponents(
1938  MLIRContext *context, ::std::optional<Location> location,
1939  TableOp::Adaptor adaptor,
1940  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1941  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1942 
1943  if (!inputShape.hasRank()) {
1944  inferredReturnShapes.push_back(ShapedTypeComponents());
1945  return success();
1946  }
1947 
1948  inferredReturnShapes.resize(1);
1949  inputShape.getDims(inferredReturnShapes[0]);
1950  return success();
1951 }
1952 
1953 LogicalResult tosa::TableOp::verify() {
1954  TensorType inputType = getInput1().getType();
1955  TensorType outputType = getOutput().getType();
1956 
1957  if (inputType.hasRank() && outputType.hasRank() &&
1958  inputType.getRank() != outputType.getRank())
1959  return emitOpError()
1960  << "expected input tensor rank to equal result tensor rank";
1961 
1962  auto inputDims = inputType.getShape();
1963  auto outputDims = outputType.getShape();
1964  for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
1965  int64_t dim = it.index();
1966  auto [inputDim, outputDim] = it.value();
1967  if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
1968  return emitOpError() << "dim(result, " << dim << ") = " << outputDim
1969  << " doesn't match dim(input, " << dim
1970  << ") = " << inputDim;
1971  }
1972  }
1973  return success();
1974 }
1975 
1976 LogicalResult
1977 tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
1978  // Multiples must be constants.
1979  DenseIntElementsAttr multiplesAttr;
1980  if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
1981  return failure();
1982  multiples = llvm::to_vector(
1983  llvm::map_range(multiplesAttr.getValues<APInt>(),
1984  [](const APInt &val) { return val.getSExtValue(); }));
1985  return success();
1986 }
1987 
1988 LogicalResult tosa::TileOp::inferReturnTypeComponents(
1989  MLIRContext *context, ::std::optional<Location> location,
1990  TileOp::Adaptor adaptor,
1991  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1992  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
1993  SmallVector<int64_t> multiples;
1994  if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
1995  multiples)) {
1996  auto rank =
1997  cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
1998  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
1999  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2000  return success();
2001  } else {
2002  multiples = convertToMlirShape(multiples);
2003  }
2004 
2005  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2006  SmallVector<int64_t> outputShape;
2007  if (!inputShape.hasRank()) {
2008  outputShape.resize(multiples.size(), ShapedType::kDynamic);
2009  inferredReturnShapes.push_back(
2010  ShapedTypeComponents(outputShape, inputType));
2011  return success();
2012  } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
2013  return failure();
2014 
2015  // Any non dynamic dimension can be multiplied to a known size.
2016  outputShape.reserve(multiples.size());
2017  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2018  if (multiples[i] == ShapedType::kDynamic) {
2019  outputShape.push_back(ShapedType::kDynamic);
2020  } else {
2021  int64_t dim = inputShape.getDimSize(i);
2022  if (dim != ShapedType::kDynamic)
2023  dim *= multiples[i];
2024  outputShape.push_back(dim);
2025  }
2026  }
2027 
2028  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2029  return success();
2030 }
2031 
2032 LogicalResult tosa::TileOp::verify() {
2033  if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
2034  /* outType = */ getOutput().getType())
2035  .failed()) {
2036  return failure();
2037  }
2038  ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
2039  ShapedType outputType = llvm::cast<ShapedType>(getType());
2040 
2041  shapeType multiplesType =
2042  llvm::cast<tosa::shapeType>(getMultiples().getType());
2043 
2044  auto multiplesRank = multiplesType.getRank();
2045 
2046  if (inputType.hasRank()) {
2047  if (inputType.getRank() != multiplesRank)
2048  return emitOpError("expect 'multiples' to have rank ")
2049  << inputType.getRank() << " but got " << multiplesRank << ".";
2050  if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
2051  return emitOpError("expect same input and output tensor rank.");
2052  } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2053  return emitOpError("expect 'multiples' array to have length ")
2054  << outputType.getRank() << " but got " << multiplesRank << ".";
2055 
2056  SmallVector<int64_t> multiples;
2057  if (getConstantMultiples(multiples).succeeded() &&
2058  llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
2059  return emitOpError(
2060  "expect element of 'multiples' to be positive integer or -1.");
2061 
2062  return success();
2063 }
2064 
2065 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2066  if (l.size() != r.size() || l.size() != 1)
2067  return false;
2068  return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
2069 }
2070 
2071 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2072  MLIRContext *context, ::std::optional<Location> location,
2073  ReshapeOp::Adaptor adaptor,
2074  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2075  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2076  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2077  llvm::SmallVector<int64_t> newShapeValue;
2078  if (!tosa::getConstShapeValues(adaptor.getShape().getDefiningOp(),
2079  newShapeValue)) {
2080  auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2081  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2082  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2083  return success();
2084  } else {
2085  newShapeValue = convertToMlirShape(newShapeValue);
2086  }
2087 
2088  // We cannot infer from the total number of elements so we must take the
2089  // shape attribute as exact.
2090  if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2091  inferredReturnShapes.push_back(
2092  ShapedTypeComponents(newShapeValue, inputType));
2093  return success();
2094  }
2095 
2096  // Determine the number of elements covered by the slice of all static
2097  // dimensions. This allows us to infer the length of the remaining dynamic
2098  // dimension.
2099  int64_t numElements = inputShape.getNumElements();
2100  int64_t staticMul = 1;
2101  for (auto val : newShapeValue) {
2102  if (ShapedType::isStatic(val)) {
2103  staticMul *= val;
2104  }
2105  }
2106 
2107  // Determine the length of the dynamic dimension.
2108  for (auto &val : newShapeValue) {
2109  if (ShapedType::isDynamic(val))
2110  val = numElements / staticMul;
2111  }
2112 
2113  inferredReturnShapes.push_back(
2114  ShapedTypeComponents(newShapeValue, inputType));
2115  return success();
2116 }
2117 
2118 llvm::LogicalResult tosa::ReshapeOp::verify() {
2119  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2120  /* outType = */ getOutput().getType())
2121  .failed()) {
2122  return failure();
2123  }
2124  TensorType inputType = getInput1().getType();
2125 
2126  SmallVector<int64_t> shapeValues;
2127  if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
2128  // skip following checks if shape is not constant
2129  return mlir::success();
2130  }
2131 
2132  int missingDims = llvm::count(shapeValues, -1);
2133  if (missingDims > 1)
2134  return emitOpError() << "expected at most one target dimension to be -1";
2135 
2136  const auto outputType = dyn_cast<RankedTensorType>(getType());
2137  if (!outputType)
2138  return success();
2139 
2140  if ((int64_t)shapeValues.size() != outputType.getRank())
2141  return emitOpError() << "new shape does not match result rank";
2142 
2143  for (auto [newShapeDim, outputShapeDim] :
2144  zip(shapeValues, outputType.getShape())) {
2145  if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2146  outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2147  return emitOpError() << "new shape is inconsistent with result shape";
2148 
2149  if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2150  return emitOpError() << "new shape has invalid tensor dimension size "
2151  << newShapeDim;
2152  }
2153 
2154  if (inputType.hasStaticShape()) {
2155  int64_t inputElementsNum = inputType.getNumElements();
2156  if (outputType.hasStaticShape()) {
2157  int64_t outputElementsNum = outputType.getNumElements();
2158  if (inputElementsNum != outputElementsNum) {
2159  return emitOpError() << "cannot reshape " << inputElementsNum
2160  << " elements into " << outputElementsNum;
2161  }
2162  }
2163 
2164  int64_t newShapeElementsNum = std::accumulate(
2165  shapeValues.begin(), shapeValues.end(), 1LL,
2166  [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
2167  bool isStaticNewShape =
2168  llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
2169  if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2170  (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2171  return emitOpError() << "cannot reshape " << inputElementsNum
2172  << " elements into " << newShapeElementsNum;
2173  }
2174  }
2175 
2176  return mlir::success();
2177 }
2178 
2179 // return failure if val is not a constant
2180 // set zp to -1 if val is non-zero float or val is not integer nor float
2181 // otherwise set zp to val's constant value
2182 static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
2183  ElementsAttr zpAttr;
2184  if (!matchPattern(val, m_Constant(&zpAttr))) {
2185  return failure();
2186  }
2187 
2188  Type zpElemType = zpAttr.getElementType();
2189 
2190  if (llvm::isa<FloatType>(zpElemType)) {
2191  if (zpAttr.getValues<APFloat>()[0].isZero()) {
2192  return 0;
2193  }
2194  // return non-zero value to trigger error check
2195  return -1;
2196  }
2197 
2198  if (llvm::isa<IntegerType>(zpElemType)) {
2199  if (signExtend)
2200  return zpAttr.getValues<APInt>()[0].getSExtValue();
2201  else
2202  return zpAttr.getValues<APInt>()[0].getZExtValue();
2203  }
2204 
2205  // return non-zero value to trigger error check
2206  return -1;
2207 }
2208 
2209 template <typename T>
2210 static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
2211  const std::string &operand) {
2212  Type zpElemType = getElementTypeOrSelf(val);
2213 
2214  if (!zpElemType.isInteger(8) && zp != 0) {
2215  // convert operand to lower case for error message
2216  std::string lower = operand;
2217  std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
2218  return op.emitOpError()
2219  << lower << " zero point must be zero for non-int8 integer types";
2220  }
2221 
2222  return success();
2223 }
2224 
2225 static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
2226  const int64_t &zp,
2227  const std::string &operand) {
2228  bool isInputZp = (operand == "Input");
2229 
2230  bool tensorUnsigned =
2231  isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2232  StringRef tensorName = isInputZp ? "input" : "output";
2233 
2234  Type zpElemType = getElementTypeOrSelf(zpVal);
2235 
2236  if (zp != 0) {
2237  if (!zpElemType.isInteger(8) &&
2238  !(zpElemType.isInteger(16) && tensorUnsigned)) {
2239  return op.emitOpError()
2240  << "expect " << tensorName << "_zp of 0, got " << zp;
2241  }
2242  if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
2243  return op.emitOpError() << "expect " << tensorName
2244  << "_zp of 0 or 32768 for unsigned int16 "
2245  << tensorName << ", got " << zp;
2246  }
2247  }
2248 
2249  return success();
2250 }
2251 
2252 #define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2253  FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2254  return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2255  } \
2256  LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2257  return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2258  }
2259 
2260 ZERO_POINT_HELPER(Conv2DOp, Input, true)
2261 ZERO_POINT_HELPER(Conv2DOp, Weight, true)
2262 ZERO_POINT_HELPER(Conv3DOp, Input, true)
2263 ZERO_POINT_HELPER(Conv3DOp, Weight, true)
2264 ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true)
2265 ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true)
2266 ZERO_POINT_HELPER(TransposeConv2DOp, Input, true)
2267 ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true)
2268 ZERO_POINT_HELPER(AvgPool2dOp, Input, true)
2269 ZERO_POINT_HELPER(AvgPool2dOp, Output, true)
2270 ZERO_POINT_HELPER(MatMulOp, A, true)
2271 ZERO_POINT_HELPER(MatMulOp, B, true)
2272 ZERO_POINT_HELPER(NegateOp, Input1, true)
2273 ZERO_POINT_HELPER(NegateOp, Output, true)
2274 ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned())
2275 ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned())
2276 #undef ZERO_POINT_HELPER
2277 
2278 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2279  MLIRContext *context, ::std::optional<Location> location,
2280  TransposeOp::Adaptor adaptor,
2281  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2282  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2283 
2284  // If input rank and permutation length is unknown, the output rank is
2285  // unknown.
2286  if (!inputShape.hasRank()) {
2287  inferredReturnShapes.push_back(ShapedTypeComponents());
2288  return success();
2289  }
2290 
2291  const auto inputRank = inputShape.getRank();
2292 
2293  // This would imply the number of permutations does not match the rank of
2294  // the input which is illegal.
2295  if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
2296  return failure();
2297  }
2298 
2299  SmallVector<int64_t> outputShape;
2300  // Rank-0 means no permutations matter.
2301  if (inputRank == 0) {
2302  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2303  return success();
2304  }
2305 
2306  // Check whether the input dimensions are all the same.
2307  bool allTheSame = true;
2308  for (int i = 1, s = inputRank; i < s; i++) {
2309  if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
2310  allTheSame = false;
2311  break;
2312  }
2313  }
2314 
2315  // If all of the input dimensions are the same we don't care about the
2316  // permutation.
2317  if (allTheSame) {
2318  outputShape.resize(inputRank, inputShape.getDimSize(0));
2319  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2320  return success();
2321  }
2322 
2323  outputShape.resize(inputRank, ShapedType::kDynamic);
2324 
2325  // Constant permutation values must be within the input rank.
2326  if (llvm::any_of(adaptor.getPerms(),
2327  [inputRank](const auto i) { return i >= inputRank; }))
2328  return failure();
2329 
2330  outputShape.reserve(inputRank);
2331  for (int i = 0, s = inputRank; i < s; i++) {
2332  outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
2333  }
2334 
2335  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2336  return success();
2337 }
2338 
2339 LogicalResult tosa::TransposeOp::verify() {
2340  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2341  /* outType = */ getOutput().getType())
2342  .failed()) {
2343  return failure();
2344  }
2345 
2346  const ShapeAdaptor inputShape(getInput1().getType());
2347  const ShapeAdaptor outputShape(getOutput().getType());
2348 
2349  const llvm::ArrayRef<int32_t> constantPerms = getPerms();
2350 
2351  if (inputShape.hasRank() &&
2352  constantPerms.size() != static_cast<size_t>(inputShape.getRank()))
2353  return emitOpError() << "expected perms attribute to have size "
2354  << inputShape.getRank()
2355  << " (input rank) but got size "
2356  << constantPerms.size();
2357 
2358  if (inputShape.hasRank() && outputShape.hasRank() &&
2359  inputShape.getRank() != outputShape.getRank())
2360  return emitOpError()
2361  << "expected input tensor rank to equal result tensor rank";
2362 
2363  if (outputShape.hasRank() &&
2364  constantPerms.size() != static_cast<size_t>(outputShape.getRank()))
2365  return emitOpError() << "expected perms attribute to have size "
2366  << outputShape.getRank()
2367  << " (output rank) but got size "
2368  << constantPerms.size();
2369 
2370  if (!llvm::all_of(constantPerms,
2371  [&constantPerms](int32_t s) {
2372  return s >= 0 &&
2373  static_cast<size_t>(s) < constantPerms.size();
2374  }) ||
2375  !isPermutationVector(llvm::to_vector(llvm::map_range(
2376  constantPerms, [](int32_t v) -> int64_t { return v; }))))
2377  return emitOpError() << "expected valid permutation indices";
2378 
2379  // ERROR_IF(tensor_size(shape1) != tensor_size(shape))
2380  if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2381  inputShape.getNumElements() != outputShape.getNumElements())
2382  return emitOpError() << "expected input1 and output to have same numbers "
2383  "of elements, got "
2384  << inputShape.getNumElements() << " and "
2385  << outputShape.getNumElements();
2386 
2387  // Verify that the types of the input and output tensors are properly
2388  // permuted.
2389  if (inputShape.hasRank() && outputShape.hasRank()) {
2390  for (auto i = 0; i < outputShape.getRank(); i++) {
2391  if (inputShape.isDynamicDim(constantPerms[i]) ||
2392  outputShape.isDynamicDim(i))
2393  continue;
2394 
2395  if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2396  return emitOpError()
2397  << "expected output tensor dim " << i << " to match "
2398  << "input dim " << constantPerms[i] << " with value of "
2399  << inputShape.getDimSize(constantPerms[i]);
2400  }
2401  }
2402 
2403  return success();
2404 }
2405 
2406 LogicalResult TransposeOp::reifyResultShapes(
2407  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2408 
2409  const llvm::ArrayRef<int32_t> transposePerms = getPerms();
2410 
2411  Value input = getInput1();
2412  auto inputType = cast<TensorType>(input.getType());
2413 
2414  SmallVector<OpFoldResult> returnedDims(inputType.getRank());
2415  for (auto dim : transposePerms) {
2416  int32_t dimInInput = transposePerms[dim];
2417  if (inputType.isDynamicDim(dimInInput))
2418  returnedDims[dim] =
2419  tensor::DimOp::create(builder, getLoc(), input, dimInInput)
2420  .getResult();
2421  else
2422  returnedDims[dim] =
2423  builder.getIndexAttr(inputType.getDimSize(dimInInput));
2424  }
2425 
2426  reifiedReturnShapes.emplace_back(std::move(returnedDims));
2427  return success();
2428 }
2429 
2430 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2431  MLIRContext *context, ::std::optional<Location> location,
2432  GatherOp::Adaptor adaptor,
2433  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2434  llvm::SmallVector<int64_t> outputShape;
2435  outputShape.resize(3, ShapedType::kDynamic);
2436 
2437  ShapeAdaptor valuesShape(adaptor.getValues().getType());
2438  if (valuesShape.hasRank()) {
2439  outputShape[0] = valuesShape.getDimSize(0);
2440  outputShape[2] = valuesShape.getDimSize(2);
2441  }
2442 
2443  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2444  if (indicesShape.hasRank()) {
2445  if (outputShape[0] == ShapedType::kDynamic)
2446  outputShape[0] = indicesShape.getDimSize(0);
2447  if (outputShape[1] == ShapedType::kDynamic)
2448  outputShape[1] = indicesShape.getDimSize(1);
2449  }
2450 
2451  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2452  return success();
2453 }
2454 
2455 LogicalResult tosa::GatherOp::verify() {
2456  if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
2457  /* outType = */ getOutput().getType())
2458  .failed()) {
2459  return failure();
2460  }
2461 
2462  const ShapeAdaptor valuesShape(getValues().getType());
2463  const ShapeAdaptor indicesShape(getIndices().getType());
2464  const ShapeAdaptor outputShape(getOutput().getType());
2465 
2466  int64_t N = ShapedType::kDynamic;
2467  int64_t W = ShapedType::kDynamic;
2468  int64_t C = ShapedType::kDynamic;
2469 
2470  if (valuesShape.hasRank()) {
2471  N = valuesShape.getDimSize(0);
2472  C = valuesShape.getDimSize(2);
2473  }
2474  if (indicesShape.hasRank()) {
2475  const int64_t indicesN = indicesShape.getDimSize(0);
2476  W = indicesShape.getDimSize(1);
2477  if (N == ShapedType::kDynamic)
2478  N = indicesN;
2479  else if (indicesN != ShapedType::kDynamic && N != indicesN)
2480  return emitOpError() << "requires indices dimension 0 to have size " << N
2481  << ", got " << indicesN;
2482  }
2483  if (outputShape.hasRank()) {
2484  const int64_t outputN = outputShape.getDimSize(0);
2485  const int64_t outputW = outputShape.getDimSize(1);
2486  const int64_t outputC = outputShape.getDimSize(2);
2487  if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2488  N != outputN)
2489  return emitOpError() << "requires output dimension 0 to have size " << N
2490  << ", got " << outputN;
2491 
2492  if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2493  W != outputW)
2494  return emitOpError() << "requires output dimension 1 to have size " << W
2495  << ", got " << outputW;
2496  if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2497  C != outputC)
2498  return emitOpError() << "requires output dimension 2 to have size " << C
2499  << ", got " << outputC;
2500  }
2501  return success();
2502 }
2503 
2504 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2505  MLIRContext *context, ::std::optional<Location> location,
2506  ResizeOp::Adaptor adaptor,
2507  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2508  llvm::SmallVector<int64_t, 4> outputShape;
2509  outputShape.resize(4, ShapedType::kDynamic);
2510 
2511  ShapeAdaptor inputShape(adaptor.getInput().getType());
2512  if (!inputShape.hasRank())
2513  return failure();
2514 
2515  outputShape[0] = inputShape.getDimSize(0);
2516  outputShape[3] = inputShape.getDimSize(3);
2517  int64_t inputHeight = inputShape.getDimSize(1);
2518  int64_t inputWidth = inputShape.getDimSize(2);
2519 
2520  if ((inputHeight == ShapedType::kDynamic) ||
2521  (inputWidth == ShapedType::kDynamic))
2522  return failure();
2523 
2524  SmallVector<int64_t> scaleInt, offsetInt, borderInt;
2525  if (!tosa::getConstShapeValues(adaptor.getScale().getDefiningOp(),
2526  scaleInt) ||
2527  !tosa::getConstShapeValues(adaptor.getOffset().getDefiningOp(),
2528  offsetInt) ||
2529  !tosa::getConstShapeValues(adaptor.getBorder().getDefiningOp(),
2530  borderInt)) {
2531  return failure();
2532  }
2533 
2534  // Compute the output shape based on attributes: scale, offset, and border.
2535  const int64_t outputHeight =
2536  (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2537  scaleInt[1]) +
2538  1;
2539 
2540  const int64_t outputWidth =
2541  (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2542  scaleInt[3]) +
2543  1;
2544 
2545  if (outputHeight < 0 || outputWidth < 0) {
2546  return emitOptionalError(
2547  location,
2548  "calculated output height and width must be non-negative, "
2549  "got height = ",
2550  outputHeight, ", width = ", outputWidth);
2551  }
2552 
2553  outputShape[1] = outputHeight;
2554  outputShape[2] = outputWidth;
2555  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2556  return success();
2557 }
2558 
2559 LogicalResult tosa::ResizeOp::verify() {
2560  const Value input = getInput();
2561  const Value output = getOutput();
2562  const RankedTensorType inputType =
2563  llvm::dyn_cast<RankedTensorType>(input.getType());
2564  const RankedTensorType outputType =
2565  llvm::dyn_cast<RankedTensorType>(output.getType());
2566 
2567  SmallVector<int64_t> scaleValues;
2568  SmallVector<int64_t> offsetValues;
2569  SmallVector<int64_t> borderValues;
2570  if (!tosa::getConstShapeValues(getScale().getDefiningOp(), scaleValues) ||
2571  !tosa::getConstShapeValues(getOffset().getDefiningOp(), offsetValues) ||
2572  !tosa::getConstShapeValues(getBorder().getDefiningOp(), borderValues)) {
2573  // Skip following checks if shape is not constant
2574  return success();
2575  }
2576 
2577  if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
2578  return emitOpError("expect all scale values to be > 0, got ")
2579  << scaleValues;
2580 
2581  const int64_t scaleYN = scaleValues[0];
2582  const int64_t scaleYD = scaleValues[1];
2583  const int64_t scaleXN = scaleValues[2];
2584  const int64_t scaleXD = scaleValues[3];
2585 
2586  const int64_t offsetY = offsetValues[0];
2587  const int64_t offsetX = offsetValues[1];
2588 
2589  const int64_t borderY = borderValues[0];
2590  const int64_t borderX = borderValues[1];
2591 
2592  if (!inputType)
2593  return success();
2594  if (!outputType)
2595  return success();
2596 
2597  const int64_t oh = outputType.getDimSize(1);
2598  const int64_t ow = outputType.getDimSize(2);
2599  const int64_t ih = inputType.getDimSize(1);
2600  const int64_t iw = inputType.getDimSize(2);
2601 
2602  // Don't check with input height that could be broadcast (ih != 1)
2603  // since Linalg, a consumer of TOSA, expects broadcasting support
2604  // in resize to be available. Taking the cautious approach for now,
2605  // we can consider removing support for broadcasting later.
2606  if (ih != ShapedType::kDynamic && ih != 1) {
2607  const std::optional<int64_t> calculatedOutHeightMinusOne =
2608  idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
2609  if (!calculatedOutHeightMinusOne.has_value())
2610  return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
2611  "border_y ")
2612  << "to be wholly divisible by scale_y_d, got ((" << ih
2613  << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
2614  << ") / " << scaleYD;
2615  const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
2616  if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
2617  return emitOpError("calculated output height did not match expected: ")
2618  << "calculated=" << calculatedOutHeight << ", expected=" << oh;
2619  }
2620 
2621  // Don't check with input width that could be broadcast (iw != 1)
2622  // since Linalg, a consumer of TOSA, expects broadcasting support
2623  // in resize to be available. Taking the cautious approach for now,
2624  // we can consider removing support for broadcasting later.
2625  if (iw != ShapedType::kDynamic && iw != 1) {
2626  const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
2627  const std::optional<int64_t> calculatedOutWidthMinusOne =
2628  idivCheck(scaledInWidth, scaleXD);
2629  if (!calculatedOutWidthMinusOne.has_value())
2630  return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
2631  "border_x ")
2632  << "to be wholly divisible by scale_x_d, got ((" << iw
2633  << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
2634  << ") / " << scaleXD;
2635  const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
2636  if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
2637  return emitOpError("calculated output width did not match expected: ")
2638  << "calculated=" << calculatedOutWidth << ", expected=" << ow;
2639  }
2640 
2641  return success();
2642 }
2643 
2644 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
2645  MLIRContext *context, ::std::optional<Location> location,
2646  ScatterOp::Adaptor adaptor,
2647  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2648  llvm::SmallVector<int64_t> outputShape;
2649  outputShape.resize(3, ShapedType::kDynamic);
2650 
2651  ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
2652  if (valuesInShape.hasRank()) {
2653  outputShape[0] = valuesInShape.getDimSize(0);
2654  outputShape[1] = valuesInShape.getDimSize(1);
2655  outputShape[2] = valuesInShape.getDimSize(2);
2656  }
2657 
2658  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2659  if (indicesShape.hasRank()) {
2660  if (outputShape[0] == ShapedType::kDynamic)
2661  outputShape[0] = indicesShape.getDimSize(0);
2662  }
2663 
2664  ShapeAdaptor inputShape(adaptor.getInput().getType());
2665  if (inputShape.hasRank()) {
2666  if (outputShape[0] == ShapedType::kDynamic)
2667  outputShape[0] = inputShape.getDimSize(0);
2668  if (outputShape[2] == ShapedType::kDynamic)
2669  outputShape[2] = inputShape.getDimSize(2);
2670  }
2671 
2672  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2673  return success();
2674 }
2675 
2676 LogicalResult tosa::ScatterOp::verify() {
2677  if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
2678  /* outType = */ getValuesOut().getType())
2679  .failed() ||
2680  verifySameElementTypes(*this, /* inType = */ getInput().getType(),
2681  /* outType = */ getValuesOut().getType())
2682  .failed()) {
2683  return failure();
2684  }
2685 
2686  const ShapeAdaptor valuesInShape(getValuesIn().getType());
2687  const ShapeAdaptor indicesShape(getIndices().getType());
2688  const ShapeAdaptor inputShape(getInput().getType());
2689  const ShapeAdaptor outputShape(getValuesOut().getType());
2690 
2691  int64_t N = ShapedType::kDynamic;
2692  int64_t K = ShapedType::kDynamic;
2693  int64_t W = ShapedType::kDynamic;
2694  int64_t C = ShapedType::kDynamic;
2695  if (valuesInShape.hasRank()) {
2696  N = valuesInShape.getDimSize(0);
2697  K = valuesInShape.getDimSize(1);
2698  C = valuesInShape.getDimSize(2);
2699  }
2700  if (indicesShape.hasRank()) {
2701  const int64_t indicesN = indicesShape.getDimSize(0);
2702  W = indicesShape.getDimSize(1);
2703  if (N == ShapedType::kDynamic)
2704  N = indicesN;
2705  else if (indicesN != ShapedType::kDynamic && N != indicesN)
2706  return emitOpError() << "requires indices dimension 0 to have size " << N
2707  << ", got " << indicesN;
2708  }
2709  if (inputShape.hasRank()) {
2710  const int64_t inputN = inputShape.getDimSize(0);
2711  const int64_t inputW = inputShape.getDimSize(1);
2712  const int64_t inputC = inputShape.getDimSize(2);
2713  if (N == ShapedType::kDynamic)
2714  N = inputN;
2715  else if (inputN != ShapedType::kDynamic && N != inputN)
2716  return emitOpError() << "requires input dimension 0 to have size " << N
2717  << ", got " << inputN;
2718  if (W == ShapedType::kDynamic)
2719  W = inputW;
2720  else if (inputW != ShapedType::kDynamic && W != inputW)
2721  return emitOpError() << "requires input dimension 1 to have size " << W
2722  << ", got " << inputW;
2723 
2724  if (C == ShapedType::kDynamic)
2725  C = inputC;
2726  else if (inputC != ShapedType::kDynamic && C != inputC)
2727  return emitOpError() << "requires input dimension 2 to have size " << C
2728  << ", got " << inputC;
2729  }
2730  if (outputShape.hasRank()) {
2731  const int64_t outputN = outputShape.getDimSize(0);
2732  const int64_t outputK = outputShape.getDimSize(1);
2733  const int64_t outputC = outputShape.getDimSize(2);
2734  if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2735  N != outputN)
2736  return emitOpError() << "requires values_out dimension 0 to have size "
2737  << N << ", got " << outputN;
2738  if (K == ShapedType::kDynamic)
2739  K = outputK;
2740  else if (outputK != ShapedType::kDynamic && K != outputK)
2741  return emitOpError() << "requires values_out dimension 1 to have size "
2742  << K << ", got " << outputK;
2743  if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2744  C != outputC)
2745  return emitOpError() << "requires values_out dimension 2 to have size "
2746  << C << ", got " << outputC;
2747  }
2748  if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
2749  return emitOpError() << "requires dimensions K >= W, got K=" << K
2750  << " and W=" << W;
2751 
2752  return success();
2753 }
2754 
2755 static LogicalResult ReduceInferReturnTypes(
2756  ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
2757  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2758  int64_t axisVal = axis.getValue().getSExtValue();
2759  if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
2760  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
2761  return success();
2762  }
2763 
2764  SmallVector<int64_t> outputShape;
2765  operandShape.getDims(outputShape);
2766  outputShape[axisVal] = 1;
2767  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2768  return success();
2769 }
2770 
2771 #define COMPATIBLE_RETURN_TYPES(OP) \
2772  bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
2773  if (l.size() != r.size() || l.size() != 1) \
2774  return false; \
2775  if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
2776  return false; \
2777  return succeeded(verifyCompatibleShape(l[0], r[0])); \
2778  }
2779 
2780 #define REDUCE_SHAPE_INFER(OP) \
2781  LogicalResult OP::inferReturnTypeComponents( \
2782  MLIRContext *context, ::std::optional<Location> location, \
2783  OP::Adaptor adaptor, \
2784  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2785  Type inputType = \
2786  llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
2787  ShapeAdaptor inputShape(adaptor.getInput().getType()); \
2788  const Properties &prop = adaptor.getProperties(); \
2789  return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
2790  inferredReturnShapes); \
2791  } \
2792  COMPATIBLE_RETURN_TYPES(OP)
2793 
2794 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
2795 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
2796 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
2797 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
2798 REDUCE_SHAPE_INFER(tosa::ReduceProductOp)
2799 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
2800 #undef REDUCE_SHAPE_INFER
2801 COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
2802 #undef COMPATIBLE_RETURN_TYPES
2803 
2804 template <typename T>
2805 static LogicalResult verifyReduceOp(T op) {
2806  // All TOSA reduce Ops have input, output and axis.
2807  TensorType inputType = op.getInput().getType();
2808  TensorType outputType = op.getOutput().getType();
2809  int32_t reduceAxis = op.getAxis();
2810 
2811  if (reduceAxis < 0) {
2812  op.emitOpError("reduce axis must not be negative");
2813  return failure();
2814  }
2815  if (inputType.hasRank()) {
2816  int64_t inputRank = inputType.getRank();
2817  // We allow for a special case where the input/output shape has rank 0 and
2818  // axis is also 0.
2819  if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
2820  op.emitOpError("expect input tensor rank (")
2821  << inputRank << ") to be larger than reduce axis (" << reduceAxis
2822  << ")";
2823  return failure();
2824  }
2825  }
2826  if (outputType.hasRank()) {
2827  int64_t outputRank = outputType.getRank();
2828  if (inputType.hasRank() && outputRank != inputType.getRank()) {
2829  op.emitOpError(
2830  "expect output tensor rank to be equal to input tensor rank");
2831  return failure();
2832  }
2833  if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
2834  op.emitOpError("expect output tensor rank (")
2835  << outputRank << ") to be larger than reduce axis (" << reduceAxis
2836  << ")";
2837  return failure();
2838  }
2839  // We can only verify the reduced dimension size to be 1 if this is not
2840  // the special case of output rank == 0.
2841  if (outputRank != 0) {
2842  auto outputShape = outputType.getShape();
2843  if (!outputType.isDynamicDim(reduceAxis) &&
2844  outputShape[reduceAxis] != 1) {
2845  op.emitOpError("expect reduced dimension size to be 1, got ")
2846  << outputShape[reduceAxis];
2847  return failure();
2848  }
2849  }
2850  }
2851  return success();
2852 }
2853 
2854 LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
2855 LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
2856 LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
2857 LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
2858 LogicalResult tosa::ReduceProductOp::verify() { return verifyReduceOp(*this); }
2859 LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
2860 
2861 static LogicalResult NAryInferReturnTypes(
2862  const ValueShapeRange &operands,
2863  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2864  llvm::SmallVector<int64_t> outShape;
2865  if (resolveBroadcastShape(operands, outShape).failed()) {
2866  inferredReturnShapes.push_back(ShapedTypeComponents());
2867  } else {
2868  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2869  }
2870  return success();
2871 }
2872 
2873 #define NARY_SHAPE_INFER(OP) \
2874  LogicalResult OP::inferReturnTypeComponents( \
2875  MLIRContext *context, ::std::optional<Location> location, \
2876  ValueShapeRange operands, DictionaryAttr attributes, \
2877  OpaqueProperties properties, RegionRange regions, \
2878  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2879  return NAryInferReturnTypes(operands, inferredReturnShapes); \
2880  }
2881 
2882 NARY_SHAPE_INFER(tosa::AbsOp)
2883 NARY_SHAPE_INFER(tosa::AddOp)
2884 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
2885 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
2886 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
2887 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
2888 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
2889 NARY_SHAPE_INFER(tosa::CastOp)
2890 NARY_SHAPE_INFER(tosa::CeilOp)
2891 NARY_SHAPE_INFER(tosa::ClampOp)
2892 NARY_SHAPE_INFER(tosa::ClzOp)
2893 NARY_SHAPE_INFER(tosa::CosOp)
2894 NARY_SHAPE_INFER(tosa::ExpOp)
2895 NARY_SHAPE_INFER(tosa::FloorOp)
2896 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
2897 NARY_SHAPE_INFER(tosa::GreaterOp)
2898 NARY_SHAPE_INFER(tosa::IdentityOp)
2899 NARY_SHAPE_INFER(tosa::IntDivOp)
2900 NARY_SHAPE_INFER(tosa::LogOp)
2901 NARY_SHAPE_INFER(tosa::LogicalAndOp)
2902 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
2903 NARY_SHAPE_INFER(tosa::LogicalNotOp)
2904 NARY_SHAPE_INFER(tosa::LogicalOrOp)
2905 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
2906 NARY_SHAPE_INFER(tosa::LogicalXorOp)
2907 NARY_SHAPE_INFER(tosa::MaximumOp)
2908 NARY_SHAPE_INFER(tosa::MinimumOp)
2909 NARY_SHAPE_INFER(tosa::PowOp)
2910 NARY_SHAPE_INFER(tosa::ReciprocalOp)
2911 NARY_SHAPE_INFER(tosa::ReverseOp)
2912 NARY_SHAPE_INFER(tosa::RsqrtOp)
2913 NARY_SHAPE_INFER(tosa::SinOp)
2914 NARY_SHAPE_INFER(tosa::SelectOp)
2915 NARY_SHAPE_INFER(tosa::SubOp)
2916 NARY_SHAPE_INFER(tosa::TanhOp)
2917 NARY_SHAPE_INFER(tosa::ErfOp)
2918 NARY_SHAPE_INFER(tosa::SigmoidOp)
2919 #undef PRED_SHAPE_INFER
2920 
2921 LogicalResult tosa::NegateOp::inferReturnTypeComponents(
2922  MLIRContext *context, ::std::optional<Location> location,
2923  NegateOp::Adaptor adaptor,
2924  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2925  ShapeAdaptor inputShape(adaptor.getInput1().getType());
2926  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
2927  return success();
2928 }
2929 
2930 LogicalResult tosa::NegateOp::verify() {
2931  // Verify same element type
2932  const Type input1Type = getInput1().getType();
2933  const Type outputType = getOutput().getType();
2934  if (verifySameElementTypes(*this, input1Type, outputType).failed())
2935  return failure();
2936 
2937  // Verify same shape
2938  const SmallVector<Type, 2> types = {input1Type, outputType};
2939  if (failed(verifyCompatibleShapes(types)))
2940  return emitOpError() << "requires the same shape for input1 and output";
2941 
2942  const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
2943  const Type input1ZpEType =
2944  getStorageElementTypeOrSelf(getInput1Zp().getType());
2945  if (input1EType != input1ZpEType) {
2946  return emitOpError("expect both input1 and its zero point are the same "
2947  "element type, got ")
2948  << input1EType << " and " << input1ZpEType;
2949  }
2950  const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
2951  const Type outputZpEType =
2952  getStorageElementTypeOrSelf(getOutputZp().getType());
2953  if (outputEType != outputZpEType) {
2954  return emitOpError("expect both output and its zero point are the same "
2955  "element type, got ")
2956  << outputEType << " and " << outputZpEType;
2957  }
2958 
2959  FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2960  if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
2961  return failure();
2962 
2963  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2964  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
2965  return failure();
2966 
2967  return success();
2968 }
2969 
2970 static LogicalResult poolingInferReturnTypes(
2971  ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
2972  ArrayRef<int64_t> pad,
2973  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2974  llvm::SmallVector<int64_t> outputShape;
2975  outputShape.resize(4, ShapedType::kDynamic);
2976 
2977  // We only know the rank if the input type is unranked.
2978  if (!inputShape) {
2979  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2980  return success();
2981  }
2982 
2983  // Batch and number of channels are identical for pooling layer.
2984  outputShape[0] = inputShape.getDimSize(0);
2985  outputShape[3] = inputShape.getDimSize(3);
2986 
2987  int64_t height = inputShape.getDimSize(1);
2988  int64_t width = inputShape.getDimSize(2);
2989 
2990  if (ShapedType::isStatic(height)) {
2991  int64_t padded = height + pad[0] + pad[1] - kernel[0];
2992  outputShape[1] = padded / stride[0] + 1;
2993  }
2994 
2995  if (ShapedType::isStatic(width)) {
2996  int64_t padded = width + pad[2] + pad[3] - kernel[1];
2997  outputShape[2] = padded / stride[1] + 1;
2998  }
2999 
3000  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3001  return success();
3002 }
3003 
3004 LogicalResult Conv2DOp::inferReturnTypeComponents(
3005  MLIRContext *context, ::std::optional<Location> location,
3006  Conv2DOp::Adaptor adaptor,
3007  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3008  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3009 
3010  int64_t inputWidth = ShapedType::kDynamic;
3011  int64_t inputHeight = ShapedType::kDynamic;
3012  int64_t weightWidth = ShapedType::kDynamic;
3013  int64_t weightHeight = ShapedType::kDynamic;
3014 
3015  // Input shape describes input width/height and batch.
3016 
3017  ShapeAdaptor inputShape(adaptor.getInput().getType());
3018  if (inputShape.hasRank()) {
3019  outputShape[0] = inputShape.getDimSize(0);
3020  inputHeight = inputShape.getDimSize(1);
3021  inputWidth = inputShape.getDimSize(2);
3022  }
3023 
3024  // Weight shapes describes the filter width/height and the output channels.
3025  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3026  if (weightShape.hasRank()) {
3027  outputShape[3] = weightShape.getDimSize(0);
3028  weightHeight = weightShape.getDimSize(1);
3029  weightWidth = weightShape.getDimSize(2);
3030  }
3031 
3032  // Bias shape can describe the output channels.
3033  ShapeAdaptor biasShape(adaptor.getBias().getType());
3034  if (biasShape.hasRank()) {
3035  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3036  ? biasShape.getDimSize(0)
3037  : outputShape[3];
3038  }
3039 
3040  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3041  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3042  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3043 
3044  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3045  int64_t inputSize = inputHeight + padding[0] + padding[1];
3046  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3047  int64_t unstridedResult = inputSize - filterSize + 1;
3048  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3049  }
3050 
3051  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3052  int64_t inputSize = inputWidth + padding[2] + padding[3];
3053  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3054  int64_t unstridedResult = inputSize - filterSize + 1;
3055  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3056  }
3057 
3058  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3059  return success();
3060 }
3061 
3062 LogicalResult Conv2DOp::verify() {
3063  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3064  verifyConvOpErrorIf(*this).failed())
3065  return failure();
3066  return success();
3067 }
3068 
3069 LogicalResult Conv3DOp::inferReturnTypeComponents(
3070  MLIRContext *context, ::std::optional<Location> location,
3071  Conv3DOp::Adaptor adaptor,
3072  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3073  llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
3074 
3075  int64_t inputWidth = ShapedType::kDynamic;
3076  int64_t inputHeight = ShapedType::kDynamic;
3077  int64_t inputDepth = ShapedType::kDynamic;
3078 
3079  int64_t weightWidth = ShapedType::kDynamic;
3080  int64_t weightHeight = ShapedType::kDynamic;
3081  int64_t weightDepth = ShapedType::kDynamic;
3082 
3083  // Input shape describes input width/height and batch.
3084  ShapeAdaptor inputShape(adaptor.getInput().getType());
3085  if (inputShape.hasRank()) {
3086  outputShape[0] = inputShape.getDimSize(0);
3087  inputDepth = inputShape.getDimSize(1);
3088  inputHeight = inputShape.getDimSize(2);
3089  inputWidth = inputShape.getDimSize(3);
3090  }
3091 
3092  // Weight shapes describes the filter width/height and the output channels.
3093  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3094  if (weightShape.hasRank()) {
3095  outputShape[4] = weightShape.getDimSize(0);
3096  weightDepth = weightShape.getDimSize(1);
3097  weightHeight = weightShape.getDimSize(2);
3098  weightWidth = weightShape.getDimSize(3);
3099  }
3100 
3101  // Bias shape can describe the output channels.
3102  ShapeAdaptor biasShape(adaptor.getBias().getType());
3103  if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
3104  outputShape[4] = biasShape.getDimSize(0);
3105  }
3106 
3107  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3108  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3109  llvm::ArrayRef<int64_t> pad = adaptor.getPad();
3110 
3111  if (ShapedType::isStatic(inputDepth) && ShapedType::isStatic(weightDepth)) {
3112  int32_t inputSize = inputDepth + pad[0] + pad[1];
3113  int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
3114  int32_t unstridedResult = inputSize - filterSize + 1;
3115  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3116  }
3117 
3118  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3119  int32_t inputSize = inputHeight + pad[2] + pad[3];
3120  int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
3121  int32_t unstridedResult = inputSize - filterSize + 1;
3122  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3123  }
3124 
3125  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3126  int32_t inputSize = inputWidth + pad[4] + pad[5];
3127  int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
3128  int32_t unstridedResult = inputSize - filterSize + 1;
3129  outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3130  }
3131 
3132  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3133  return success();
3134 }
3135 
3136 LogicalResult Conv3DOp::verify() {
3137  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3138  verifyConvOpErrorIf(*this).failed())
3139  return failure();
3140  return success();
3141 }
3142 
3143 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3144  MLIRContext *context, ::std::optional<Location> location,
3145  AvgPool2dOp::Adaptor adaptor,
3146  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3147  ShapeAdaptor inputShape(adaptor.getInput().getType());
3148  const Properties &prop = adaptor.getProperties();
3149  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3150  inferredReturnShapes);
3151 }
3152 
3153 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3154  MLIRContext *context, ::std::optional<Location> location,
3155  MaxPool2dOp::Adaptor adaptor,
3156  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3157  ShapeAdaptor inputShape(adaptor.getInput().getType());
3158  const Properties &prop = adaptor.getProperties();
3159  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
3160  inferredReturnShapes);
3161 }
3162 
3163 LogicalResult MaxPool2dOp::verify() {
3164  if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
3165  /* outType = */ getOutput().getType())))
3166  return failure();
3167 
3168  if (failed(verifyPoolingOp(*this)))
3169  return failure();
3170 
3171  return success();
3172 }
3173 
3174 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3175  MLIRContext *context, ::std::optional<Location> location,
3176  DepthwiseConv2DOp::Adaptor adaptor,
3177  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3178  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3179 
3180  int64_t inputWidth = ShapedType::kDynamic;
3181  int64_t inputHeight = ShapedType::kDynamic;
3182  int64_t inputChannels = ShapedType::kDynamic;
3183 
3184  int64_t weightWidth = ShapedType::kDynamic;
3185  int64_t weightHeight = ShapedType::kDynamic;
3186  int64_t depthChannels = ShapedType::kDynamic;
3187 
3188  // Input shape describes input width/height and batch.
3189  ShapeAdaptor inputShape(adaptor.getInput().getType());
3190  if (inputShape.hasRank()) {
3191  outputShape[0] = inputShape.getDimSize(0);
3192  inputHeight = inputShape.getDimSize(1);
3193  inputWidth = inputShape.getDimSize(2);
3194  inputChannels = inputShape.getDimSize(3);
3195  }
3196 
3197  // Weight shapes describes the filter width/height and the output channels.
3198  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3199  if (weightShape.hasRank()) {
3200  weightHeight = weightShape.getDimSize(0);
3201  weightWidth = weightShape.getDimSize(1);
3202  inputChannels = ShapedType::isDynamic(inputChannels)
3203  ? weightShape.getDimSize(2)
3204  : inputChannels;
3205  depthChannels = weightShape.getDimSize(3);
3206  }
3207 
3208  // If both inputChannels and depthChannels are available we can determine
3209  // the output channels.
3210  if (ShapedType::isStatic(inputChannels) &&
3211  ShapedType::isStatic(depthChannels)) {
3212  outputShape[3] = inputChannels * depthChannels;
3213  }
3214 
3215  // Bias shape can describe the output channels.
3216  ShapeAdaptor biasShape(adaptor.getBias().getType());
3217  if (biasShape.hasRank()) {
3218  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3219  ? biasShape.getDimSize(0)
3220  : outputShape[3];
3221  }
3222 
3223  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
3224  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
3225  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3226 
3227  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3228  int64_t inputSize = inputHeight + padding[0] + padding[1];
3229  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3230  int64_t unstridedResult = inputSize - filterSize + 1;
3231  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3232  }
3233 
3234  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3235  int64_t inputSize = inputWidth + padding[2] + padding[3];
3236  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3237  int64_t unstridedResult = inputSize - filterSize + 1;
3238  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3239  }
3240 
3241  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3242  return success();
3243 }
3244 
3245 LogicalResult DepthwiseConv2DOp::verify() {
3246  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
3247  verifyConvOpErrorIf(*this).failed())
3248  return failure();
3249  return success();
3250 }
3251 
3252 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3253  MLIRContext *context, ::std::optional<Location> location,
3254  TransposeConv2DOp::Adaptor adaptor,
3255  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3256  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
3257 
3258  int64_t inputWidth = ShapedType::kDynamic;
3259  int64_t inputHeight = ShapedType::kDynamic;
3260  int64_t weightWidth = ShapedType::kDynamic;
3261  int64_t weightHeight = ShapedType::kDynamic;
3262 
3263  // Input shape describes input width/height and batch.
3264  ShapeAdaptor inputShape(adaptor.getInput().getType());
3265  if (inputShape.hasRank()) {
3266  outputShape[0] = ShapedType::isDynamic(outputShape[0])
3267  ? inputShape.getDimSize(0)
3268  : outputShape[0];
3269  inputHeight = inputShape.getDimSize(1);
3270  inputWidth = inputShape.getDimSize(2);
3271  }
3272 
3273  // Weight shapes describes the filter width/height and the output channels.
3274  ShapeAdaptor weightShape(adaptor.getWeight().getType());
3275  if (weightShape.hasRank()) {
3276  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3277  ? weightShape.getDimSize(0)
3278  : outputShape[3];
3279  weightHeight = weightShape.getDimSize(1);
3280  weightWidth = weightShape.getDimSize(2);
3281  }
3282 
3283  // Bias shape can describe the output channels.
3284  ShapeAdaptor biasShape(adaptor.getInput().getType());
3285  if (biasShape.hasRank()) {
3286  outputShape[3] = ShapedType::isDynamic(outputShape[3])
3287  ? biasShape.getDimSize(0)
3288  : outputShape[3];
3289  }
3290 
3291  llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
3292  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
3293 
3294  if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3295  int64_t calculateSize =
3296  (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3297  outputShape[1] =
3298  ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3299  }
3300 
3301  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3302  int64_t calculateSize =
3303  (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3304  outputShape[2] =
3305  ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3306  }
3307 
3308  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3309  return success();
3310 }
3311 
3312 LogicalResult TransposeConv2DOp::verify() {
3313  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
3314  return failure();
3315 
3316  const llvm::ArrayRef<int64_t> strides = getStride();
3317  const int64_t strideY = strides[0];
3318  const int64_t strideX = strides[1];
3319 
3320  if (strideY < 1 || strideX < 1)
3321  return emitOpError("expect all stride values to be >= 1, got [")
3322  << strides << "]";
3323 
3324  const auto checkPadAgainstKernelDim =
3325  [this](int64_t pad_value, int64_t kernel_dim_size,
3326  llvm::StringRef pad_name,
3327  llvm::StringRef kernel_dim_name) -> LogicalResult {
3328  if (pad_value <= -kernel_dim_size)
3329  return emitOpError("expected ")
3330  << pad_name << " > -" << kernel_dim_name
3331  << ", but got: " << pad_name << "=" << pad_value << " and "
3332  << kernel_dim_name << "=" << kernel_dim_size;
3333  return success();
3334  };
3335 
3336  const llvm::ArrayRef<int64_t> padding = getOutPad();
3337  const int64_t outPadTop = padding[0];
3338  const int64_t outPadBottom = padding[1];
3339  const int64_t outPadLeft = padding[2];
3340  const int64_t outPadRight = padding[3];
3341 
3342  const auto weightType =
3343  llvm::dyn_cast<RankedTensorType>(getWeight().getType());
3344 
3345  if (weightType) {
3346  const int64_t kernelHeight = weightType.getDimSize(1);
3347  if (ShapedType::isStatic(kernelHeight)) {
3348  if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
3349  "out_pad_top", "KH")))
3350  return failure();
3351 
3352  if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3353  "out_pad_bottom", "KH")))
3354  return failure();
3355  }
3356 
3357  const int64_t kernelWidth = weightType.getDimSize(2);
3358  if (ShapedType::isStatic(kernelWidth)) {
3359  if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
3360  "out_pad_left", "KW")))
3361  return failure();
3362 
3363  if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3364  "out_pad_right", "KW")))
3365  return failure();
3366  }
3367  }
3368 
3369  // Rest of the checks depend on the output type being a RankedTensorType
3370  const auto outputType =
3371  llvm::dyn_cast<RankedTensorType>(getOutput().getType());
3372  if (!outputType)
3373  return success();
3374 
3375  const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
3376  if (inputType && weightType) {
3377  const int64_t inputHeight = inputType.getDimSize(1);
3378  const int64_t kernelHeight = weightType.getDimSize(1);
3379  const int64_t outputHeight = outputType.getDimSize(1);
3380 
3381  if (ShapedType::isStatic(inputHeight) &&
3382  ShapedType::isStatic(outputHeight)) {
3383  if (outputHeight !=
3384  (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3385  return emitOpError(
3386  "dimension mismatch: expected OH == (IH - 1) * stride_y "
3387  "+ out_pad_top + out_pad_bottom + KH, but got ")
3388  << outputHeight << " != (" << inputHeight << " - 1) * "
3389  << strideY << " + " << outPadTop << " + " << outPadBottom
3390  << " + " << kernelHeight;
3391  }
3392 
3393  const int64_t inputWidth = inputType.getDimSize(2);
3394  const int64_t kernelWidth = weightType.getDimSize(2);
3395  const int64_t outputWidth = outputType.getDimSize(2);
3396 
3397  if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
3398  if (outputWidth !=
3399  (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3400  return emitOpError(
3401  "dimension mismatch: expected OW == (IW - 1) * stride_x "
3402  "+ out_pad_left + out_pad_right + KW, but got ")
3403  << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
3404  << " + " << outPadLeft << " + " << outPadRight << " + "
3405  << kernelWidth;
3406  }
3407  }
3408 
3409  const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().getType());
3410 
3411  if (!biasType)
3412  return success();
3413 
3414  const int64_t biasChannels = biasType.getDimSize(0);
3415 
3416  // Skip further checks if bias is dynamic
3417  if (biasChannels == ShapedType::kDynamic)
3418  return success();
3419 
3420  const int64_t outputChannels = outputType.getDimSize(3);
3421  if (!ShapedType::isDynamic(outputChannels) &&
3422  biasChannels != outputChannels && biasChannels != 1)
3423  return emitOpError(
3424  "bias channels expected to be equal to output channels (")
3425  << outputChannels << ") or 1, got " << biasChannels;
3426 
3427  return success();
3428 }
3429 
3430 LogicalResult RescaleOp::verify() {
3431  auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
3432  if (!inputType) {
3433  emitOpError("expect shaped tensor for input, got ") << getInput().getType();
3434  return failure();
3435  }
3436 
3437  auto inputElementType =
3438  getStorageElementTypeOrSelf(inputType.getElementType());
3439  if (!mlir::isa<IntegerType>(inputElementType)) {
3440  emitOpError("expect input to have integer element type, got ")
3441  << inputElementType;
3442  return failure();
3443  }
3444 
3445  auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
3446  if (!outputType) {
3447  emitOpError("expect shaped tensor for output, got ")
3448  << getOutput().getType();
3449  return failure();
3450  }
3451 
3452  auto outputElementType =
3453  getStorageElementTypeOrSelf(outputType.getElementType());
3454  if (!mlir::isa<IntegerType>(outputElementType)) {
3455  emitOpError("expect output to have integer element type, got ")
3456  << outputElementType;
3457  return failure();
3458  }
3459 
3460  if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
3461  .failed())
3462  return failure();
3463 
3464  if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
3465  .failed())
3466  return failure();
3467 
3468  FailureOr<int64_t> maybeIZp = getInputZeroPoint();
3469  if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
3470  return failure();
3471 
3472  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3473  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3474  return failure();
3475 
3476  auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
3477  if (!multiplierType) {
3478  emitOpError("expect shaped tensor for multiplier, got ")
3479  << getMultiplier().getType();
3480  return failure();
3481  }
3482 
3483  auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
3484  if (!shiftType) {
3485  emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
3486  return failure();
3487  }
3488 
3489  // multiplier element type must be i32 for scale32 = true
3490  if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
3491  emitOpError("expect i32 element type for multiplier for scale32=true, got ")
3492  << multiplierType.getElementType();
3493  return failure();
3494  }
3495 
3496  // multiplier element type must be i16 for scale32 = false
3497  if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
3498  emitOpError(
3499  "expect i16 element type for multiplier for scale32=false, got ")
3500  << multiplierType.getElementType();
3501  return failure();
3502  }
3503 
3504  if (!inputType.hasRank())
3505  return success();
3506 
3507  // multiplier/shift must have shape = {numChannels},
3508  // where numChannel is 1 if per_channel = false
3509  // otherwise numChannel is dimension in input shape's last axis
3510  int64_t numChannels = 1;
3511  if (getPerChannel()) {
3512  if (inputType.getRank() < 1) {
3513  emitOpError("requires input to be at least rank 1 when per_channel is "
3514  "true, but got rank ")
3515  << inputType.getRank();
3516  return failure();
3517  }
3518  numChannels = inputType.getDimSize(inputType.getRank() - 1);
3519  }
3520 
3521  if (!multiplierType.hasRank())
3522  return success();
3523 
3524  ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
3525  // multiplier input has rank 1 by dialect definition
3526  if (multiplierShape[0] != ShapedType::kDynamic &&
3527  multiplierShape[0] != numChannels) {
3528  emitOpError("expect shape of { ")
3529  << numChannels << " } for multiplier input, got { "
3530  << multiplierShape[0] << " }";
3531  return failure();
3532  }
3533 
3534  if (!shiftType.hasRank())
3535  return success();
3536 
3537  ArrayRef<int64_t> shiftShape = shiftType.getShape();
3538  // shift input has rank 1 by dialect definition
3539  if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3540  emitOpError("expect shape of { ")
3541  << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
3542  return failure();
3543  }
3544 
3545  return success();
3546 }
3547 
3548 LogicalResult RescaleOp::inferReturnTypeComponents(
3549  MLIRContext *context, ::std::optional<Location> location,
3550  RescaleOp::Adaptor adaptor,
3551  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3552  ShapeAdaptor inputShape(adaptor.getInput().getType());
3553  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3554  return success();
3555 }
3556 
3557 LogicalResult IfOp::inferReturnTypeComponents(
3558  MLIRContext *context, ::std::optional<Location> location,
3559  IfOp::Adaptor adaptor,
3560  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3562  for (Region *region : adaptor.getRegions()) {
3563  for (auto &block : *region)
3564  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3565  yieldOps.push_back(returnOp);
3566  }
3567 
3568  if (yieldOps.empty())
3569  return failure();
3570 
3571  // Get the initial type information for the yield op.
3572  llvm::SmallVector<ValueKnowledge> resultKnowledge;
3573  resultKnowledge.reserve(yieldOps.front().getNumOperands());
3574  for (auto operand : yieldOps.front().getOperands()) {
3575  resultKnowledge.push_back(
3576  ValueKnowledge::getKnowledgeFromType(operand.getType()));
3577  }
3578 
3579  for (auto yieldOp : yieldOps) {
3580  if (resultKnowledge.size() != yieldOp.getNumOperands())
3581  return failure();
3582 
3583  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
3584  int32_t index = it.index();
3585  auto meet = ValueKnowledge::meet(
3586  resultKnowledge[index],
3587  ValueKnowledge::getKnowledgeFromType(it.value().getType()));
3588  if (!meet)
3589  continue;
3590  resultKnowledge[index] = meet;
3591  }
3592  }
3593 
3594  for (const ValueKnowledge &result : resultKnowledge) {
3595  inferredReturnShapes.push_back(result.getShapedTypeComponents());
3596  }
3597 
3598  return success();
3599 }
3600 
3601 LogicalResult WhileOp::inferReturnTypeComponents(
3602  MLIRContext *context, ::std::optional<Location> location,
3603  WhileOp::Adaptor adaptor,
3604  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3606  for (auto &block : adaptor.getBodyGraph())
3607  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3608  yieldOps.push_back(returnOp);
3609 
3610  // TOSA's while must have a tosa.yield as its terminator. If not found this
3611  // tosa.while is invalid.
3612  if (yieldOps.empty())
3613  return failure();
3614 
3615  // Get the initial type information from the operand types.
3616  llvm::SmallVector<ValueKnowledge> resultKnowledge;
3617  resultKnowledge.reserve(yieldOps.front().getNumOperands());
3618  for (auto operand : yieldOps.front().getOperands()) {
3619  resultKnowledge.push_back(
3620  ValueKnowledge::getKnowledgeFromType(operand.getType()));
3621  }
3622 
3623  for (auto yieldOp : yieldOps) {
3624  if (resultKnowledge.size() != yieldOp.getNumOperands())
3625  return failure();
3626 
3627  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
3628  int32_t index = it.index();
3629  if (auto meet = ValueKnowledge::meet(
3630  resultKnowledge[index],
3631  ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
3632  resultKnowledge[index] = meet;
3633  }
3634  }
3635  }
3636 
3637  for (const ValueKnowledge &result : resultKnowledge) {
3638  inferredReturnShapes.push_back(result.getShapedTypeComponents());
3639  }
3640 
3641  return success();
3642 }
3643 
3644 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
3645  if (auto vt = llvm::dyn_cast<VectorType>(getType()))
3646  return llvm::to_vector<4>(vt.getShape());
3647  return std::nullopt;
3648 }
3649 
3650 // parse and print of IfOp refer to the implementation of SCF dialect.
3651 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
3652  // Create the regions for 'then'.
3653  result.regions.reserve(2);
3654  Region *thenRegion = result.addRegion();
3655  Region *elseRegion = result.addRegion();
3656 
3657  auto &builder = parser.getBuilder();
3659  // Create a i1 tensor type for the boolean condition.
3660  Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
3661  if (parser.parseOperand(cond) ||
3662  parser.resolveOperand(cond, i1Type, result.operands))
3663  return failure();
3664  // Parse optional results type list.
3665  if (parser.parseOptionalArrowTypeList(result.types))
3666  return failure();
3667  // Parse the 'then' region.
3668  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
3669  return failure();
3670 
3671  // If we find an 'else' keyword then parse the 'else' region.
3672  if (!parser.parseOptionalKeyword("else")) {
3673  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
3674  return failure();
3675  }
3676 
3677  // Parse the optional attribute list.
3678  if (parser.parseOptionalAttrDict(result.attributes))
3679  return failure();
3680  return success();
3681 }
3682 
3683 void IfOp::print(OpAsmPrinter &p) {
3684  bool printBlockTerminators = false;
3685 
3686  p << " " << getCondition();
3687  if (!getResults().empty()) {
3688  p << " -> (" << getResultTypes() << ")";
3689  // Print yield explicitly if the op defines values.
3690  printBlockTerminators = true;
3691  }
3692  p << ' ';
3693  p.printRegion(getThenGraph(),
3694  /*printEntryBlockArgs=*/false,
3695  /*printBlockTerminators=*/printBlockTerminators);
3696 
3697  // Print the 'else' regions if it exists and has a block.
3698  auto &elseRegion = getElseGraph();
3699  if (!elseRegion.empty()) {
3700  p << " else ";
3701  p.printRegion(elseRegion,
3702  /*printEntryBlockArgs=*/false,
3703  /*printBlockTerminators=*/printBlockTerminators);
3704  }
3705 
3706  p.printOptionalAttrDict((*this)->getAttrs());
3707 }
3708 
3709 LogicalResult IfOp::verify() {
3710  if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
3711  "'then_graph' arguments", getInputList(),
3712  "'input_list'")
3713  .failed())
3714  return failure();
3715 
3716  if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(),
3717  "'else_graph' arguments", getInputList(),
3718  "'input_list'")
3719  .failed())
3720  return failure();
3721 
3722  auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
3723  if (errorIfTypeOrShapeMismatch(*this, thenYield.getInputs(),
3724  "'then_graph' results", getOutputList(),
3725  "'output_list'")
3726  .failed())
3727  return failure();
3728 
3729  auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
3730  if (errorIfTypeOrShapeMismatch(*this, elseYield.getInputs(),
3731  "'else_graph' results", getOutputList(),
3732  "'output_list'")
3733  .failed())
3734  return failure();
3735 
3736  auto condType = getCondition().getType();
3737  if (errorIfShapeNotSizeOne(*this, condType).failed())
3738  return emitOpError() << "'condition' must be a size 1 tensor, got "
3739  << condType;
3740 
3741  return success();
3742 }
3743 
3744 LogicalResult WhileOp::verify() {
3745  if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'",
3746  getOutputList(), "'output_list'")
3747  .failed())
3748  return failure();
3749 
3750  if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(),
3751  "'cond_graph' arguments", getInputList(),
3752  "'input_list'")
3753  .failed())
3754  return failure();
3755 
3756  if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(),
3757  "'body_graph' arguments", getInputList(),
3758  "'input_list'")
3759  .failed())
3760  return failure();
3761 
3762  auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
3763  if (errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
3764  "'body_graph' results", getInputList(),
3765  "'input_list'")
3766  .failed())
3767  return failure();
3768 
3769  // Condition block output must be a single element tensor with a single bool
3770  // value.
3771  auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
3772  if (condYield.getInputs().size() != 1)
3773  return emitOpError() << "require 'cond_graph' only have one result";
3774 
3775  auto condOutType = condYield.getInputs()[0].getType();
3776  if (errorIfShapeNotSizeOne(*this, condOutType).failed())
3777  return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
3778  << condOutType;
3779 
3780  if (!getElementTypeOrSelf(condOutType).isInteger(1))
3781  return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
3782  << condOutType;
3783 
3784  return success();
3785 }
3786 
3787 LogicalResult ReverseOp::verify() {
3788  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
3789  /* outType = */ getOutput().getType())
3790  .failed())
3791  return failure();
3792  TensorType inputType = getInput1().getType();
3793  TensorType outputType = getOutput().getType();
3794  int32_t reverseAxis = getAxis();
3795 
3796  if (reverseAxis < 0)
3797  return emitOpError("expected non-negative reverse axis");
3798  if (inputType.hasRank()) {
3799  int64_t inputRank = inputType.getRank();
3800  // We allow for a special case where the input/output shape has rank 0 and
3801  // axis is also 0.
3802  if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
3803  return emitOpError("expect input tensor rank (")
3804  << inputRank << ") to be larger than reverse axis (" << reverseAxis
3805  << ")";
3806  }
3807  if (outputType.hasRank()) {
3808  int64_t outputRank = outputType.getRank();
3809  if (inputType.hasRank() && outputRank != inputType.getRank())
3810  return emitOpError(
3811  "expect output tensor rank to be equal to input tensor rank");
3812  if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
3813  return emitOpError("expect output tensor rank (")
3814  << outputRank << ") to be larger than reverse axis ("
3815  << reverseAxis << ")";
3816  }
3817  return success();
3818 }
3819 
3820 LogicalResult tosa::SelectOp::verify() {
3821  // verify input2 and input3 have same element type as output
3822  if (verifySameElementTypes(*this, /* inType = */ getOnTrue().getType(),
3823  /* outType = */ getOutput().getType())
3824  .failed() ||
3825  verifySameElementTypes(*this, /* inType = */ getOnFalse().getType(),
3826  /* outType = */ getOutput().getType())
3827  .failed()) {
3828  return failure();
3829  }
3830  // verify input1 has element type of bool
3831  auto predicateType = llvm::dyn_cast<ShapedType>(getPred().getType());
3832  if (!predicateType) {
3833  return emitOpError("expect shaped tensor for input1, got ")
3834  << getInput1().getType();
3835  }
3836  auto predicateElementType = predicateType.getElementType();
3837  if (!predicateElementType.isInteger(1)) {
3838  return emitOpError("expect element type of bool for input1, got ")
3839  << predicateElementType;
3840  }
3841 
3842  return success();
3843 }
3844 
3845 LogicalResult tosa::VariableOp::verify() {
3846  StringRef symName = getName();
3847  FailureOr<tosa::VariableOp> varOp = findVariableDecl(*this, symName);
3848  if (succeeded(varOp))
3849  return emitOpError("illegal to have multiple declaration of '")
3850  << symName << "'";
3851 
3852  return success();
3853 }
3854 
3855 LogicalResult tosa::VariableReadOp::verify() {
3856  if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
3857  .failed())
3858  return failure();
3859 
3860  return success();
3861 }
3862 
3863 LogicalResult tosa::VariableWriteOp::verify() {
3864  if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
3865  .failed())
3866  return failure();
3867 
3868  return success();
3869 }
3870 
3871 // parse and print of WhileOp refer to the implementation of SCF dialect.
3872 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
3875  Region *cond = result.addRegion();
3876  Region *body = result.addRegion();
3877 
3878  OptionalParseResult listResult =
3879  parser.parseOptionalAssignmentList(regionArgs, operands);
3880  if (listResult.has_value() && failed(listResult.value()))
3881  return failure();
3882 
3883  FunctionType functionType;
3884  SMLoc typeLoc = parser.getCurrentLocation();
3885  if (failed(parser.parseColonType(functionType)))
3886  return failure();
3887 
3888  result.addTypes(functionType.getResults());
3889 
3890  if (functionType.getNumInputs() != operands.size()) {
3891  return parser.emitError(typeLoc)
3892  << "expected as many input types as operands "
3893  << "(expected " << operands.size() << " got "
3894  << functionType.getNumInputs() << ")";
3895  }
3896 
3897  // Resolve input operands.
3898  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
3899  parser.getCurrentLocation(),
3900  result.operands)))
3901  return failure();
3902 
3903  // Propagate the types into the region arguments.
3904  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
3905  regionArgs[i].type = functionType.getInput(i);
3906 
3907  return failure(parser.parseRegion(*cond, regionArgs) ||
3908  parser.parseKeyword("do") || parser.parseRegion(*body) ||
3910 }
3911 
3913  Block::BlockArgListType blocksArgs,
3914  ValueRange initializers,
3915  StringRef prefix = "") {
3916  assert(blocksArgs.size() == initializers.size() &&
3917  "expected same length of arguments and initializers");
3918  if (initializers.empty())
3919  return;
3920 
3921  parser << prefix << '(';
3922  llvm::interleaveComma(
3923  llvm::zip(blocksArgs, initializers), parser,
3924  [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
3925  parser << ")";
3926 }
3927 
3928 void WhileOp::print(OpAsmPrinter &parser) {
3929  printInitializationList(parser, getCondGraph().front().getArguments(),
3930  getInputList(), " ");
3931  parser << " : ";
3932  parser.printFunctionalType(getInputList().getTypes(),
3933  getResults().getTypes());
3934  parser << ' ';
3935  parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
3936  parser << " do ";
3937  parser.printRegion(getBodyGraph());
3938  parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
3939 }
3940 
3941 // Create a rank-1 const tensor for zero point of the source tensor.
3942 std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
3943  Location loc,
3944  Type srcElemType,
3945  int64_t zp) {
3946  srcElemType = getStorageElementTypeOrSelf(srcElemType);
3947  auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
3948  if (llvm::isa<FloatType>(srcElemType)) {
3949  auto zpAttr = DenseElementsAttr::get(
3950  zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
3951  return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
3952  }
3953  if (llvm::isa<IntegerType>(srcElemType)) {
3954  auto zpAttr =
3955  DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp));
3956  return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
3957  }
3958  llvm::errs() << "zero point is not allowed for unsupported data types\n";
3959  return std::nullopt;
3960 }
3961 
3962 //===----------------------------------------------------------------------===//
3963 // TOSA Shape and Shape Operators Helper functions.
3964 //===----------------------------------------------------------------------===//
3965 
3967  return mlir::isa<tosa::shapeType>(t);
3968 }
3969 
3970 LogicalResult
3972  int rank) {
3973  if (rank < 0)
3974  return emitError() << "invalid rank (must be >= 0): " << rank;
3975  return success();
3976 }
3977 
3979  for (auto v : op->getOperands()) {
3980  if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
3981  Operation *definingOp = v.getDefiningOp();
3982  if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
3983  return op->emitOpError("shape operand is not compile time resolvable");
3984  }
3985  }
3986  }
3987  return success();
3988 }
3989 
3991  for (auto type : op->getOperandTypes()) {
3992  if (!mlir::isa<mlir::tosa::shapeType>(type)) {
3993  return op->emitOpError("must have operands with tosa shape type");
3994  }
3995  }
3996  for (auto type : op->getResultTypes()) {
3997  if (!mlir::isa<mlir::tosa::shapeType>(type)) {
3998  return op->emitOpError("must have result with tosa shape type");
3999  }
4000  }
4001  return success();
4002 }
4003 
4004 LogicalResult
4006  if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) ||
4007  failed(verifyTosaShapeOperator(op)))
4008  return failure();
4009 
4010  // delegate function that returns rank of shape type
4011  auto getRank = [](const Type type) {
4012  return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4013  };
4014  auto operandTypes = op->getOperandTypes();
4015  auto resultTypes = op->getResultTypes();
4016 
4017  auto rank = getRank(*op->getOperandTypes().begin());
4018  for (auto type : operandTypes) {
4019  if (getRank(type) != rank) {
4020  return op->emitOpError("operands don't have matching ranks");
4021  }
4022  }
4023  for (auto type : resultTypes) {
4024  if (getRank(type) != rank) {
4025  return op->emitOpError("result shape has different rank than operands");
4026  }
4027  }
4028  return success();
4029 }
4030 
4031 //===----------------------------------------------------------------------===//
4032 // TOSA Shape Operators verify functions.
4033 //===----------------------------------------------------------------------===//
4034 
4035 LogicalResult tosa::ConstShapeOp::verify() {
4036  // check one dimensional rank
4037  auto valuesRank = getValues().getType().getRank();
4038  if (valuesRank != 1)
4039  return emitOpError("expect elements in attribute values with rank 1");
4040  // check that number of elements in values attr equal to rank of result shape
4041  auto count = getValues().getNumElements();
4042  auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
4043  if (!(count == rank || (count == 1 && rank == 0))) {
4044  return emitOpError("expect number of elements in attribute values (")
4045  << count << ") to be equal to the rank (" << rank
4046  << ") for the result shape type";
4047  }
4048  return success();
4049 }
4050 
4051 //===----------------------------------------------------------------------===//
4052 // TOSA Attribute Definitions.
4053 //===----------------------------------------------------------------------===//
4054 
4055 #define GET_ATTRDEF_CLASSES
4056 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
4057 
4058 //===----------------------------------------------------------------------===//
4059 // TOSA Type Definitions.
4060 //===----------------------------------------------------------------------===//
4061 #define GET_TYPEDEF_CLASSES
4062 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
4063 
4064 //===----------------------------------------------------------------------===//
4065 // TOSA Operator Definitions.
4066 //===----------------------------------------------------------------------===//
4067 
4068 #define GET_OP_CLASSES
4069 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
Definition: AMXDialect.cpp:85
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition: TosaOps.cpp:1034
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
Definition: TosaOps.cpp:2182
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType)
Definition: TosaOps.cpp:725
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:2755
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
Definition: TosaOps.cpp:294
static FailureOr< tosa::VariableOp > findVariableDecl(Operation *op, StringRef symName)
Definition: TosaOps.cpp:672
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
Definition: TosaOps.cpp:662
#define REDUCE_SHAPE_INFER(OP)
Definition: TosaOps.cpp:2780
static LogicalResult verifyConvOp(T op)
Definition: TosaOps.cpp:334
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
Definition: TosaOps.cpp:706
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:2970
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition: TosaOps.cpp:1148
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
Definition: TosaOps.cpp:1162
static LogicalResult verifyReduceOp(T op)
Definition: TosaOps.cpp:2805
#define NARY_SHAPE_INFER(OP)
Definition: TosaOps.cpp:2873
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
Definition: TosaOps.cpp:2252
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition: TosaOps.cpp:1012
static LogicalResult verifyConvOpErrorIf(T op)
Definition: TosaOps.cpp:480
static LogicalResult verifyConvOpModes(T op)
Definition: TosaOps.cpp:439
std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition: TosaOps.cpp:277
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:2861
#define COMPATIBLE_RETURN_TYPES(OP)
Definition: TosaOps.cpp:2771
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition: TosaOps.cpp:1193
Type getStorageElementTypeOrSelf(Type type)
Definition: TosaOps.cpp:283
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
Definition: TosaOps.cpp:1108
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition: TosaOps.cpp:988
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
Definition: TosaOps.cpp:1063
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
Definition: TosaOps.cpp:620
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition: TosaOps.cpp:136
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
Definition: TosaOps.cpp:2210
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Definition: TosaOps.cpp:3912
static LogicalResult verifyPoolingOp(T op)
Definition: TosaOps.cpp:788
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
Definition: TosaOps.cpp:1282
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:249
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:188
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:94
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:205
This class indicates that op operates on tosa shape types.
Definition: TosaOps.h:130
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:40
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:53
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:50
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
Definition: Operation.cpp:919
LogicalResult verifyTosaShapeOperator(Operation *op)
Definition: TosaOps.cpp:3990
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
Definition: TosaOps.cpp:4005
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
Definition: TosaOps.cpp:3978
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition: Traits.cpp:59
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Definition: QuantUtils.cpp:198
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
Definition: QuantUtils.cpp:289
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
Definition: TosaOps.cpp:225
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
Definition: QuantUtils.cpp:269
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
Definition: QuantUtils.cpp:162
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
Definition: TosaOps.cpp:250
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
Definition: QuantUtils.cpp:214
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
Definition: TosaOps.cpp:3942
bool isa_tosa_shape_type(mlir::Type t)
Definition: TosaOps.cpp:3966
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Definition: QuantUtils.cpp:243
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
Definition: TosaOps.cpp:314
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
Definition: ShapeUtils.h:33
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45