MLIR  21.0.0git
TosaOps.cpp
Go to the documentation of this file.
1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://developer.mlplatform.org/w/tosa/
12 //
13 //===----------------------------------------------------------------------===//
14 
22 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 #include <numeric>
34 
35 using namespace mlir;
36 using namespace mlir::tosa;
37 
38 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
40 
41 //===----------------------------------------------------------------------===//
42 // Tosa dialect interface includes.
43 //===----------------------------------------------------------------------===//
44 
45 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
46 
47 namespace {
48 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
49 
50 //===----------------------------------------------------------------------===//
51 // Dialect Function Inliner Interface.
52 //===----------------------------------------------------------------------===//
53 struct TosaInlinerInterface : public DialectInlinerInterface {
55 
56  //===--------------------------------------------------------------------===//
57  // Analysis Hooks.
58  //===--------------------------------------------------------------------===//
59 
60  /// All operations can be inlined by default.
61  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
62  IRMapping &map) const final {
63  return true;
64  }
65 
66  /// All regions with If and While parent operators can be inlined.
67  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
68  IRMapping &map) const final {
69  return (isa<tosa::IfOp>(dest->getParentOp()) ||
70  isa<tosa::WhileOp>(dest->getParentOp()));
71  }
72 };
73 
74 /// This class implements the bytecode interface for the Tosa dialect.
75 struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
76  TosaDialectBytecodeInterface(Dialect *dialect)
77  : BytecodeDialectInterface(dialect) {}
78 
79  //===--------------------------------------------------------------------===//
80  // Attributes
81 
82  Attribute readAttribute(DialectBytecodeReader &reader) const override {
83  return ::readAttribute(getContext(), reader);
84  }
85 
86  LogicalResult writeAttribute(Attribute attr,
87  DialectBytecodeWriter &writer) const override {
88  return ::writeAttribute(attr, writer);
89  }
90 
91  //===--------------------------------------------------------------------===//
92  // Types
93 
94  Type readType(DialectBytecodeReader &reader) const override {
95  return ::readType(getContext(), reader);
96  }
97 
98  LogicalResult writeType(Type type,
99  DialectBytecodeWriter &writer) const override {
100  return ::writeType(type, writer);
101  }
102 
103  void writeVersion(DialectBytecodeWriter &writer) const final {
104  // TODO: Populate.
105  }
106 
107  std::unique_ptr<DialectVersion>
108  readVersion(DialectBytecodeReader &reader) const final {
109  // TODO: Populate
110  reader.emitError("Dialect does not support versioning");
111  return nullptr;
112  }
113 
114  LogicalResult upgradeFromVersion(Operation *topLevelOp,
115  const DialectVersion &version) const final {
116  return success();
117  }
118 };
119 
120 } // namespace
121 
122 //===----------------------------------------------------------------------===//
123 // TOSA control flow support.
124 //===----------------------------------------------------------------------===//
125 
126 /// Returns the while loop body.
127 SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
128 
129 //===----------------------------------------------------------------------===//
130 // Tosa dialect initialization.
131 //===----------------------------------------------------------------------===//
132 
133 void TosaDialect::initialize() {
134  addTypes<
135 #define GET_TYPEDEF_LIST
136 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
137  >();
138  addOperations<
139 #define GET_OP_LIST
140 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
141  >();
142  addAttributes<
143 #define GET_ATTRDEF_LIST
144 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
145  >();
146  addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
147  declarePromisedInterfaces<
148  mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
149  ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
150  LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
151  LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
152  BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
153  NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
154  GreaterEqualOp, MatMulOp>();
155 }
156 
158  Type type, Location loc) {
159  // Tosa dialect constants only support ElementsAttr unlike standard dialect
160  // constant which supports all attributes.
161  if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
162  return builder.create<tosa::ConstShapeOp>(
163  loc, type, llvm::cast<DenseIntElementsAttr>(value));
164  }
165  if (llvm::isa<ElementsAttr>(value))
166  return builder.create<tosa::ConstOp>(loc, type,
167  llvm::cast<ElementsAttr>(value));
168  return nullptr;
169 }
170 
171 //===----------------------------------------------------------------------===//
172 // Parsers and printers
173 //===----------------------------------------------------------------------===//
174 
175 ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
176  Attribute &attr) {
177  if (succeeded(parser.parseOptionalEqual())) {
178  if (failed(parser.parseAttribute(attr))) {
179  return parser.emitError(parser.getCurrentLocation())
180  << "expected attribute";
181  }
182  if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
183  typeAttr = TypeAttr::get(typedAttr.getType());
184  }
185  return success();
186  }
187 
188  Type type;
189  if (failed(parser.parseColonType(type))) {
190  return parser.emitError(parser.getCurrentLocation()) << "expected type";
191  }
192  typeAttr = TypeAttr::get(type);
193 
194  return success();
195 }
196 
198  Attribute attr) {
199  bool needsSpace = false;
200  auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
201  if (!typedAttr || typedAttr.getType() != type.getValue()) {
202  p << ": ";
203  p.printAttribute(type);
204  needsSpace = true; // subsequent attr value needs a space separator
205  }
206  if (attr) {
207  if (needsSpace)
208  p << ' ';
209  p << "= ";
210  p.printAttribute(attr);
211  }
212 }
213 
214 //===----------------------------------------------------------------------===//
215 // TOSA Operator Verifiers.
216 //===----------------------------------------------------------------------===//
217 
218 template <typename T>
219 static LogicalResult verifyConvOp(T op) {
220  // All TOSA conv ops have an input and weight arguments which must be ranked
221  // tensors.
222  auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
223  if (!inputType) {
224  op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
225  return failure();
226  }
227 
228  auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
229  if (!weightType) {
230  op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
231  return failure();
232  }
233 
234  auto inputEType = inputType.getElementType();
235  auto weightEType = weightType.getElementType();
236  auto biasEType =
237  llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
238  auto resultEType =
239  llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
240  bool biasIsFloat = llvm::isa<FloatType>(biasEType);
241  bool resultIsFloat = llvm::isa<FloatType>(resultEType);
242 
243  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
244  inputEType = quantType.getStorageType();
245 
246  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
247  biasEType = quantType.getStorageType();
248 
249  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
250  resultEType = quantType.getStorageType();
251 
252  if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
253  // for now, only enforce bias element type == result element type for
254  // float types.
255  op.emitOpError(
256  "expect both bias and result to have same element type, got ")
257  << biasEType << " and " << resultEType;
258  return failure();
259  }
260 
261  if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
262  isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
263  if (inputEType != weightEType) {
264  op.emitOpError(
265  "expect both input and weight to have same element type, got ")
266  << inputEType << " and " << weightEType;
267  return failure();
268  }
269  }
270 
271  bool inputIsFloat = llvm::isa<FloatType>(inputEType);
272  bool weightIsFloat = llvm::isa<FloatType>(weightEType);
273 
274  // Either both must be float or both non-float.
275  if (inputIsFloat != weightIsFloat) {
276  op.emitOpError(
277  "expect both input and weight to be float or not together, got ")
278  << inputEType << " and " << weightEType;
279  return failure();
280  }
281 
282  // We require an explicit input zero point and weight zero point for i8
283  // convolution.
284  if (!op.getInputZp() && !op.getWeightZp())
285  return inputEType.isInteger(8) ? failure() : success();
286 
287  ElementsAttr inputZpAttr;
288  ElementsAttr weightZpAttr;
289  if (!matchPattern(op.getInputZp(), m_Constant(&inputZpAttr)) ||
290  !matchPattern(op.getWeightZp(), m_Constant(&weightZpAttr))) {
291  op.emitOpError(
292  "bail out if the actual value of zero points cannot be determined");
293  return failure();
294  }
295 
296  // Get and verify explicit zero points.
297  int64_t inputZpVal;
298  int64_t weightZpVal;
299 
300  if (tosa::getZeroPoint(inputZpAttr, inputZpVal).failed() ||
301  tosa::verifyZeroPoint<T>(getElementTypeOrSelf(inputZpAttr), inputZpVal)
302  .failed()) {
303  op.emitOpError("input zero point must be zero for non-int8 integer types");
304  return failure();
305  }
306 
307  if (tosa::getZeroPoint(weightZpAttr, weightZpVal).failed() ||
308  tosa::verifyZeroPoint<T>(getElementTypeOrSelf(weightZpAttr), weightZpVal)
309  .failed()) {
310  op.emitOpError("weight zero point must be zero for non-int8 integer types");
311  return failure();
312  }
313 
314  return success();
315 }
316 
317 LogicalResult tosa::ConstOp::verify() {
318 
319  auto attrType = llvm::dyn_cast<TensorType>(getValueAttr().getType());
320  auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
321 
322  if (!attrType || !outputType) {
323  emitOpError("expected tensors for attr/result type");
324  return failure();
325  }
326 
327  if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
328  outputType.getElementType())) {
329  if (result.getStorageType() == attrType.getElementType())
330  return success();
331  }
332 
333  if (attrType.getElementType() != outputType.getElementType()) {
334  emitOpError("expected same attr/result element types");
335  return failure();
336  }
337 
338  return success();
339 }
340 
341 template <typename T>
342 static LogicalResult verifyConvOpModes(T op) {
343  auto inputEType =
344  llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
345 
346  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
347  inputEType = quantType.getStorageType();
348 
349  auto accType = op.getAccType();
350  if (inputEType.isInteger(8) && !accType.isInteger(32))
351  return op.emitOpError("accumulator type for i8 tensor is not i32");
352 
353  if (inputEType.isInteger(16) && !accType.isInteger(48))
354  return op.emitOpError("accumulator type for i16 tensor is not i48");
355 
356  if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
357  return op.emitOpError("accumulator type for f8 tensor is not f16");
358 
359  if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
360  return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
361 
362  if (inputEType.isBF16() && !accType.isF32())
363  return op.emitOpError("accumulator type for bf16 tensor is not f32");
364 
365  if (inputEType.isF32() && !accType.isF32())
366  return op.emitOpError("accumulator type for f32 tensor is not f32");
367 
368  auto resultEType =
369  llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
370 
371  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
372  resultEType = quantType.getStorageType();
373 
374  // check allowed input/result element types combinations
375  if ((inputEType.isInteger(8) && resultEType.isInteger(32)) ||
376  (inputEType.isInteger(16) && resultEType.isInteger(48)) ||
377  (isa<Float8E5M2Type>(inputEType) && resultEType.isF16()) ||
378  (isa<Float8E4M3FNType>(inputEType) && resultEType.isF16()) ||
379  (inputEType.isF16() && resultEType.isF16()) ||
380  (inputEType.isBF16() && resultEType.isBF16()) ||
381  (inputEType.isF32() && resultEType.isF32()))
382  return success();
383 
384  return op.emitOpError("input/output element types are incompatible.");
385 }
386 
387 // verify that inType and outType have same element types
388 template <typename T>
389 static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
390  auto inputType = llvm::dyn_cast<TensorType>(inType);
391  auto outputType = llvm::dyn_cast<TensorType>(outType);
392  if (!inputType) {
393  op.emitOpError("expect shaped tensor for input, got ") << inType;
394  return failure();
395  }
396  if (!outputType) {
397  op.emitOpError("expect shaped tensor for output, got ") << outType;
398  return failure();
399  }
400  auto inputElementType = inputType.getElementType();
401  auto outputElementType = outputType.getElementType();
402  auto inputQuantType =
403  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
404  auto outputQuantType =
405  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
406  if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
407  (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
408  inputElementType != outputElementType) {
409  // only check if both element types are int/index/float/UniformQuantized
410  // eg, not sure how to check quant::QuantizedType
411  // this happens in test_conv2d_q_grouped_convolution in
412  // tfl-to-tosa-pipeline.mlir
413  op.emitOpError("expect input and output to have same element type, got ")
414  << inputElementType << " and " << outputElementType;
415  return failure();
416  }
417  return success();
418 }
419 
420 LogicalResult tosa::ArgMaxOp::verify() {
421  // Ensure output is of 32-bit integer
422  const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
423  if (!resultETy.isIntOrIndex())
424  return emitOpError("result tensor is not of integer type");
425 
426  // Ensure axis is within the tensor rank
427  const auto inputType = llvm::cast<ShapedType>(getInput().getType());
428  const int64_t axis = getAxisAttr().getInt();
429  if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
430  return emitOpError("specified axis is outside the rank of the tensor");
431 
432  return success();
433 }
434 
435 LogicalResult tosa::AvgPool2dOp::verify() {
436  auto inputType = llvm::cast<ShapedType>(getInput().getType());
437 
438  auto inputETy = inputType.getElementType();
439  auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
440 
441  if (auto quantType =
442  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
443  inputETy = quantType.getStorageType();
444 
445  if (auto quantType =
446  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy))
447  resultETy = quantType.getStorageType();
448 
449  auto accType = getAccType();
450  if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
451  return emitOpError("accumulator type for integer tensor is not i32");
452 
453  if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
454  return emitOpError("accumulator type for f16 tensor is not f16/f32");
455 
456  if (inputETy.isBF16() && !accType.isF32())
457  return emitOpError("accumulator type for bf16 tensor is not f32");
458 
459  if (inputETy.isF32() && !accType.isF32())
460  return emitOpError("accumulator type for f32 tensor is not f32");
461 
462  if ((inputETy.isF32() && resultETy.isF32()) ||
463  (inputETy.isF16() && resultETy.isF16()) ||
464  (inputETy.isBF16() && resultETy.isBF16()) ||
465  (inputETy.isInteger(8) && resultETy.isInteger(8)) ||
466  (inputETy.isInteger(16) && resultETy.isInteger(16)))
467  return success();
468 
469  return emitOpError("input/output element types are incompatible.");
470 }
471 
472 LogicalResult tosa::ClampOp::verify() {
473  mlir::Type inputETy =
474  llvm::cast<ShapedType>(getInput().getType()).getElementType();
475  if (auto quantType =
476  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
477  inputETy = quantType.getStorageType();
478  }
479  mlir::Type outputETy =
480  llvm::cast<ShapedType>(getOutput().getType()).getElementType();
481  if (auto quantType =
482  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
483  outputETy = quantType.getStorageType();
484  }
485  if (inputETy != outputETy)
486  return emitOpError("input/output element types are incompatible.");
487 
488  auto maxValAttr = getMaxValAttr();
489  auto minValAttr = getMinValAttr();
490 
491  unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
492 
493  if (inputETy.isInteger(dataTypeBitWidth)) {
494  // if input datatype is integer, check that the min_val/max_val attributes
495  // are integer attributes, and that their type is the same as the input's
496  // datatype
497  auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
498  auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
499  if (!intMaxValAttr || !intMinValAttr ||
500  (intMaxValAttr.getType() != intMinValAttr.getType()) ||
501  (intMaxValAttr.getType() != inputETy))
502  return emitOpError("min/max attributes types are incompatible with "
503  "input/output element types.");
504  } else {
505  // otherwise, input datatype is float, check that the min_val/max_val
506  // attributes share the same type and that their type is the same as the
507  // input's datatype
508  auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
509  auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
510  if (!floatMaxValAttr || !floatMinValAttr ||
511  (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
512  (floatMaxValAttr.getType() != inputETy))
513  return emitOpError("min/max attributes types are incompatible with "
514  "input/output element types.");
515  }
516 
517  return success();
518 }
519 
520 //===----------------------------------------------------------------------===//
521 // TOSA Operator Quantization Builders.
522 //===----------------------------------------------------------------------===//
523 
524 /// This builder is called on all convolution operators except TransposeConv,
525 /// which has specialized output shape semantics. The builder also defines the
526 /// bitwidth of the output given the bit width of the input & weight content.
527 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
528  Type outputType, Value input, Value weight,
529  Value bias, DenseI64ArrayAttr pad,
530  DenseI64ArrayAttr stride,
531  DenseI64ArrayAttr dilation,
532  TypeAttr accType) {
533  auto zps = createZPsAsConst(builder, input, weight);
534  result.addOperands({input, weight, bias, zps.first, zps.second});
535  result.addAttribute("pad", pad);
536  result.addAttribute("stride", stride);
537  result.addAttribute("dilation", dilation);
538  result.addAttribute("acc_type", accType);
539  Type finalOutputType = outputType;
540  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
541  if (quantAttr) {
542  finalOutputType =
543  buildConvOpResultTypeInfo(builder, outputType, input, weight);
544  }
545  result.addTypes(finalOutputType);
546 }
547 
548 /// Handles tosa.transpose_conv2d which has outpad and output shape
549 /// attributes.
551  OpBuilder &builder, OperationState &result, Type outputType, Value input,
552  Value weight, Value bias, DenseI64ArrayAttr outpad,
553  DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
554  auto zps = createZPsAsConst(builder, input, weight);
555  result.addOperands({input, weight, bias, zps.first, zps.second});
556  result.addAttribute("out_pad", outpad);
557  result.addAttribute("stride", stride);
558  result.addAttribute("out_shape", outputShape);
559  result.addAttribute("acc_type", accType);
560  Type finalOutputType = outputType;
561  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
562  if (quantAttr) {
563  finalOutputType =
564  buildConvOpResultTypeInfo(builder, outputType, input, weight);
565  }
566  result.addTypes(finalOutputType);
567 }
568 
569 /// The tosa.fully_connected op has its own builder as it does not have
570 /// strides/dilation/padding.
571 static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
572  Type outputType, Value input, Value weight,
573  Value bias) {
574 
575  result.addOperands({input, weight, bias});
576  auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
577  if (quantAttr) {
578  result.addAttribute("quantization_info", quantAttr);
579  result.addTypes(
580  buildConvOpResultTypeInfo(builder, outputType, input, weight));
581  } else {
582  result.addTypes(outputType);
583  }
584 }
585 
586 /// The tosa.matmul op is also intended to be generated where a
587 /// fully_connected op must be constructed where the weight is not a constant.
588 /// In this case, the fully_connected op must be expressed using matmul.
589 /// TODO: Add link to the leglization document explaining this.
591  OperationState &result, Type outputType,
592  Value a, Value b) {
593  result.addOperands({a, b});
594  auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
595 
596  if (quantAttr) {
597  result.addAttribute("a_zp", builder.getI32IntegerAttr(
598  static_cast<int32_t>(quantAttr.getAZp())));
599  result.addAttribute("b_zp", builder.getI32IntegerAttr(
600  static_cast<int32_t>(quantAttr.getBZp())));
601 
602  auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
603  assert(inputType && "Input must be a shaped tensor type!");
604 
605  auto inputQType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
606  inputType.getElementType());
607  assert(inputQType && "Tensor must have quantized datatype!");
608 
609  unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
610 
611  auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
612  assert(outputShapedType && "Output must be a shaped type");
613 
614  IntegerType accElementType;
615  if (inputBits == 16)
616  accElementType = builder.getIntegerType(48);
617  else
618  accElementType = builder.getI32Type();
619  auto accType = outputShapedType.clone(accElementType);
620  result.addTypes(accType);
621  } else {
622  result.addTypes(outputType);
623  }
624 }
625 
626 /// Both the tosa.avg_pool2d and unary ops use the same
627 /// UnaruOpQuantizationAttr but avg_pool operator has its own builder as it
628 /// has additional parameters not part of the unary ops.
629 static void
631  Type outputType, Value input,
632  DenseArrayAttr kernel, DenseArrayAttr stride,
633  DenseArrayAttr pad, TypeAttr accType) {
634  result.addOperands(input);
635  result.addAttribute("kernel", kernel);
636  result.addAttribute("stride", stride);
637  result.addAttribute("pad", pad);
638  result.addAttribute("acc_type", accType);
639  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
640  if (quantAttr) {
641  result.addAttribute("input_zp",
642  builder.getI32IntegerAttr(
643  static_cast<int32_t>(quantAttr.getInputZp())));
644  result.addAttribute("output_zp",
645  builder.getI32IntegerAttr(
646  static_cast<int32_t>(quantAttr.getOutputZp())));
647  }
648  result.types.push_back(outputType);
649 }
650 
651 /// This builder is called on single-parameter unary operators that have scale
652 /// relationship between their input and output, expressed by the
653 /// UnaryOpQuantizationAttr.
655  OperationState &result, Type outputType,
656  Value input) {
657  result.addOperands(input);
658  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
659  if (quantAttr) {
660  // note: negateOp has attributes input1_zp and output_zp
661  result.addAttribute("input1_zp",
662  builder.getI32IntegerAttr(
663  static_cast<int32_t>(quantAttr.getInputZp())));
664  result.addAttribute("output_zp",
665  builder.getI32IntegerAttr(
666  static_cast<int32_t>(quantAttr.getOutputZp())));
667  }
668  result.types.push_back(outputType);
669 }
670 
671 /// This builder is called on TOSA pad operator that needs to create its own
672 /// OptionalAttr quantization_attr parameter to scale the padding values
673 /// correctly. No pad_const is interpreted as zero-padding.
674 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
675  Type outputType, Value input,
676  Value paddings) {
677  result.addOperands({input, paddings});
678  auto quantAttr = buildPadOpQuantizationAttr(builder, input);
679  if (quantAttr) {
680  result.addAttribute("input_zp",
681  builder.getI32IntegerAttr(
682  static_cast<int32_t>(quantAttr.getInputZp())));
683  }
684  result.types.push_back(outputType);
685 }
686 
687 /// This builder is called on TOSA pad operator when an explicit pad_const
688 /// value is passed in. It also optionally constructs quantization_attr.
690  OperationState &result,
691  Type outputType, Value input,
692  Value paddings,
693  Value padConst) {
694  result.addOperands({input, paddings, padConst});
695  auto quantAttr = buildPadOpQuantizationAttr(builder, input);
696  if (quantAttr) {
697  result.addAttribute("input_zp",
698  builder.getI32IntegerAttr(
699  static_cast<int32_t>(quantAttr.getInputZp())));
700  }
701  result.types.push_back(outputType);
702 }
703 
704 //===----------------------------------------------------------------------===//
705 // TOSA Operator Return Type Inference.
706 //===----------------------------------------------------------------------===//
707 
708 static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
709  SmallVector<int64_t> &outShape) {
710  int64_t outRank = 0;
711  for (int i = 0, e = operands.size(); i != e; ++i) {
712  auto shape = operands.getShape(i);
713  if (!shape.hasRank()) {
714  // TODO(jennik): Update function to have better case handling for
715  // invalid operands and for ranked tensors.
716  return failure();
717  }
718  outRank = std::max<int64_t>(outRank, shape.getRank());
719  }
720 
721  outShape.resize(outRank, 1);
722 
723  for (int i = 0, e = operands.size(); i != e; ++i) {
724  auto shape = operands.getShape(i);
725  auto rankDiff = outShape.size() - shape.getRank();
726 
727  for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
728  auto dim1 = outShape[i + rankDiff];
729  auto dim2 = shape.getDimSize(i);
730  auto resolvedDim = dim1;
731 
732  if (dim1 == 1) {
733  resolvedDim = dim2;
734  } else if (dim2 == 1) {
735  resolvedDim = dim1;
736  } else if (dim1 != dim2) {
737  return failure();
738  }
739  outShape[i + rankDiff] = resolvedDim;
740  }
741  }
742 
743  return success();
744 }
745 
746 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
747  MLIRContext *context, ::std::optional<Location> location,
748  ArgMaxOp::Adaptor adaptor,
749  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
750  ShapeAdaptor inputShape(adaptor.getInput().getType());
751  IntegerAttr axis = adaptor.getProperties().axis;
752  int32_t axisVal = axis.getValue().getSExtValue();
753 
754  if (!inputShape.hasRank()) {
755  inferredReturnShapes.push_back(ShapedTypeComponents());
756  return success();
757  }
758 
759  SmallVector<int64_t> outShape;
760  outShape.reserve(inputShape.getRank() - 1);
761  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
762  if (i == axisVal)
763  continue;
764  outShape.push_back(inputShape.getDimSize(i));
765  }
766 
767  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
768  return success();
769 }
770 
771 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
772  MLIRContext *context, ::std::optional<Location> location,
773  RFFT2dOp::Adaptor adaptor,
774  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
775  ShapeAdaptor inputShape(adaptor.getInput().getType());
776 
777  if (!inputShape.hasRank())
778  return failure();
779 
780  llvm::SmallVector<int64_t> outputShape;
781  outputShape.resize(3, ShapedType::kDynamic);
782  outputShape[0] = inputShape.getDimSize(0);
783  outputShape[1] = inputShape.getDimSize(1);
784  int64_t inWidth = inputShape.getDimSize(2);
785 
786  // Note that we can support this calculation symbolically
787  // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1]
788  if (inWidth != ShapedType::kDynamic)
789  outputShape[2] = inWidth / 2 + 1;
790 
791  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
792  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
793 
794  return success();
795 }
796 
797 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
798  MLIRContext *context, ::std::optional<Location> location,
799  FFT2dOp::Adaptor adaptor,
800  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
801  inferredReturnShapes.push_back(
802  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
803  inferredReturnShapes.push_back(
804  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
805  return success();
806 }
807 
808 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
809  MLIRContext *context, ::std::optional<Location> location,
810  ConcatOp::Adaptor adaptor,
811  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
812  // Infer all dimension sizes by reducing based on inputs.
813  const Properties &prop = adaptor.getProperties();
814  int32_t axis = prop.axis.getValue().getSExtValue();
815  llvm::SmallVector<int64_t> outputShape;
816  bool hasRankedInput = false;
817  for (auto operand : adaptor.getOperands()) {
818  ShapeAdaptor operandShape(operand.getType());
819  if (!operandShape.hasRank())
820  continue;
821 
822  // Copy the Operand's rank.
823  if (!hasRankedInput)
824  outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
825 
826  // Copy shapes until the dim is non-dynamic.
827  for (int i = 0, s = operandShape.getRank(); i < s; i++) {
828  if (i == axis || operandShape.isDynamicDim(i))
829  continue;
830  if (outputShape[i] == ShapedType::kDynamic)
831  outputShape[i] = operandShape.getDimSize(i);
832  if (outputShape[i] != operandShape.getDimSize(i))
833  return emitOptionalError(location,
834  "Cannot concat tensors with different sizes"
835  " on the non-axis dimension ",
836  i);
837  }
838 
839  hasRankedInput = true;
840  }
841  Type inputType =
842  llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
843  if (!hasRankedInput) {
844  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
845  return success();
846  }
847 
848  // Determine the dimension size along the concatenation axis.
849  int64_t concatDimSize = 0;
850  for (auto operand : adaptor.getOperands()) {
851  ShapeAdaptor operandShape(operand.getType());
852 
853  // We need to know the length of the concatenation axis of all inputs to
854  // determine the dimension size of the output shape.
855  if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
856  concatDimSize = ShapedType::kDynamic;
857  break;
858  }
859 
860  concatDimSize += operandShape.getDimSize(axis);
861  }
862 
863  outputShape[axis] = concatDimSize;
864 
865  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
866  return success();
867 }
868 
869 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
870  MLIRContext *context, ::std::optional<Location> location,
871  ValueShapeRange operands, DictionaryAttr attributes,
872  OpaqueProperties properties, RegionRange regions,
873  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
874  auto elementType = IntegerType::get(context, /*width=*/1);
875 
877  if (resolveBroadcastShape(operands, outShape).failed()) {
878  inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
879  return success();
880  }
881 
882  inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
883  return success();
884 }
885 
886 bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
887  if (l.size() != r.size() || l.size() != 1)
888  return false;
889  return succeeded(verifyCompatibleShape(l[0], r[0]));
890 }
891 
892 LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
893  MLIRContext *context, ::std::optional<Location> location,
894  FullyConnectedOp::Adaptor adaptor,
895  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
896  ShapeAdaptor inputShape(adaptor.getInput().getType());
897  ShapeAdaptor weightShape(adaptor.getWeight().getType());
898  ShapeAdaptor biasShape(adaptor.getBias().getType());
899 
900  // All shapes are dynamic.
901  SmallVector<int64_t> outShape;
902  outShape.resize(2, ShapedType::kDynamic);
903 
904  if (inputShape.hasRank()) {
905  outShape[0] = inputShape.getDimSize(0);
906  }
907 
908  if (weightShape.hasRank()) {
909  outShape[1] = weightShape.getDimSize(0);
910  }
911 
912  if (biasShape.hasRank()) {
913  outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0)
914  : outShape[1];
915  }
916 
917  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
918  return success();
919 }
920 
921 LogicalResult FullyConnectedOp::verify() {
922  // All TOSA conv ops have an input() and weight().
923  auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
924 
925  RankedTensorType weightType =
926  llvm::dyn_cast<RankedTensorType>(getWeight().getType());
927 
928  // Must be ranked tensor types
929  if (!inputType) {
930  emitOpError("expect a ranked tensor for input, got ") << getInput();
931  return failure();
932  }
933  if (!weightType) {
934  emitOpError("expect a ranked tensor for weight, got ") << getWeight();
935  return failure();
936  }
937 
938  auto inputEType = inputType.getElementType();
939  auto weightEType = weightType.getElementType();
940 
941  bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
942  bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
943 
944  // Either both must be quantized or both unquantized.
945  if (inputIsQuant != weightIsQuant) {
946  emitOpError(
947  "expect both input and weight to be float or not together, got ")
948  << inputEType << " and " << weightEType;
949  return failure();
950  }
951 
952  // Quantized type must have constructed the quantizationattr, and unquantized
953  // types should not have a quantizationattr.
954  if ((inputIsQuant && !getInputZp()) || (!inputIsQuant && getInputZp())) {
955  emitOpError("input zero point is required for quantized type, and not "
956  "allowed for float type");
957  return failure();
958  }
959  return success();
960 }
961 
962 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
963  MLIRContext *context, ::std::optional<Location> location,
964  MatMulOp::Adaptor adaptor,
965  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
966  ShapeAdaptor lhsShape(adaptor.getA().getType());
967  ShapeAdaptor rhsShape(adaptor.getB().getType());
968 
969  // All shapes are dynamic.
970  SmallVector<int64_t> outShape;
971  outShape.resize(3, ShapedType::kDynamic);
972 
973  if (lhsShape.hasRank()) {
974  outShape[0] = lhsShape.getDimSize(0);
975  outShape[1] = lhsShape.getDimSize(1);
976  }
977 
978  if (rhsShape.hasRank()) {
979  outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
980  : outShape[0];
981  outShape[2] = rhsShape.getDimSize(2);
982  }
983 
984  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
985  return success();
986 }
987 
988 LogicalResult tosa::PadOp::inferReturnTypeComponents(
989  MLIRContext *context, ::std::optional<Location> location,
990  PadOp::Adaptor adaptor,
991  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
992  ShapeAdaptor inputShape(adaptor.getInput1().getType());
993  auto paddingRank =
994  cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
995  SmallVector<int64_t> outputShape;
996 
997  // If the input rank is unknown, we can infer the output rank using the
998  // padding shape's rank divided by 2.
999  if (!inputShape.hasRank()) {
1000  outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
1001  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1002  return success();
1003  }
1004 
1005  SmallVector<int64_t> paddingValues;
1006  // If the paddings value is not a constant, all dimensions must be dynamic.
1007  if (!tosa::getConstShapeValue(adaptor.getPadding().getDefiningOp(),
1008  paddingValues)) {
1009  outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1010  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1011  return success();
1012  }
1013 
1014  outputShape.reserve(inputShape.getRank());
1015  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1016  if (inputShape.isDynamicDim(i)) {
1017  outputShape.push_back(ShapedType::kDynamic);
1018  continue;
1019  }
1020  auto padFront = paddingValues[i * 2];
1021  auto padBack = paddingValues[i * 2 + 1];
1022  if (padFront < 0 || padBack < 0) {
1023  // if either padding for dim i is -1, output dim is unknown
1024  outputShape.push_back(ShapedType::kDynamic);
1025  continue;
1026  }
1027 
1028  outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
1029  }
1030 
1031  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1032  return success();
1033 }
1034 
1035 LogicalResult tosa::PadOp::verify() {
1036  RankedTensorType inputType = getInput1().getType();
1037  RankedTensorType outputType = getOutput().getType();
1038  auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
1039 
1040  if (inputType.getRank() != outputType.getRank())
1041  return emitOpError() << "expect same input and output tensor rank.";
1042 
1043  if (paddingRank != inputType.getRank() * 2)
1044  return emitOpError() << "expected padding tensor dim 0 to have size "
1045  << inputType.getRank() * 2
1046  << " (2*rank(shape1)) but got size " << paddingRank;
1047 
1048  return success();
1049 }
1050 
1052  return to_vector(llvm::map_range(shape, [](int64_t dim) {
1053  return dim == -1 ? ShapedType::kDynamic : dim;
1054  }));
1055 }
1056 
1057 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
1058  MLIRContext *context, ::std::optional<Location> location,
1059  SliceOp::Adaptor adaptor,
1060  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1061 
1062  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
1063  SmallVector<int64_t> start;
1064  SmallVector<int64_t> size;
1065 
1066  if (!tosa::getConstShapeValue(adaptor.getStart().getDefiningOp(), start) ||
1067  !tosa::getConstShapeValue(adaptor.getSize().getDefiningOp(), size)) {
1068  auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
1069  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
1070  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
1071  return success();
1072  }
1073 
1074  // if size[i] is -1, all remaining elements in dimension i are included
1075  // in the slice, similar to TF.
1076  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1077  // initialize outputShape to all unknown
1078  SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
1079  if (inputShape.hasRank()) {
1080  for (size_t i = 0; i < size.size(); i++) {
1081  if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
1082  (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
1083  start[i] < inputShape.getDimSize(i))) {
1084  // size[i] is not 0 and not < -1, and start[i] is in valid range
1085  if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
1086  // input shape has unknown dim[i] - only valid if size[i] > 0
1087  if (size[i] > 0) {
1088  outputShape[i] = size[i];
1089  }
1090  } else {
1091  // input shape has known dim[i]
1092  if (size[i] == -1) {
1093  outputShape[i] = inputShape.getDimSize(i) - start[i];
1094  } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
1095  // start[i] + size[i] is within bound of input shape's dim[i]
1096  outputShape[i] = size[i];
1097  }
1098  }
1099  }
1100  }
1101  } else {
1102  outputShape = convertToMlirShape(size);
1103  }
1104  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1105  return success();
1106 }
1107 
1108 LogicalResult tosa::SliceOp::verify() {
1109  auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1110  if (!inputType)
1111  return success();
1112 
1113  auto startShapeRank =
1114  llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
1115  if (inputType.getRank() != startShapeRank)
1116  return emitOpError(
1117  "length of start attribute is not equal rank of input shape");
1118 
1119  auto sizeShapeRank =
1120  llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
1121  if (inputType.getRank() != sizeShapeRank)
1122  return emitOpError(
1123  "length of size attribute is not equal rank of input shape");
1124 
1125  return success();
1126 }
1127 
1128 LogicalResult tosa::MulOp::inferReturnTypeComponents(
1129  MLIRContext *context, ::std::optional<Location> location,
1130  ValueShapeRange operands, DictionaryAttr attributes,
1131  OpaqueProperties properties, RegionRange regions,
1132  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1133  // mul op's output shape only depend on input1 and input2, not on shift
1134  ValueShapeRange twoInputs = operands.drop_back();
1135  llvm::SmallVector<int64_t> outShape;
1136  if (resolveBroadcastShape(twoInputs, outShape).failed()) {
1137  inferredReturnShapes.push_back(ShapedTypeComponents());
1138  } else {
1139  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1140  }
1141  return success();
1142 }
1143 
1144 LogicalResult tosa::MulOp::verify() {
1145  auto resElemType = getElementTypeOrSelf(getOutput());
1146 
1147  // Verify if the element type among operands and result match tosa
1148  // specification.
1149  if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
1150  IntegerType lhsIntType =
1151  cast<IntegerType>(getElementTypeOrSelf(getInput1()));
1152  IntegerType rhsIntType =
1153  cast<IntegerType>(getElementTypeOrSelf(getInput2()));
1154  if (lhsIntType != rhsIntType)
1155  return emitOpError("requires the same element type for all operands");
1156 
1157  // Though the spec requires the element type of result to be i32, a more
1158  // relaxed way is provided at dialect level for easier cooperating with
1159  // other dialects.
1160  if (lhsIntType.getWidth() > resIntType.getWidth())
1161  return emitOpError("invalid data type size for operands or result");
1162 
1163  } else {
1164  // For other supported type, the spec requires requires the same element
1165  // type for all operands (excludes `shift` operand) and results.
1166  for (int i = 0; i < 2; ++i) {
1167  if (getElementTypeOrSelf(getOperand(i)) != resElemType)
1168  return emitOpError(
1169  "requires the same element type for all operands and results");
1170  }
1171 
1172  // verify shift has value 0 for non-integer types
1173  ElementsAttr shift_elem;
1174  if (matchPattern(getShift(), m_Constant(&shift_elem))) {
1175  int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1176  if (shift != 0) {
1177  return emitOpError() << "require shift to be 0 for float type";
1178  }
1179  }
1180  }
1181 
1182  // Verify the op has same ranks for all main operands (excludes extra operands
1183  // such as shift of mul op, so this is the only difference with the built-in
1184  // `SameOperandsAndResultRank` trait) and results types, if known.
1185 
1186  // delegate function that returns true if type is a shaped type with known
1187  // rank
1188  auto hasRank = [](const Type type) {
1189  if (auto shaped_type = dyn_cast<ShapedType>(type))
1190  return shaped_type.hasRank();
1191 
1192  return false;
1193  };
1194 
1195  auto rankedOperandTypes =
1196  llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));
1197 
1198  auto rankedResultTypes =
1199  llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
1200 
1201  // If all operands and results are unranked, then no further verification.
1202  if (rankedOperandTypes.empty() && rankedResultTypes.empty())
1203  return success();
1204 
1205  // delegate function that returns rank of shaped type with known rank
1206  auto getRank = [](const Type type) {
1207  return cast<ShapedType>(type).getRank();
1208  };
1209 
1210  auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
1211  : getRank(*rankedResultTypes.begin());
1212 
1213  for (size_t i = 0; i < 2; ++i) {
1214  if (rank != getRank(rankedOperandTypes[i])) {
1215  return emitOpError("operands don't have matching ranks");
1216  }
1217  }
1218 
1219  for (const auto type : rankedResultTypes) {
1220  if (rank != getRank(type)) {
1221  return emitOpError("result type has different rank than operands");
1222  }
1223  }
1224 
1225  // check for broadcast compatible shapes in first two operands (ignoring
1226  // shift)
1227 
1228  // delegate function that returns shape of shaped type
1229  auto getShape = [](const Type type) {
1230  return mlir::cast<ShapedType>(type).getShape();
1231  };
1232  SmallVector<int64_t> resultShape;
1233  if (!mlir::OpTrait::util::getBroadcastedShape(getShape(rankedOperandTypes[0]),
1234  getShape(rankedOperandTypes[1]),
1235  resultShape)) {
1236  return emitOpError("operands don't have broadcast-compatible shapes");
1237  }
1238 
1239  return success();
1240 }
1241 
1242 LogicalResult tosa::TableOp::inferReturnTypeComponents(
1243  MLIRContext *context, ::std::optional<Location> location,
1244  TableOp::Adaptor adaptor,
1245  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1246  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1247 
1248  if (!inputShape.hasRank()) {
1249  inferredReturnShapes.push_back(ShapedTypeComponents());
1250  return success();
1251  }
1252 
1253  inferredReturnShapes.resize(1);
1254  inputShape.getDims(inferredReturnShapes[0]);
1255  return success();
1256 }
1257 
1258 LogicalResult tosa::TableOp::verify() {
1259  TensorType inputType = getInput1().getType();
1260  TensorType outputType = getOutput().getType();
1261 
1262  if (inputType.hasRank() && outputType.hasRank() &&
1263  inputType.getRank() != outputType.getRank())
1264  return emitOpError()
1265  << "expected input tensor rank to equal result tensor rank";
1266 
1267  auto inputDims = inputType.getShape();
1268  auto outputDims = outputType.getShape();
1269  for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
1270  int64_t dim = it.index();
1271  auto [inputDim, outputDim] = it.value();
1272  if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) {
1273  return emitOpError() << "dim(result, " << dim << ") = " << outputDim
1274  << " doesn't match dim(input, " << dim
1275  << ") = " << inputDim;
1276  }
1277  }
1278  return success();
1279 }
1280 
1281 LogicalResult
1282 tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
1283  // Multiples must be constants.
1284  DenseIntElementsAttr multiplesAttr;
1285  if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
1286  return failure();
1287  multiples = llvm::to_vector(
1288  llvm::map_range(multiplesAttr.getValues<APInt>(),
1289  [](const APInt &val) { return val.getSExtValue(); }));
1290  return success();
1291 }
1292 
1293 LogicalResult tosa::TileOp::inferReturnTypeComponents(
1294  MLIRContext *context, ::std::optional<Location> location,
1295  TileOp::Adaptor adaptor,
1296  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1297  DenseIntElementsAttr multiplesAttr;
1298  if (!matchPattern(adaptor.getMultiples(), m_Constant(&multiplesAttr)))
1299  return failure();
1300 
1301  SmallVector<int64_t> multiples = llvm::to_vector(
1302  llvm::map_range(multiplesAttr.getValues<APInt>(),
1303  [](const APInt &val) { return val.getSExtValue(); }));
1304 
1305  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1306  SmallVector<int64_t> outputShape;
1307  if (!inputShape.hasRank()) {
1308  outputShape.resize(multiples.size(), ShapedType::kDynamic);
1309  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1310  return success();
1311  } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
1312  return failure();
1313 
1314  // Any non dynamic dimension can be multiplied to a known size.
1315  outputShape.reserve(multiples.size());
1316  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1317  int64_t dim = inputShape.getDimSize(i);
1318  if (dim != ShapedType::kDynamic)
1319  dim *= multiples[i];
1320  outputShape.push_back(dim);
1321  }
1322 
1323  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1324  return success();
1325 }
1326 
1327 LogicalResult tosa::TileOp::verify() {
1328  ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
1329  ShapedType outputType = llvm::cast<ShapedType>(getType());
1330 
1331  shapeType multiplesType =
1332  llvm::cast<tosa::shapeType>(getMultiples().getType());
1333 
1334  auto multiplesRank = multiplesType.getRank();
1335 
1336  if (inputType.hasRank()) {
1337  if (inputType.getRank() != multiplesRank)
1338  return emitOpError("expect 'multiples' to have rank ")
1339  << inputType.getRank() << " but got " << multiplesRank << ".";
1340  if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
1341  return emitOpError("expect same input and output tensor rank.");
1342  } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
1343  return emitOpError("expect 'multiples' array to have length ")
1344  << outputType.getRank() << " but got " << multiplesRank << ".";
1345 
1346  SmallVector<int64_t> multiples;
1347  if (getConstantMultiples(multiples).succeeded() &&
1348  llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
1349  return emitOpError(
1350  "expect element of 'multiples' to be positive integer or -1.");
1351 
1352  return success();
1353 }
1354 
1355 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1356  if (l.size() != r.size() || l.size() != 1)
1357  return false;
1358  return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
1359 }
1360 
1361 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
1362  MLIRContext *context, ::std::optional<Location> location,
1363  ReshapeOp::Adaptor adaptor,
1364  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1365  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1366  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
1367  llvm::SmallVector<int64_t> newShapeValue;
1368  if (!tosa::getConstShapeValue(adaptor.getShape().getDefiningOp(),
1369  newShapeValue)) {
1370  auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
1371  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
1372  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
1373  return success();
1374  } else {
1375  newShapeValue = convertToMlirShape(newShapeValue);
1376  }
1377 
1378  // We cannot infer from the total number of elements so we must take the
1379  // shape attribute as exact.
1380  if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
1381  inferredReturnShapes.push_back(
1382  ShapedTypeComponents(newShapeValue, inputType));
1383  return success();
1384  }
1385 
1386  // Determine the number of elements covered by the slice of all static
1387  // dimensions. This allows us to infer the length of the remaining dynamic
1388  // dimension.
1389  int64_t numElements = inputShape.getNumElements();
1390  int64_t staticMul = 1;
1391  for (auto val : newShapeValue) {
1392  if (!ShapedType::isDynamic(val)) {
1393  staticMul *= val;
1394  }
1395  }
1396 
1397  // Determine the length of the dynamic dimension.
1398  for (auto &val : newShapeValue) {
1399  if (ShapedType::isDynamic(val))
1400  val = numElements / staticMul;
1401  }
1402 
1403  inferredReturnShapes.push_back(
1404  ShapedTypeComponents(newShapeValue, inputType));
1405  return success();
1406 }
1407 
1408 llvm::LogicalResult tosa::ReshapeOp::verify() {
1409  TensorType inputType = getInput1().getType();
1410  RankedTensorType outputType = getType();
1411 
1412  SmallVector<int64_t> shapeValues;
1413  if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeValues)) {
1414  // skip following checks if shape is not constant
1415  return mlir::success();
1416  }
1417 
1418  if ((int64_t)shapeValues.size() != outputType.getRank())
1419  return emitOpError() << "new shape does not match result rank";
1420 
1421  for (auto [newShapeDim, outputShapeDim] :
1422  zip(shapeValues, outputType.getShape())) {
1423  if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
1424  outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
1425  return emitOpError() << "new shape is inconsistent with result shape";
1426 
1427  if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
1428  return emitOpError() << "new shape has invalid tensor dimension size "
1429  << newShapeDim;
1430  }
1431 
1432  if (inputType.hasStaticShape()) {
1433  int64_t inputElementsNum = inputType.getNumElements();
1434  if (outputType.hasStaticShape()) {
1435  int64_t outputElementsNum = outputType.getNumElements();
1436  if (inputElementsNum != outputElementsNum) {
1437  return emitOpError() << "cannot reshape " << inputElementsNum
1438  << " elements into " << outputElementsNum;
1439  }
1440  }
1441 
1442  int64_t newShapeElementsNum = std::accumulate(
1443  shapeValues.begin(), shapeValues.end(), 1LL,
1444  [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
1445  bool isStaticNewShape =
1446  llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
1447  if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
1448  (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
1449  return emitOpError() << "cannot reshape " << inputElementsNum
1450  << " elements into " << newShapeElementsNum;
1451  }
1452  }
1453 
1454  int missingDims = llvm::count(shapeValues, -1);
1455  if (missingDims > 1)
1456  return emitOpError() << "expected at most one target dimension to be -1";
1457 
1458  return mlir::success();
1459 }
1460 
1461 LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) {
1462  // Perms must be constants.
1463  DenseIntElementsAttr permsAttr;
1464  if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
1465  return failure();
1466 
1467  perms.clear();
1468  for (auto v : permsAttr.getValues<APInt>())
1469  perms.push_back(v.getSExtValue());
1470 
1471  return success();
1472 }
1473 
1474 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
1475  MLIRContext *context, ::std::optional<Location> location,
1476  TransposeOp::Adaptor adaptor,
1477  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1478  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1479  ShapeAdaptor permsShape(adaptor.getPerms().getType());
1480 
1481  // We cannot infer anything from a rank-0 "permutation" tensor.
1482  if (permsShape.hasRank() && permsShape.getRank() == 0)
1483  return failure();
1484 
1485  // If input rank and permutation length is unknown, the output rank is
1486  // unknown.
1487  if (!inputShape.hasRank() || !permsShape.hasRank() ||
1488  permsShape.isDynamicDim(0)) {
1489  inferredReturnShapes.push_back(ShapedTypeComponents());
1490  return success();
1491  }
1492 
1493  // This would imply the number of permutations does not match the rank of
1494  // the input which is illegal.
1495  if (permsShape.getDimSize(0) != inputShape.getRank()) {
1496  return failure();
1497  }
1498 
1499  SmallVector<int64_t> outputShape;
1500  // Rank-0 means no permutations matter.
1501  if (inputShape.getRank() == 0) {
1502  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1503  return success();
1504  }
1505 
1506  // Check whether the input dimensions are all the same.
1507  bool allTheSame = true;
1508  for (int i = 1, s = inputShape.getRank(); i < s; i++) {
1509  if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
1510  allTheSame = false;
1511  break;
1512  }
1513  }
1514 
1515  // If all of the input dimensions are the same we don't care about the
1516  // permutation.
1517  if (allTheSame) {
1518  outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
1519  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1520  return success();
1521  }
1522 
1523  outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1524  // If the permuations are a constant we can directly determine the output
1525  // shape.
1526  DenseIntElementsAttr attr;
1527  if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
1528  attr.getType().getRank() == 1) {
1529  ShapeAdaptor permShape = attr;
1530  // Constant permutation must be the same length as the input rank.
1531  if (inputShape.getRank() != permShape.getRank())
1532  return emitOptionalError(location,
1533  "constant permutation must be the same length"
1534  " as the input rank");
1535 
1536  // Constant permutation values must be within the input rank.
1537  for (int i = 0, e = inputShape.getRank(); i < e; i++) {
1538  if (inputShape.getRank() <= permShape.getDimSize(i))
1539  return failure();
1540  }
1541 
1542  outputShape.reserve(inputShape.getRank());
1543  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1544  outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
1545  }
1546  }
1547 
1548  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1549  return success();
1550 }
1551 
1552 LogicalResult tosa::TransposeOp::verify() {
1553  TensorType inputType = getInput1().getType();
1554  TensorType permType = getPerms().getType();
1555  TensorType outputType = getOutput().getType();
1556 
1557  if (permType.hasRank() && permType.getRank() != 1)
1558  return emitOpError()
1559  << "expected permutation tensor to be rank 1 but got rank "
1560  << permType.getRank();
1561  if (inputType.hasRank() && permType.hasRank())
1562  if (!permType.isDynamicDim(0) &&
1563  permType.getDimSize(0) != inputType.getRank())
1564  return emitOpError() << "expected permutation tensor dim 0 to have size "
1565  << inputType.getRank()
1566  << " (input rank) but got size "
1567  << permType.getDimSize(0);
1568  if (inputType.hasRank() && outputType.hasRank() &&
1569  inputType.getRank() != outputType.getRank())
1570  return emitOpError()
1571  << "expected input tensor rank to equal result tensor rank";
1572  if (outputType.hasRank() && permType.hasRank())
1573  if (!permType.isDynamicDim(0) &&
1574  permType.getDimSize(0) != outputType.getRank())
1575  return emitOpError() << "expected permutation tensor dim 0 to have size "
1576  << outputType.getRank()
1577  << " (output rank) but got size "
1578  << permType.getDimSize(0);
1579 
1580  SmallVector<int32_t> constantPerms;
1581  if (succeeded(getConstantPerms(constantPerms))) {
1582  // Assert that the permutation tensor has a rank, which means that the
1583  // rank has been verified above.
1584  assert(permType.hasRank() &&
1585  "Unexpectedly found permutation tensor without rank");
1586  if (!llvm::all_of(constantPerms,
1587  [&constantPerms](int32_t s) {
1588  return s >= 0 &&
1589  static_cast<size_t>(s) < constantPerms.size();
1590  }) ||
1591  !isPermutationVector(llvm::to_vector(llvm::map_range(
1592  constantPerms, [](int32_t v) -> int64_t { return v; }))))
1593  return emitOpError() << "expected valid permutation tensor";
1594 
1595  // Verify that the types of the input and output tensors are properly
1596  // permuted.
1597  if (inputType.hasRank() && outputType.hasRank()) {
1598  assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
1599  inputType.getRank() == outputType.getRank());
1600 
1601  for (auto i = 0; i < outputType.getRank(); i++) {
1602  if (inputType.isDynamicDim(constantPerms[i]) ||
1603  outputType.isDynamicDim(i))
1604  continue;
1605 
1606  if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
1607  return emitOpError()
1608  << "expected output tensor dim " << i << " to match "
1609  << "input dim " << constantPerms[i] << " with value of "
1610  << inputType.getDimSize(constantPerms[i]);
1611  }
1612  }
1613  }
1614  return success();
1615 }
1616 
1617 LogicalResult TransposeOp::reifyResultShapes(
1618  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1619 
1620  SmallVector<int32_t> transposePerms;
1621  if (getConstantPerms(transposePerms).failed())
1622  return failure();
1623 
1624  Value input = getInput1();
1625  auto inputType = cast<TensorType>(input.getType());
1626 
1627  SmallVector<OpFoldResult> returnedDims(inputType.getRank());
1628  for (auto dim : transposePerms) {
1629  int32_t dimInInput = transposePerms[dim];
1630  if (inputType.isDynamicDim(dimInInput))
1631  returnedDims[dim] =
1632  builder.create<tensor::DimOp>(getLoc(), input, dimInInput)
1633  .getResult();
1634  else
1635  returnedDims[dim] =
1636  builder.getIndexAttr(inputType.getDimSize(dimInInput));
1637  }
1638 
1639  reifiedReturnShapes.emplace_back(std::move(returnedDims));
1640  return success();
1641 }
1642 
1643 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
1644  MLIRContext *context, ::std::optional<Location> location,
1645  GatherOp::Adaptor adaptor,
1646  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1647  llvm::SmallVector<int64_t> outputShape;
1648  outputShape.resize(3, ShapedType::kDynamic);
1649 
1650  ShapeAdaptor valuesShape(adaptor.getValues().getType());
1651  if (valuesShape.hasRank()) {
1652  outputShape[0] = valuesShape.getDimSize(0);
1653  outputShape[2] = valuesShape.getDimSize(2);
1654  }
1655 
1656  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
1657  if (indicesShape.hasRank()) {
1658  if (outputShape[0] == ShapedType::kDynamic)
1659  outputShape[0] = indicesShape.getDimSize(0);
1660  if (outputShape[1] == ShapedType::kDynamic)
1661  outputShape[1] = indicesShape.getDimSize(1);
1662  }
1663 
1664  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1665  return success();
1666 }
1667 
1668 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
1669  MLIRContext *context, ::std::optional<Location> location,
1670  ResizeOp::Adaptor adaptor,
1671  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1672  llvm::SmallVector<int64_t, 4> outputShape;
1673  outputShape.resize(4, ShapedType::kDynamic);
1674 
1675  ShapeAdaptor inputShape(adaptor.getInput().getType());
1676  if (!inputShape.hasRank())
1677  return failure();
1678 
1679  outputShape[0] = inputShape.getDimSize(0);
1680  outputShape[3] = inputShape.getDimSize(3);
1681  int64_t inputHeight = inputShape.getDimSize(1);
1682  int64_t inputWidth = inputShape.getDimSize(2);
1683 
1684  if ((inputHeight == ShapedType::kDynamic) ||
1685  (inputWidth == ShapedType::kDynamic))
1686  return failure();
1687 
1688  llvm::ArrayRef<int64_t> scaleInt = adaptor.getScale();
1689  llvm::ArrayRef<int64_t> offsetInt = adaptor.getOffset();
1690  llvm::ArrayRef<int64_t> borderInt = adaptor.getBorder();
1691 
1692  // Compute the output shape based on attributes: scale, offset, and border.
1693  outputShape[1] =
1694  (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
1695  scaleInt[1]) +
1696  1;
1697 
1698  outputShape[2] =
1699  (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
1700  scaleInt[3]) +
1701  1;
1702 
1703  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1704  return success();
1705 }
1706 
1707 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
1708  MLIRContext *context, ::std::optional<Location> location,
1709  ScatterOp::Adaptor adaptor,
1710  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1711  llvm::SmallVector<int64_t> outputShape;
1712  outputShape.resize(3, ShapedType::kDynamic);
1713 
1714  ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
1715  if (valuesInShape.hasRank()) {
1716  outputShape[0] = valuesInShape.getDimSize(0);
1717  outputShape[1] = valuesInShape.getDimSize(1);
1718  outputShape[2] = valuesInShape.getDimSize(2);
1719  }
1720 
1721  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
1722  if (indicesShape.hasRank()) {
1723  if (outputShape[0] == ShapedType::kDynamic)
1724  outputShape[0] = indicesShape.getDimSize(0);
1725  }
1726 
1727  ShapeAdaptor inputShape(adaptor.getInput().getType());
1728  if (inputShape.hasRank()) {
1729  if (outputShape[0] == ShapedType::kDynamic)
1730  outputShape[0] = inputShape.getDimSize(0);
1731  if (outputShape[2] == ShapedType::kDynamic)
1732  outputShape[2] = inputShape.getDimSize(2);
1733  }
1734 
1735  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1736  return success();
1737 }
1738 
1739 static LogicalResult ReduceInferReturnTypes(
1740  ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
1741  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1742  int64_t axisVal = axis.getValue().getSExtValue();
1743  if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
1744  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1745  return success();
1746  }
1747 
1748  SmallVector<int64_t> outputShape;
1749  operandShape.getDims(outputShape);
1750  outputShape[axisVal] = 1;
1751  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1752  return success();
1753 }
1754 
1755 #define COMPATIBLE_RETURN_TYPES(OP) \
1756  bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
1757  if (l.size() != r.size() || l.size() != 1) \
1758  return false; \
1759  if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
1760  return false; \
1761  return succeeded(verifyCompatibleShape(l[0], r[0])); \
1762  }
1763 
1764 #define REDUCE_SHAPE_INFER(OP) \
1765  LogicalResult OP::inferReturnTypeComponents( \
1766  MLIRContext *context, ::std::optional<Location> location, \
1767  OP::Adaptor adaptor, \
1768  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
1769  Type inputType = \
1770  llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
1771  ShapeAdaptor inputShape(adaptor.getInput().getType()); \
1772  const Properties &prop = adaptor.getProperties(); \
1773  return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
1774  inferredReturnShapes); \
1775  } \
1776  COMPATIBLE_RETURN_TYPES(OP)
1777 
1778 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
1779 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
1780 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
1781 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
1782 REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
1783 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
1784 #undef REDUCE_SHAPE_INFER
1785 COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
1786 #undef COMPATIBLE_RETURN_TYPES
1787 
1788 template <typename T>
1789 static LogicalResult verifyReduceOp(T op) {
1790  // All TOSA reduce Ops have input, output and axis.
1791  TensorType inputType = op.getInput().getType();
1792  TensorType outputType = op.getOutput().getType();
1793  int32_t reduceAxis = op.getAxis();
1794 
1795  if (reduceAxis < 0) {
1796  op.emitOpError("reduce axis must not be negative");
1797  return failure();
1798  }
1799  if (inputType.hasRank()) {
1800  int64_t inputRank = inputType.getRank();
1801  // We allow for a special case where the input/output shape has rank 0 and
1802  // axis is also 0.
1803  if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
1804  op.emitOpError("expect input tensor rank (")
1805  << inputRank << ") to be larger than reduce axis (" << reduceAxis
1806  << ")";
1807  return failure();
1808  }
1809  }
1810  if (outputType.hasRank()) {
1811  int64_t outputRank = outputType.getRank();
1812  if (inputType.hasRank() && outputRank != inputType.getRank()) {
1813  op.emitOpError(
1814  "expect output tensor rank to be equal to input tensor rank");
1815  return failure();
1816  }
1817  if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
1818  op.emitOpError("expect output tensor rank (")
1819  << outputRank << ") to be larger than reduce axis (" << reduceAxis
1820  << ")";
1821  return failure();
1822  }
1823  // We can only verify the reduced dimension size to be 1 if this is not
1824  // the special case of output rank == 0.
1825  if (outputRank != 0) {
1826  auto outputShape = outputType.getShape();
1827  if (!outputType.isDynamicDim(reduceAxis) &&
1828  outputShape[reduceAxis] != 1) {
1829  op.emitOpError("expect reduced dimension size to be 1, got ")
1830  << outputShape[reduceAxis];
1831  return failure();
1832  }
1833  }
1834  }
1835  return success();
1836 }
1837 
1838 LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
1839 LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
1840 LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
1841 LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
1842 LogicalResult tosa::ReduceProdOp::verify() { return verifyReduceOp(*this); }
1843 LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
1844 
1845 static LogicalResult NAryInferReturnTypes(
1846  const ValueShapeRange &operands,
1847  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1848  llvm::SmallVector<int64_t> outShape;
1849  if (resolveBroadcastShape(operands, outShape).failed()) {
1850  inferredReturnShapes.push_back(ShapedTypeComponents());
1851  } else {
1852  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1853  }
1854  return success();
1855 }
1856 
1857 #define NARY_SHAPE_INFER(OP) \
1858  LogicalResult OP::inferReturnTypeComponents( \
1859  MLIRContext *context, ::std::optional<Location> location, \
1860  ValueShapeRange operands, DictionaryAttr attributes, \
1861  OpaqueProperties properties, RegionRange regions, \
1862  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
1863  return NAryInferReturnTypes(operands, inferredReturnShapes); \
1864  }
1865 
1866 NARY_SHAPE_INFER(tosa::AbsOp)
1867 NARY_SHAPE_INFER(tosa::AddOp)
1868 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
1869 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
1870 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
1871 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
1872 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
1873 NARY_SHAPE_INFER(tosa::CastOp)
1874 NARY_SHAPE_INFER(tosa::CeilOp)
1875 NARY_SHAPE_INFER(tosa::ClampOp)
1876 NARY_SHAPE_INFER(tosa::ClzOp)
1877 NARY_SHAPE_INFER(tosa::CosOp)
1878 NARY_SHAPE_INFER(tosa::ExpOp)
1879 NARY_SHAPE_INFER(tosa::FloorOp)
1880 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
1881 NARY_SHAPE_INFER(tosa::GreaterOp)
1882 NARY_SHAPE_INFER(tosa::IdentityOp)
1883 NARY_SHAPE_INFER(tosa::IntDivOp)
1884 NARY_SHAPE_INFER(tosa::LogOp)
1885 NARY_SHAPE_INFER(tosa::LogicalAndOp)
1886 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
1887 NARY_SHAPE_INFER(tosa::LogicalNotOp)
1888 NARY_SHAPE_INFER(tosa::LogicalOrOp)
1889 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
1890 NARY_SHAPE_INFER(tosa::LogicalXorOp)
1891 NARY_SHAPE_INFER(tosa::MaximumOp)
1892 NARY_SHAPE_INFER(tosa::MinimumOp)
1893 NARY_SHAPE_INFER(tosa::NegateOp)
1894 NARY_SHAPE_INFER(tosa::PowOp)
1895 NARY_SHAPE_INFER(tosa::ReciprocalOp)
1896 NARY_SHAPE_INFER(tosa::RescaleOp)
1897 NARY_SHAPE_INFER(tosa::ReverseOp)
1898 NARY_SHAPE_INFER(tosa::RsqrtOp)
1899 NARY_SHAPE_INFER(tosa::SinOp)
1900 NARY_SHAPE_INFER(tosa::SelectOp)
1901 NARY_SHAPE_INFER(tosa::SubOp)
1902 NARY_SHAPE_INFER(tosa::TanhOp)
1903 NARY_SHAPE_INFER(tosa::ErfOp)
1904 NARY_SHAPE_INFER(tosa::SigmoidOp)
1905 #undef PRED_SHAPE_INFER
1906 
1907 static LogicalResult poolingInferReturnTypes(
1908  ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
1909  ArrayRef<int64_t> pad,
1910  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1911  llvm::SmallVector<int64_t> outputShape;
1912  outputShape.resize(4, ShapedType::kDynamic);
1913 
1914  // We only know the rank if the input type is unranked.
1915  if (!inputShape) {
1916  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1917  return success();
1918  }
1919 
1920  // Batch and number of channels are identical for pooling layer.
1921  outputShape[0] = inputShape.getDimSize(0);
1922  outputShape[3] = inputShape.getDimSize(3);
1923 
1924  int64_t height = inputShape.getDimSize(1);
1925  int64_t width = inputShape.getDimSize(2);
1926 
1927  if (!ShapedType::isDynamic(height)) {
1928  int64_t padded = height + pad[0] + pad[1] - kernel[0];
1929  outputShape[1] = padded / stride[0] + 1;
1930  }
1931 
1932  if (!ShapedType::isDynamic(width)) {
1933  int64_t padded = width + pad[2] + pad[3] - kernel[1];
1934  outputShape[2] = padded / stride[1] + 1;
1935  }
1936 
1937  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1938  return success();
1939 }
1940 
1941 LogicalResult Conv2DOp::inferReturnTypeComponents(
1942  MLIRContext *context, ::std::optional<Location> location,
1943  Conv2DOp::Adaptor adaptor,
1944  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1945  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
1946 
1947  int64_t inputWidth = ShapedType::kDynamic;
1948  int64_t inputHeight = ShapedType::kDynamic;
1949  int64_t weightWidth = ShapedType::kDynamic;
1950  int64_t weightHeight = ShapedType::kDynamic;
1951 
1952  // Input shape describes input width/height and batch.
1953 
1954  ShapeAdaptor inputShape(adaptor.getInput().getType());
1955  if (inputShape.hasRank()) {
1956  outputShape[0] = inputShape.getDimSize(0);
1957  inputHeight = inputShape.getDimSize(1);
1958  inputWidth = inputShape.getDimSize(2);
1959  }
1960 
1961  // Weight shapes describes the filter width/height and the output channels.
1962  ShapeAdaptor weightShape(adaptor.getWeight().getType());
1963  if (weightShape.hasRank()) {
1964  outputShape[3] = weightShape.getDimSize(0);
1965  weightHeight = weightShape.getDimSize(1);
1966  weightWidth = weightShape.getDimSize(2);
1967  }
1968 
1969  // Bias shape can describe the output channels.
1970  ShapeAdaptor biasShape(adaptor.getBias().getType());
1971  if (biasShape.hasRank()) {
1972  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1973  ? biasShape.getDimSize(0)
1974  : outputShape[3];
1975  }
1976 
1977  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
1978  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1979  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
1980 
1981  if (!ShapedType::isDynamic(inputHeight) &&
1982  !ShapedType::isDynamic(weightHeight)) {
1983  int64_t inputSize = inputHeight + padding[0] + padding[1];
1984  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1985  int64_t unstridedResult = inputSize - filterSize + 1;
1986  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1987  }
1988 
1989  if (!ShapedType::isDynamic(inputWidth) &&
1990  !ShapedType::isDynamic(weightWidth)) {
1991  int64_t inputSize = inputWidth + padding[2] + padding[3];
1992  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1993  int64_t unstridedResult = inputSize - filterSize + 1;
1994  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1995  }
1996 
1997  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1998  return success();
1999 }
2000 
2001 LogicalResult Conv2DOp::verify() {
2002  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
2003  return failure();
2004  return success();
2005 }
2006 
2007 LogicalResult Conv3DOp::inferReturnTypeComponents(
2008  MLIRContext *context, ::std::optional<Location> location,
2009  Conv3DOp::Adaptor adaptor,
2010  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2011  llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
2012 
2013  int64_t inputWidth = ShapedType::kDynamic;
2014  int64_t inputHeight = ShapedType::kDynamic;
2015  int64_t inputDepth = ShapedType::kDynamic;
2016 
2017  int64_t weightWidth = ShapedType::kDynamic;
2018  int64_t weightHeight = ShapedType::kDynamic;
2019  int64_t weightDepth = ShapedType::kDynamic;
2020 
2021  // Input shape describes input width/height and batch.
2022  ShapeAdaptor inputShape(adaptor.getInput().getType());
2023  if (inputShape.hasRank()) {
2024  outputShape[0] = inputShape.getDimSize(0);
2025  inputDepth = inputShape.getDimSize(1);
2026  inputHeight = inputShape.getDimSize(2);
2027  inputWidth = inputShape.getDimSize(3);
2028  }
2029 
2030  // Weight shapes describes the filter width/height and the output channels.
2031  ShapeAdaptor weightShape(adaptor.getWeight().getType());
2032  if (weightShape.hasRank()) {
2033  outputShape[4] = weightShape.getDimSize(0);
2034  weightDepth = weightShape.getDimSize(1);
2035  weightHeight = weightShape.getDimSize(2);
2036  weightWidth = weightShape.getDimSize(3);
2037  }
2038 
2039  // Bias shape can describe the output channels.
2040  ShapeAdaptor biasShape(adaptor.getBias().getType());
2041  if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
2042  outputShape[4] = biasShape.getDimSize(0);
2043  }
2044 
2045  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
2046  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
2047  llvm::ArrayRef<int64_t> pad = adaptor.getPad();
2048 
2049  if (!ShapedType::isDynamic(inputDepth) &&
2050  !ShapedType::isDynamic(weightDepth)) {
2051  int32_t inputSize = inputDepth + pad[0] + pad[1];
2052  int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
2053  int32_t unstridedResult = inputSize - filterSize + 1;
2054  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
2055  }
2056 
2057  if (!ShapedType::isDynamic(inputHeight) &&
2058  !ShapedType::isDynamic(weightHeight)) {
2059  int32_t inputSize = inputHeight + pad[2] + pad[3];
2060  int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
2061  int32_t unstridedResult = inputSize - filterSize + 1;
2062  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
2063  }
2064 
2065  if (!ShapedType::isDynamic(inputWidth) &&
2066  !ShapedType::isDynamic(weightWidth)) {
2067  int32_t inputSize = inputWidth + pad[4] + pad[5];
2068  int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
2069  int32_t unstridedResult = inputSize - filterSize + 1;
2070  outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
2071  }
2072 
2073  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2074  return success();
2075 }
2076 
2077 LogicalResult Conv3DOp::verify() {
2078  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
2079  return failure();
2080  return success();
2081 }
2082 
2083 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
2084  MLIRContext *context, ::std::optional<Location> location,
2085  AvgPool2dOp::Adaptor adaptor,
2086  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2087  ShapeAdaptor inputShape(adaptor.getInput().getType());
2088  const Properties &prop = adaptor.getProperties();
2089  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
2090  inferredReturnShapes);
2091 }
2092 
2093 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
2094  MLIRContext *context, ::std::optional<Location> location,
2095  MaxPool2dOp::Adaptor adaptor,
2096  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2097  ShapeAdaptor inputShape(adaptor.getInput().getType());
2098  const Properties &prop = adaptor.getProperties();
2099  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
2100  inferredReturnShapes);
2101 }
2102 
2103 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
2104  MLIRContext *context, ::std::optional<Location> location,
2105  DepthwiseConv2DOp::Adaptor adaptor,
2106  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2107  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
2108 
2109  int64_t inputWidth = ShapedType::kDynamic;
2110  int64_t inputHeight = ShapedType::kDynamic;
2111  int64_t inputChannels = ShapedType::kDynamic;
2112 
2113  int64_t weightWidth = ShapedType::kDynamic;
2114  int64_t weightHeight = ShapedType::kDynamic;
2115  int64_t depthChannels = ShapedType::kDynamic;
2116 
2117  // Input shape describes input width/height and batch.
2118  ShapeAdaptor inputShape(adaptor.getInput().getType());
2119  if (inputShape.hasRank()) {
2120  outputShape[0] = inputShape.getDimSize(0);
2121  inputHeight = inputShape.getDimSize(1);
2122  inputWidth = inputShape.getDimSize(2);
2123  inputChannels = inputShape.getDimSize(3);
2124  }
2125 
2126  // Weight shapes describes the filter width/height and the output channels.
2127  ShapeAdaptor weightShape(adaptor.getWeight().getType());
2128  if (weightShape.hasRank()) {
2129  weightHeight = weightShape.getDimSize(0);
2130  weightWidth = weightShape.getDimSize(1);
2131  inputChannels = ShapedType::isDynamic(inputChannels)
2132  ? weightShape.getDimSize(2)
2133  : inputChannels;
2134  depthChannels = weightShape.getDimSize(3);
2135  }
2136 
2137  // If both inputChannels and depthChannels are available we can determine
2138  // the output channels.
2139  if (!ShapedType::isDynamic(inputChannels) &&
2140  !ShapedType::isDynamic(depthChannels)) {
2141  outputShape[3] = inputChannels * depthChannels;
2142  }
2143 
2144  // Bias shape can describe the output channels.
2145  ShapeAdaptor biasShape(adaptor.getBias().getType());
2146  if (biasShape.hasRank()) {
2147  outputShape[3] = ShapedType::isDynamic(outputShape[3])
2148  ? biasShape.getDimSize(0)
2149  : outputShape[3];
2150  }
2151 
2152  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
2153  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
2154  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
2155 
2156  if (!ShapedType::isDynamic(inputHeight) &&
2157  !ShapedType::isDynamic(weightHeight)) {
2158  int64_t inputSize = inputHeight + padding[0] + padding[1];
2159  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
2160  int64_t unstridedResult = inputSize - filterSize + 1;
2161  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
2162  }
2163 
2164  if (!ShapedType::isDynamic(inputWidth) &&
2165  !ShapedType::isDynamic(weightWidth)) {
2166  int64_t inputSize = inputWidth + padding[2] + padding[3];
2167  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
2168  int64_t unstridedResult = inputSize - filterSize + 1;
2169  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
2170  }
2171 
2172  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2173  return success();
2174 }
2175 
2176 LogicalResult DepthwiseConv2DOp::verify() {
2177  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
2178  return failure();
2179  return success();
2180 }
2181 
2182 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
2183  MLIRContext *context, ::std::optional<Location> location,
2184  TransposeConv2DOp::Adaptor adaptor,
2185  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2186  // outputShape is mutable.
2187  llvm::SmallVector<int64_t> outputShape =
2188  convertToMlirShape(adaptor.getOutShape());
2189 
2190  int64_t inputWidth = ShapedType::kDynamic;
2191  int64_t inputHeight = ShapedType::kDynamic;
2192  int64_t weightWidth = ShapedType::kDynamic;
2193  int64_t weightHeight = ShapedType::kDynamic;
2194 
2195  // Input shape describes input width/height and batch.
2196  ShapeAdaptor inputShape(adaptor.getInput().getType());
2197  if (inputShape.hasRank()) {
2198  outputShape[0] = ShapedType::isDynamic(outputShape[0])
2199  ? inputShape.getDimSize(0)
2200  : outputShape[0];
2201  inputHeight = inputShape.getDimSize(1);
2202  inputWidth = inputShape.getDimSize(2);
2203  }
2204 
2205  // Weight shapes describes the filter width/height and the output channels.
2206  ShapeAdaptor weightShape(adaptor.getWeight().getType());
2207  if (weightShape.hasRank()) {
2208  outputShape[3] = ShapedType::isDynamic(outputShape[3])
2209  ? weightShape.getDimSize(0)
2210  : outputShape[3];
2211  weightHeight = weightShape.getDimSize(1);
2212  weightWidth = weightShape.getDimSize(2);
2213  }
2214 
2215  // Bias shape can describe the output channels.
2216  ShapeAdaptor biasShape(adaptor.getInput().getType());
2217  if (biasShape.hasRank()) {
2218  outputShape[3] = ShapedType::isDynamic(outputShape[3])
2219  ? biasShape.getDimSize(0)
2220  : outputShape[3];
2221  }
2222 
2223  llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
2224  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
2225 
2226  if (!ShapedType::isDynamic(inputHeight) &&
2227  !ShapedType::isDynamic(weightHeight)) {
2228  int64_t calculateSize =
2229  (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
2230  outputShape[1] =
2231  ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
2232  }
2233 
2234  if (!ShapedType::isDynamic(inputWidth) &&
2235  !ShapedType::isDynamic(weightWidth)) {
2236  int64_t calculateSize =
2237  (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
2238  outputShape[2] =
2239  ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
2240  }
2241 
2242  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2243  return success();
2244 }
2245 
2246 LogicalResult TransposeConv2DOp::verify() {
2247  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
2248  return failure();
2249  return success();
2250 }
2251 
2252 LogicalResult IfOp::inferReturnTypeComponents(
2253  MLIRContext *context, ::std::optional<Location> location,
2254  IfOp::Adaptor adaptor,
2255  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2257  for (Region *region : adaptor.getRegions()) {
2258  for (auto &block : *region)
2259  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
2260  yieldOps.push_back(returnOp);
2261  }
2262 
2263  if (yieldOps.empty())
2264  return failure();
2265 
2266  // Get the initial type information for the yield op.
2267  llvm::SmallVector<ValueKnowledge> resultKnowledge;
2268  resultKnowledge.reserve(yieldOps.front().getNumOperands());
2269  for (auto operand : yieldOps.front().getOperands()) {
2270  resultKnowledge.push_back(
2271  ValueKnowledge::getKnowledgeFromType(operand.getType()));
2272  }
2273 
2274  for (auto yieldOp : yieldOps) {
2275  if (resultKnowledge.size() != yieldOp.getNumOperands())
2276  return failure();
2277 
2278  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
2279  int32_t index = it.index();
2280  auto meet = ValueKnowledge::meet(
2281  resultKnowledge[index],
2282  ValueKnowledge::getKnowledgeFromType(it.value().getType()));
2283  if (!meet)
2284  continue;
2285  resultKnowledge[index] = meet;
2286  }
2287  }
2288 
2289  for (const ValueKnowledge &result : resultKnowledge) {
2290  inferredReturnShapes.push_back(result.getShapedTypeComponents());
2291  }
2292 
2293  return success();
2294 }
2295 
2296 LogicalResult WhileOp::inferReturnTypeComponents(
2297  MLIRContext *context, ::std::optional<Location> location,
2298  WhileOp::Adaptor adaptor,
2299  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2301  for (auto &block : adaptor.getBody())
2302  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
2303  yieldOps.push_back(returnOp);
2304 
2305  // TOSA's while must have a tosa.yield as its terminator. If not found this
2306  // tosa.while is invalid.
2307  if (yieldOps.empty())
2308  return failure();
2309 
2310  // Get the initial type information from the operand types.
2311  llvm::SmallVector<ValueKnowledge> resultKnowledge;
2312  resultKnowledge.reserve(yieldOps.front().getNumOperands());
2313  for (auto operand : yieldOps.front().getOperands()) {
2314  resultKnowledge.push_back(
2315  ValueKnowledge::getKnowledgeFromType(operand.getType()));
2316  }
2317 
2318  for (auto yieldOp : yieldOps) {
2319  if (resultKnowledge.size() != yieldOp.getNumOperands())
2320  return failure();
2321 
2322  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
2323  int32_t index = it.index();
2324  if (auto meet = ValueKnowledge::meet(
2325  resultKnowledge[index],
2326  ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
2327  resultKnowledge[index] = meet;
2328  }
2329  }
2330  }
2331 
2332  for (const ValueKnowledge &result : resultKnowledge) {
2333  inferredReturnShapes.push_back(result.getShapedTypeComponents());
2334  }
2335 
2336  return success();
2337 }
2338 
2339 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
2340  if (auto vt = llvm::dyn_cast<VectorType>(getType()))
2341  return llvm::to_vector<4>(vt.getShape());
2342  return std::nullopt;
2343 }
2344 
2345 // parse and print of IfOp refer to the implementation of SCF dialect.
2346 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2347  // Create the regions for 'then'.
2348  result.regions.reserve(2);
2349  Region *thenRegion = result.addRegion();
2350  Region *elseRegion = result.addRegion();
2351 
2352  auto &builder = parser.getBuilder();
2354  // Create a i1 tensor type for the boolean condition.
2355  Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
2356  if (parser.parseOperand(cond) ||
2357  parser.resolveOperand(cond, i1Type, result.operands))
2358  return failure();
2359  // Parse optional results type list.
2360  if (parser.parseOptionalArrowTypeList(result.types))
2361  return failure();
2362  // Parse the 'then' region.
2363  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2364  return failure();
2365 
2366  // If we find an 'else' keyword then parse the 'else' region.
2367  if (!parser.parseOptionalKeyword("else")) {
2368  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2369  return failure();
2370  }
2371 
2372  // Parse the optional attribute list.
2373  if (parser.parseOptionalAttrDict(result.attributes))
2374  return failure();
2375  return success();
2376 }
2377 
2378 void IfOp::print(OpAsmPrinter &p) {
2379  bool printBlockTerminators = false;
2380 
2381  p << " " << getCond();
2382  if (!getResults().empty()) {
2383  p << " -> (" << getResultTypes() << ")";
2384  // Print yield explicitly if the op defines values.
2385  printBlockTerminators = true;
2386  }
2387  p << ' ';
2388  p.printRegion(getThenBranch(),
2389  /*printEntryBlockArgs=*/false,
2390  /*printBlockTerminators=*/printBlockTerminators);
2391 
2392  // Print the 'else' regions if it exists and has a block.
2393  auto &elseRegion = getElseBranch();
2394  if (!elseRegion.empty()) {
2395  p << " else ";
2396  p.printRegion(elseRegion,
2397  /*printEntryBlockArgs=*/false,
2398  /*printBlockTerminators=*/printBlockTerminators);
2399  }
2400 
2401  p.printOptionalAttrDict((*this)->getAttrs());
2402 }
2403 
2404 LogicalResult ReverseOp::verify() {
2405  TensorType inputType = getInput1().getType();
2406  TensorType outputType = getOutput().getType();
2407  int32_t reverseAxis = getAxis();
2408 
2409  if (reverseAxis < 0)
2410  return emitOpError("expected non-negative reverse axis");
2411  if (inputType.hasRank()) {
2412  int64_t inputRank = inputType.getRank();
2413  // We allow for a special case where the input/output shape has rank 0 and
2414  // axis is also 0.
2415  if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
2416  return emitOpError("expect input tensor rank (")
2417  << inputRank << ") to be larger than reverse axis (" << reverseAxis
2418  << ")";
2419  }
2420  if (outputType.hasRank()) {
2421  int64_t outputRank = outputType.getRank();
2422  if (inputType.hasRank() && outputRank != inputType.getRank())
2423  return emitOpError(
2424  "expect output tensor rank to be equal to input tensor rank");
2425  if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
2426  return emitOpError("expect output tensor rank (")
2427  << outputRank << ") to be larger than reverse axis ("
2428  << reverseAxis << ")";
2429  }
2430  return success();
2431 }
2432 
2433 // parse and print of WhileOp refer to the implementation of SCF dialect.
2434 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
2437  Region *cond = result.addRegion();
2438  Region *body = result.addRegion();
2439 
2440  OptionalParseResult listResult =
2441  parser.parseOptionalAssignmentList(regionArgs, operands);
2442  if (listResult.has_value() && failed(listResult.value()))
2443  return failure();
2444 
2445  FunctionType functionType;
2446  SMLoc typeLoc = parser.getCurrentLocation();
2447  if (failed(parser.parseColonType(functionType)))
2448  return failure();
2449 
2450  result.addTypes(functionType.getResults());
2451 
2452  if (functionType.getNumInputs() != operands.size()) {
2453  return parser.emitError(typeLoc)
2454  << "expected as many input types as operands "
2455  << "(expected " << operands.size() << " got "
2456  << functionType.getNumInputs() << ")";
2457  }
2458 
2459  // Resolve input operands.
2460  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
2461  parser.getCurrentLocation(),
2462  result.operands)))
2463  return failure();
2464 
2465  // Propagate the types into the region arguments.
2466  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
2467  regionArgs[i].type = functionType.getInput(i);
2468 
2469  return failure(parser.parseRegion(*cond, regionArgs) ||
2470  parser.parseKeyword("do") || parser.parseRegion(*body) ||
2472 }
2473 
2475  Block::BlockArgListType blocksArgs,
2476  ValueRange initializers,
2477  StringRef prefix = "") {
2478  assert(blocksArgs.size() == initializers.size() &&
2479  "expected same length of arguments and initializers");
2480  if (initializers.empty())
2481  return;
2482 
2483  parser << prefix << '(';
2484  llvm::interleaveComma(
2485  llvm::zip(blocksArgs, initializers), parser,
2486  [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
2487  parser << ")";
2488 }
2489 
2490 void WhileOp::print(OpAsmPrinter &parser) {
2491  printInitializationList(parser, getCond().front().getArguments(), getInputs(),
2492  " ");
2493  parser << " : ";
2494  parser.printFunctionalType(getInputs().getTypes(), getResults().getTypes());
2495  parser << ' ';
2496  parser.printRegion(getCond(), /*printEntryBlockArgs=*/false);
2497  parser << " do ";
2498  parser.printRegion(getBody());
2499  parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
2500 }
2501 
2502 LogicalResult mlir::tosa::getZeroPoint(ElementsAttr zpAttr, int64_t &zp) {
2503  Type zpElemType = zpAttr.getElementType();
2504  if (auto quantType =
2505  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(zpElemType)) {
2506  zp = quantType.getZeroPoint();
2507  return success();
2508  }
2509  if (llvm::isa<FloatType>(zpElemType)) {
2510  // non-zero zero point is not allowed for float types.
2511  if (!zpAttr.getValues<APFloat>()[0].isZero())
2512  return failure();
2513  zp = 0;
2514  return success();
2515  }
2516  if (llvm::isa<IntegerType>(zpElemType)) {
2517  zp = zpAttr.getValues<APInt>()[0].getSExtValue();
2518  return success();
2519  }
2520  // zero point is not allowed for unsupported types.
2521  return failure();
2522 }
2523 
2524 // Create a rank-1 const tensor for zero point of the source tensor.
2525 std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
2526  Location loc,
2527  Type srcElemType,
2528  int64_t zp) {
2529  srcElemType = getElementTypeOrSelf(srcElemType);
2530  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcElemType))
2531  srcElemType = quantType.getStorageType();
2532  auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
2533  if (llvm::isa<FloatType>(srcElemType)) {
2534  auto zpAttr = DenseElementsAttr::get(
2535  zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
2536  return builder.create<tosa::ConstOp>(loc, zpType, zpAttr);
2537  }
2538  if (llvm::isa<IntegerType>(srcElemType)) {
2539  auto zpAttr =
2540  DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp));
2541  return builder.create<tosa::ConstOp>(loc, zpType, zpAttr);
2542  }
2543  llvm::errs() << "zero point is not allowed for unsupported data types\n";
2544  return std::nullopt;
2545 }
2546 
2547 //===----------------------------------------------------------------------===//
2548 // TOSA Shape and Shape Operators Helper functions.
2549 //===----------------------------------------------------------------------===//
2550 
2552  return mlir::isa<tosa::shapeType>(t);
2553 }
2554 
2555 LogicalResult
2557  int rank) {
2558  if (rank < 0)
2559  return emitError() << "invalid rank (must be >= 0): " << rank;
2560  return success();
2561 }
2562 
2564  for (auto v : op->getOperands()) {
2565  if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
2566  Operation *definingOp = v.getDefiningOp();
2567  if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
2568  return op->emitOpError("shape operand is not compile time resolvable");
2569  }
2570  }
2571  }
2572  return success();
2573 }
2574 
2576  for (auto type : op->getOperandTypes()) {
2577  if (!mlir::isa<mlir::tosa::shapeType>(type)) {
2578  return op->emitOpError("must have operands with tosa shape type");
2579  }
2580  }
2581  for (auto type : op->getResultTypes()) {
2582  if (!mlir::isa<mlir::tosa::shapeType>(type)) {
2583  return op->emitOpError("must have result with tosa shape type");
2584  }
2585  }
2586  return success();
2587 }
2588 
2589 LogicalResult
2591  if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) ||
2592  failed(verifyTosaShapeOperator(op)))
2593  return failure();
2594 
2595  // delegate function that returns rank of shape type
2596  auto getRank = [](const Type type) {
2597  return mlir::cast<mlir::tosa::shapeType>(type).getRank();
2598  };
2599  auto operandTypes = op->getOperandTypes();
2600  auto resultTypes = op->getResultTypes();
2601 
2602  auto rank = getRank(*op->getOperandTypes().begin());
2603  for (auto type : operandTypes) {
2604  if (getRank(type) != rank) {
2605  return op->emitOpError("operands don't have matching ranks");
2606  }
2607  }
2608  for (auto type : resultTypes) {
2609  if (getRank(type) != rank) {
2610  return op->emitOpError("result shape has different rank than operands");
2611  }
2612  }
2613  return success();
2614 }
2615 
2616 //===----------------------------------------------------------------------===//
2617 // TOSA Shape Operators verify functions.
2618 //===----------------------------------------------------------------------===//
2619 
2620 LogicalResult tosa::ConstShapeOp::verify() {
2621  // check one dimensional rank
2622  auto valuesRank = getValue().getType().getRank();
2623  if (valuesRank != 1)
2624  return emitOpError("expect elements in attribute value with rank 1");
2625  // check that number of elements in value attr equal to rank of result shape
2626  auto count = getValue().getNumElements();
2627  auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
2628  if (!(count == rank || (count == 1 && rank == 0))) {
2629  return emitOpError("expect number of elements in attribute value (")
2630  << count << ") to be equal to the rank (" << rank
2631  << ") for the result shape type";
2632  }
2633  return success();
2634 }
2635 
2636 //===----------------------------------------------------------------------===//
2637 // TOSA Attribute Definitions.
2638 //===----------------------------------------------------------------------===//
2639 
2640 #define GET_ATTRDEF_CLASSES
2641 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
2642 
2643 //===----------------------------------------------------------------------===//
2644 // TOSA Type Definitions.
2645 //===----------------------------------------------------------------------===//
2646 #define GET_TYPEDEF_CLASSES
2647 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
2648 
2649 //===----------------------------------------------------------------------===//
2650 // TOSA Operator Definitions.
2651 //===----------------------------------------------------------------------===//
2652 
2653 #define GET_OP_CLASSES
2654 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition: TosaOps.cpp:590
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType)
Definition: TosaOps.cpp:389
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:1739
#define REDUCE_SHAPE_INFER(OP)
Definition: TosaOps.cpp:1764
static LogicalResult verifyConvOp(T op)
Definition: TosaOps.cpp:219
static void buildUnaryOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter unary operators that have scale relationship between their...
Definition: TosaOps.cpp:654
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:1907
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition: TosaOps.cpp:550
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition: TosaOps.cpp:674
static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias)
The tosa.fully_connected op has its own builder as it does not have strides/dilation/padding.
Definition: TosaOps.cpp:571
static LogicalResult verifyReduceOp(T op)
Definition: TosaOps.cpp:1789
#define NARY_SHAPE_INFER(OP)
Definition: TosaOps.cpp:1857
static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings, Value padConst)
This builder is called on TOSA pad operator when an explicit pad_const value is passed in.
Definition: TosaOps.cpp:689
static LogicalResult verifyConvOpModes(T op)
Definition: TosaOps.cpp:342
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:1845
#define COMPATIBLE_RETURN_TYPES(OP)
Definition: TosaOps.cpp:1755
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition: TosaOps.cpp:708
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition: TosaOps.cpp:527
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr but avg_pool operator has...
Definition: TosaOps.cpp:630
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition: TosaOps.cpp:1051
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Definition: TosaOps.cpp:2474
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:95
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class indicates that op operates on tosa shape types.
Definition: TosaOps.h:119
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:47
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:114
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
Definition: Operation.cpp:915
LogicalResult verifyTosaShapeOperator(Operation *op)
Definition: TosaOps.cpp:2575
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
Definition: TosaOps.cpp:2590
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
Definition: TosaOps.cpp:2563
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition: Traits.cpp:60
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Definition: QuantUtils.cpp:191
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
Definition: QuantUtils.cpp:282
LogicalResult getZeroPoint(ElementsAttr zpAttr, int64_t &zp)
Definition: TosaOps.cpp:2502
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
Definition: QuantUtils.cpp:262
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
Definition: QuantUtils.cpp:155
ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &attr)
Definition: TosaOps.cpp:175
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
Definition: QuantUtils.cpp:207
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
Definition: TosaOps.cpp:2525
bool isa_tosa_shape_type(mlir::Type t)
Definition: TosaOps.cpp:2551
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, Attribute attr)
Definition: TosaOps.cpp:197
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Definition: QuantUtils.cpp:236
bool getConstShapeValue(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
Definition: ShapeUtils.h:33
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45