MLIR  18.0.0git
TosaOps.cpp
Go to the documentation of this file.
1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://developer.mlplatform.org/w/tosa/
12 //
13 //===----------------------------------------------------------------------===//
14 
20 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
27 #include "llvm/ADT/DenseMap.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 
30 using namespace mlir;
31 using namespace mlir::tosa;
32 
33 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
34 
35 //===----------------------------------------------------------------------===//
36 // Tosa dialect interface includes.
37 //===----------------------------------------------------------------------===//
38 
39 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
40 
41 namespace {
42 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
43 
44 //===----------------------------------------------------------------------===//
45 // Dialect Function Inliner Interface.
46 //===----------------------------------------------------------------------===//
47 struct TosaInlinerInterface : public DialectInlinerInterface {
49 
50  //===--------------------------------------------------------------------===//
51  // Analysis Hooks.
52  //===--------------------------------------------------------------------===//
53 
54  /// All operations can be inlined by default.
55  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
56  IRMapping &map) const final {
57  return true;
58  }
59 
60  /// All regions with If and While parent operators can be inlined.
61  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
62  IRMapping &map) const final {
63  return (isa<tosa::IfOp>(dest->getParentOp()) ||
64  isa<tosa::WhileOp>(dest->getParentOp()));
65  }
66 };
67 
68 /// This class implements the bytecode interface for the Tosa dialect.
69 struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
70  TosaDialectBytecodeInterface(Dialect *dialect)
71  : BytecodeDialectInterface(dialect) {}
72 
73  //===--------------------------------------------------------------------===//
74  // Attributes
75 
76  Attribute readAttribute(DialectBytecodeReader &reader) const override {
77  return ::readAttribute(getContext(), reader);
78  }
79 
80  LogicalResult writeAttribute(Attribute attr,
81  DialectBytecodeWriter &writer) const override {
82  return ::writeAttribute(attr, writer);
83  }
84 
85  //===--------------------------------------------------------------------===//
86  // Types
87 
88  Type readType(DialectBytecodeReader &reader) const override {
89  return ::readType(getContext(), reader);
90  }
91 
92  LogicalResult writeType(Type type,
93  DialectBytecodeWriter &writer) const override {
94  return ::writeType(type, writer);
95  }
96 
97  void writeVersion(DialectBytecodeWriter &writer) const final {
98  // TODO: Populate.
99  }
100 
101  std::unique_ptr<DialectVersion>
102  readVersion(DialectBytecodeReader &reader) const final {
103  // TODO: Populate
104  reader.emitError("Dialect does not support versioning");
105  return nullptr;
106  }
107 
108  LogicalResult upgradeFromVersion(Operation *topLevelOp,
109  const DialectVersion &version_) const final {
110  return success();
111  }
112 };
113 
114 } // namespace
115 
116 //===----------------------------------------------------------------------===//
117 // TOSA control flow support.
118 //===----------------------------------------------------------------------===//
119 
120 /// Returns the while loop body.
121 SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
122 
123 //===----------------------------------------------------------------------===//
124 // Tosa dialect initialization.
125 //===----------------------------------------------------------------------===//
126 
127 void TosaDialect::initialize() {
128  addOperations<
129 #define GET_OP_LIST
130 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
131  >();
132  addAttributes<
133 #define GET_ATTRDEF_LIST
134 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
135  >();
136  addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
137 }
138 
140  Type type, Location loc) {
141  // Tosa dialect constants only support ElementsAttr unlike standard dialect
142  // constant which supports all attributes.
143  if (llvm::isa<ElementsAttr>(value))
144  return builder.create<tosa::ConstOp>(loc, type,
145  llvm::cast<ElementsAttr>(value));
146  return nullptr;
147 }
148 
149 //===----------------------------------------------------------------------===//
150 // Parsers and printers
151 //===----------------------------------------------------------------------===//
152 
154  Attribute &attr) {
155  if (succeeded(parser.parseOptionalEqual())) {
156  if (failed(parser.parseAttribute(attr))) {
157  return parser.emitError(parser.getCurrentLocation())
158  << "expected attribute";
159  }
160  if (auto typedAttr = attr.dyn_cast<TypedAttr>()) {
161  typeAttr = TypeAttr::get(typedAttr.getType());
162  }
163  return success();
164  }
165 
166  Type type;
167  if (failed(parser.parseColonType(type))) {
168  return parser.emitError(parser.getCurrentLocation()) << "expected type";
169  }
170  typeAttr = TypeAttr::get(type);
171 
172  return success();
173 }
174 
176  Attribute attr) {
177  bool needsSpace = false;
178  auto typedAttr = attr.dyn_cast_or_null<TypedAttr>();
179  if (!typedAttr || typedAttr.getType() != type.getValue()) {
180  p << ": ";
181  p.printAttribute(type);
182  needsSpace = true; // subsequent attr value needs a space separator
183  }
184  if (attr) {
185  if (needsSpace)
186  p << ' ';
187  p << "= ";
188  p.printAttribute(attr);
189  }
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // TOSA Operator Verifiers.
194 //===----------------------------------------------------------------------===//
195 
196 static bool hasZeroDimension(ShapedType shapedType) {
197  if (!shapedType.hasRank())
198  return false;
199 
200  auto rank = shapedType.getRank();
201 
202  for (int i = 0; i < rank; i++) {
203  if (shapedType.isDynamicDim(i))
204  continue;
205  if (shapedType.getDimSize(i) == 0)
206  return true;
207  }
208 
209  return false;
210 }
211 
212 template <typename T> static LogicalResult verifyConvOp(T op) {
213  // All TOSA conv ops have an input() and weight().
214  auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
215  auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
216 
217  // Must be ranked tensor types
218  if (!inputType) {
219  op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
220  return failure();
221  }
222  if (!weightType) {
223  op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
224  return failure();
225  }
226 
227  if (hasZeroDimension(inputType))
228  return op.emitOpError() << "tensor has a dimension with size zero. Each "
229  "dimension of a tensor must have size >= 1";
230 
231  auto inputEType = inputType.getElementType();
232  auto weightEType = weightType.getElementType();
233 
234  bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
235  bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
236 
237  // Either both must be quantized or both unquantized.
238  if (inputIsQuant != weightIsQuant) {
239  op.emitOpError(
240  "expect both input and weight to be float or not together, got ")
241  << inputEType << " and " << weightEType;
242  return failure();
243  }
244 
245  // Quantized type must have constructed the quantizationattr, and unquantized
246  // types should not have a quantizationattr.
247  if ((inputIsQuant && !op.getQuantizationInfo()) ||
248  (!inputIsQuant && op.getQuantizationInfo())) {
249  op.emitOpError("quantizationattr is required for quantized type, and not "
250  "allowed for float type");
251  return failure();
252  }
253 
254  return success();
255 }
256 
258  // Ensure output is of 32-bit integer
259  const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
260  if (!resultETy.isIntOrIndex())
261  return emitOpError("result tensor is not of integer type");
262 
263  // Ensure axis is within the tensor rank
264  const auto inputType = llvm::cast<ShapedType>(getInput().getType());
265  const int64_t axis = getAxisAttr().getInt();
266  if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
267  return emitOpError("specified axis is outside the rank of the tensor");
268 
269  return success();
270 }
271 
273  auto inputType = llvm::cast<ShapedType>(getInput().getType());
274  if (hasZeroDimension(inputType))
275  return emitOpError() << "tensor has a dimension with size zero. Each "
276  "dimension of a tensor must have size >= 1";
277 
278  auto inputETy = inputType.getElementType();
279  auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
280 
281  if (auto quantType =
282  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
283  inputETy = quantType.getStorageType();
284 
285  if (auto quantType =
286  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy))
287  resultETy = quantType.getStorageType();
288 
289  auto accType = getAccType();
290  if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
291  return emitOpError("accumulator type for integer tensor is not i32");
292 
293  if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
294  return emitOpError("accumulator type for f16 tensor is not f16/f32");
295 
296  if (inputETy.isBF16() && !accType.isF32())
297  return emitOpError("accumulator type for bf16 tensor is not f32");
298 
299  if (inputETy.isF32() && !accType.isF32())
300  return emitOpError("accumulator type for f32 tensor is not f32");
301 
302  if ((inputETy.isF32() && resultETy.isF32()) ||
303  (inputETy.isF16() && resultETy.isF16()) ||
304  (inputETy.isBF16() && resultETy.isBF16()) ||
305  (inputETy.isInteger(8) && resultETy.isInteger(8)) ||
306  (inputETy.isInteger(16) && resultETy.isInteger(16)))
307  return success();
308 
309  return emitOpError("input/output element types are incompatible.");
310 }
311 
313  mlir::Type inputETy =
314  llvm::cast<ShapedType>(getInput().getType()).getElementType();
315  if (auto quantType =
316  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
317  inputETy = quantType.getStorageType();
318  }
319  mlir::Type maxFpType = getMaxFpAttr().getType();
320  mlir::Type minFpType = getMinFpAttr().getType();
321  mlir::Type outputETy =
322  llvm::cast<ShapedType>(getOutput().getType()).getElementType();
323  if (auto quantType =
324  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
325  outputETy = quantType.getStorageType();
326  }
327  unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
328 
329  if (inputETy != outputETy)
330  return emitOpError("input/output element types are incompatible.");
331 
332  // if input datatype is float, check that the two min/max_fp attributes share
333  // the same type and that their type is either the same of the input's
334  // datatype, or a float type whose bitwidth > input datatype bitwidth
335  if (!inputETy.isInteger(dataTypeBitWidth)) {
336  if (((maxFpType != minFpType) ||
337  (maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <=
338  inputETy.getIntOrFloatBitWidth())))
339  return emitOpError("min/max attributes types are incompatible with "
340  "input/output element types.");
341  }
342 
343  return success();
344 }
345 
346 //===----------------------------------------------------------------------===//
347 // TOSA Operator Quantization Builders.
348 //===----------------------------------------------------------------------===//
349 
350 /// This builder is called on all convolution operators except TransposeConv,
351 /// which has specialized output shape semantics. The builder also defines the
352 /// bitwidth of the output given the bit width of the input & weight content.
353 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
354  Type outputType, Value input, Value weight,
355  Value bias, DenseI64ArrayAttr pad,
356  DenseI64ArrayAttr stride,
357  DenseI64ArrayAttr dilation) {
358 
359  result.addOperands({input, weight, bias});
360  result.addAttribute("pad", pad);
361  result.addAttribute("stride", stride);
362  result.addAttribute("dilation", dilation);
363 
364  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
365  if (quantAttr) {
366  result.addAttribute("quantization_info", quantAttr);
367  result.addTypes(
368  buildConvOpResultTypeInfo(builder, outputType, input, weight));
369  } else {
370  result.addTypes(outputType);
371  }
372 }
373 
374 /// Handles tosa.transpose_conv2d which has outpad and output shape attributes.
376  OpBuilder &builder, OperationState &result, Type outputType, Value input,
377  Value weight, Value bias, DenseI64ArrayAttr outpad,
378  DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) {
379  result.addOperands({input, weight, bias});
380  result.addAttribute("out_pad", outpad);
381  result.addAttribute("stride", stride);
382  result.addAttribute("out_shape", outputShape);
383  auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
384 
385  if (quantAttr) {
386  result.addAttribute("quantization_info", quantAttr);
387  result.addTypes(
388  buildConvOpResultTypeInfo(builder, outputType, input, weight));
389  } else {
390  result.addTypes(outputType);
391  }
392 }
393 
394 /// The tosa.fully_connected op has its own builder as it does not have
395 /// strides/dilation/padding.
396 static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
397  Type outputType, Value input, Value weight,
398  Value bias) {
399 
400  result.addOperands({input, weight, bias});
401  auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
402  if (quantAttr) {
403  result.addAttribute("quantization_info", quantAttr);
404  result.addTypes(
405  buildConvOpResultTypeInfo(builder, outputType, input, weight));
406  } else {
407  result.addTypes(outputType);
408  }
409 }
410 
411 /// The tosa.matmul op is also intended to be generated where a fully_connected
412 /// op must be constructed where the weight is not a constant. In this case,
413 /// the fully_connected op must be expressed using matmul.
414 /// TODO: Add link to the leglization document explaining this.
416  OperationState &result, Type outputType,
417  Value a, Value b) {
418  result.addOperands({a, b});
419  auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
420 
421  if (quantAttr) {
422  result.addAttribute("quantization_info", quantAttr);
423 
424  auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
425  assert(inputType && "Input must be a shaped tensor type!");
426 
427  auto inputQType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
428  inputType.getElementType());
429  assert(inputQType && "Tensor must have quantized datatype!");
430 
431  unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
432 
433  auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
434  assert(outputShapedType && "Output must be a shaped type");
435 
436  IntegerType accElementType;
437  if (inputBits == 16)
438  accElementType = builder.getIntegerType(48);
439  else
440  accElementType = builder.getI32Type();
441  auto accType = outputShapedType.clone(accElementType);
442  result.addTypes(accType);
443  } else {
444  result.addTypes(outputType);
445  }
446 }
447 
448 /// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr
449 /// but avg_pool operator has its own builder as it has additional parameters
450 /// not part of the unary ops.
451 static void
453  Type outputType, Value input,
454  DenseArrayAttr kernel, DenseArrayAttr stride,
455  DenseArrayAttr pad, TypeAttr acc_type) {
456  result.addOperands(input);
457  result.addAttribute("kernel", kernel);
458  result.addAttribute("stride", stride);
459  result.addAttribute("pad", pad);
460  result.addAttribute("acc_type", acc_type);
461  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
462  if (quantAttr)
463  result.addAttribute("quantization_info", quantAttr);
464  result.types.push_back(outputType);
465 }
466 
467 /// This builder is called on single-parameter unary operators that have scale
468 /// relationship between their input and output, expressed by the
469 /// UnaryOpQuantizationAttr.
471  OperationState &result, Type outputType,
472  Value input) {
473  result.addOperands(input);
474  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
475  if (quantAttr)
476  result.addAttribute("quantization_info", quantAttr);
477  result.types.push_back(outputType);
478 }
479 
480 /// This builder is called on TOSA pad operator that needs to create its own
481 /// OptionalAttr quantization_attr parameter to scale the padding values
482 /// correctly. No pad_const is interpreted as zero-padding.
483 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
484  Type outputType, Value input,
485  Value paddings) {
486  result.addOperands({input, paddings});
487  auto quantAttr = buildPadOpQuantizationAttr(builder, input);
488  if (quantAttr)
489  result.addAttribute("quantization_info", quantAttr);
490  result.types.push_back(outputType);
491 }
492 
493 /// This builder is called on TOSA pad operator when an explicit pad_const
494 /// value is passed in. It also optionally constructs quantization_attr.
496  OperationState &result,
497  Type outputType, Value input,
498  Value paddings,
499  Value padConst) {
500  result.addOperands({input, paddings, padConst});
501  auto quantAttr = buildPadOpQuantizationAttr(builder, input);
502  if (quantAttr)
503  result.addAttribute("quantization_info", quantAttr);
504  result.types.push_back(outputType);
505 }
506 
507 //===----------------------------------------------------------------------===//
508 // TOSA Operator Return Type Inference.
509 //===----------------------------------------------------------------------===//
510 
512  SmallVector<int64_t> &outShape) {
513  int64_t outRank = 0;
514  for (int i = 0, e = operands.size(); i != e; ++i) {
515  auto shape = operands.getShape(i);
516  if (!shape.hasRank()) {
517  // TODO(jennik): Update function to have better case handling for invalid
518  // operands and for ranked tensors.
519  return failure();
520  }
521  outRank = std::max<int64_t>(outRank, shape.getRank());
522  }
523 
524  outShape.resize(outRank, 1);
525 
526  for (int i = 0, e = operands.size(); i != e; ++i) {
527  auto shape = operands.getShape(i);
528  auto rankDiff = outShape.size() - shape.getRank();
529 
530  for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
531  auto dim1 = outShape[i + rankDiff];
532  auto dim2 = shape.getDimSize(i);
533  auto resolvedDim = dim1;
534 
535  if (dim1 == 1) {
536  resolvedDim = dim2;
537  } else if (dim2 == 1) {
538  resolvedDim = dim1;
539  } else if (dim1 != dim2) {
540  return failure();
541  }
542  outShape[i + rankDiff] = resolvedDim;
543  }
544  }
545 
546  return success();
547 }
548 
549 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
550  MLIRContext *context, ::std::optional<Location> location,
551  ArgMaxOp::Adaptor adaptor,
552  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
553  ShapeAdaptor inputShape(adaptor.getInput().getType());
554  IntegerAttr axis = adaptor.getProperties().axis;
555  int32_t axisVal = axis.getValue().getSExtValue();
556 
557  if (!inputShape.hasRank()) {
558  inferredReturnShapes.push_back(ShapedTypeComponents());
559  return success();
560  }
561 
562  SmallVector<int64_t> outShape;
563  outShape.reserve(inputShape.getRank() - 1);
564  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
565  if (i == axisVal)
566  continue;
567  outShape.push_back(inputShape.getDimSize(i));
568  }
569 
570  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
571  return success();
572 }
573 
574 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
575  MLIRContext *context, ::std::optional<Location> location,
576  RFFT2dOp::Adaptor adaptor,
577  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
578  ShapeAdaptor inputShape(adaptor.getInput().getType());
579 
580  if (!inputShape.hasRank())
581  return failure();
582 
583  llvm::SmallVector<int64_t> outputShape;
584  outputShape.resize(3, ShapedType::kDynamic);
585  outputShape[0] = inputShape.getDimSize(0);
586  outputShape[1] = inputShape.getDimSize(1);
587  int64_t inWidth = inputShape.getDimSize(2);
588 
589  // Note that we can support this calculation symbolically
590  // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1]
591  if (inWidth != ShapedType::kDynamic)
592  outputShape[2] = inWidth / 2 + 1;
593 
594  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
595  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
596 
597  return success();
598 }
599 
600 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
601  MLIRContext *context, ::std::optional<Location> location,
602  FFT2dOp::Adaptor adaptor,
603  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
604  inferredReturnShapes.push_back(
605  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
606  inferredReturnShapes.push_back(
607  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
608  return success();
609 }
610 
611 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
612  MLIRContext *context, ::std::optional<Location> location,
613  ConcatOp::Adaptor adaptor,
614  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
615  // Infer all dimension sizes by reducing based on inputs.
616  const Properties &prop = adaptor.getProperties();
617  int32_t axis = prop.axis.getValue().getSExtValue();
618  llvm::SmallVector<int64_t> outputShape;
619  bool hasRankedInput = false;
620  for (auto operand : adaptor.getOperands()) {
621  ShapeAdaptor operandShape(operand.getType());
622  if (!operandShape.hasRank())
623  continue;
624 
625  // Copy the Operand's rank.
626  if (!hasRankedInput)
627  outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
628 
629  // Copy shapes until the dim is non-dynamic.
630  for (int i = 0, s = operandShape.getRank(); i < s; i++) {
631  if (i == axis || operandShape.isDynamicDim(i))
632  continue;
633  if (outputShape[i] == ShapedType::kDynamic)
634  outputShape[i] = operandShape.getDimSize(i);
635  if (outputShape[i] != operandShape.getDimSize(i))
636  return emitOptionalError(location,
637  "Cannot concat tensors with different sizes"
638  " on the non-axis dimension ",
639  i);
640  }
641 
642  hasRankedInput = true;
643  }
644  Type inputType =
645  llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
646  if (!hasRankedInput) {
647  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
648  return success();
649  }
650 
651  // Determine the dimension size along the concatenation axis.
652  int64_t concatDimSize = 0;
653  for (auto operand : adaptor.getOperands()) {
654  ShapeAdaptor operandShape(operand.getType());
655 
656  // We need to know the length of the concatenation axis of all inputs to
657  // determine the dimension size of the output shape.
658  if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
659  concatDimSize = ShapedType::kDynamic;
660  break;
661  }
662 
663  concatDimSize += operandShape.getDimSize(axis);
664  }
665 
666  outputShape[axis] = concatDimSize;
667 
668  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
669  return success();
670 }
671 
672 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
673  MLIRContext *context, ::std::optional<Location> location,
674  ValueShapeRange operands, DictionaryAttr attributes,
675  OpaqueProperties properties, RegionRange regions,
676  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
677  auto elementType = IntegerType::get(context, /*width=*/1);
678 
680  if (resolveBroadcastShape(operands, outShape).failed()) {
681  inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
682  return success();
683  }
684 
685  inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
686  return success();
687 }
688 
689 bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
690  if (l.size() != r.size() || l.size() != 1)
691  return false;
692  return succeeded(verifyCompatibleShape(l[0], r[0]));
693 }
694 
695 LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
696  MLIRContext *context, ::std::optional<Location> location,
697  FullyConnectedOp::Adaptor adaptor,
698  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
699  ShapeAdaptor inputShape(adaptor.getInput().getType());
700  ShapeAdaptor weightShape(adaptor.getWeight().getType());
701  ShapeAdaptor biasShape(adaptor.getBias().getType());
702 
703  // All shapes are dynamic.
704  SmallVector<int64_t> outShape;
705  outShape.resize(2, ShapedType::kDynamic);
706 
707  if (inputShape.hasRank()) {
708  outShape[0] = inputShape.getDimSize(0);
709  }
710 
711  if (weightShape.hasRank()) {
712  outShape[1] = weightShape.getDimSize(0);
713  }
714 
715  if (biasShape.hasRank()) {
716  outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0)
717  : outShape[1];
718  }
719 
720  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
721  return success();
722 }
723 
725 
726 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
727  MLIRContext *context, ::std::optional<Location> location,
728  MatMulOp::Adaptor adaptor,
729  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
730  ShapeAdaptor lhsShape(adaptor.getA().getType());
731  ShapeAdaptor rhsShape(adaptor.getB().getType());
732 
733  // All shapes are dynamic.
734  SmallVector<int64_t> outShape;
735  outShape.resize(3, ShapedType::kDynamic);
736 
737  if (lhsShape.hasRank()) {
738  outShape[0] = lhsShape.getDimSize(0);
739  outShape[1] = lhsShape.getDimSize(1);
740  }
741 
742  if (rhsShape.hasRank()) {
743  outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
744  : outShape[0];
745  outShape[2] = rhsShape.getDimSize(2);
746  }
747 
748  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
749  return success();
750 }
751 
752 LogicalResult tosa::PadOp::inferReturnTypeComponents(
753  MLIRContext *context, ::std::optional<Location> location,
754  PadOp::Adaptor adaptor,
755  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
756  ShapeAdaptor inputShape(adaptor.getInput1().getType());
757  ShapeAdaptor paddingShape(adaptor.getPadding().getType());
758  SmallVector<int64_t> outputShape;
759 
760  // If both inputs have unknown shape, we cannot determine the shape of the
761  // output.
762  if (!inputShape.hasRank() && !paddingShape.hasRank()) {
763  inferredReturnShapes.push_back(ShapedTypeComponents());
764  return success();
765  }
766 
767  // If the input rank is unknown we can info the output rank using the padding
768  // shape's first dim.
769  if (!inputShape.hasRank()) {
770  if (paddingShape.isDynamicDim(0)) {
771  inferredReturnShapes.push_back(ShapedTypeComponents());
772  return success();
773  }
774 
775  outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamic);
776  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
777  return success();
778  }
779 
780  DenseIntElementsAttr paddings;
781  // If the paddings value is not a constant, all dimensions must be dynamic.
782  if (!matchPattern(adaptor.getPadding(), m_Constant(&paddings))) {
783  outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
784  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
785  return success();
786  }
787 
788  SmallVector<int64_t> paddingValues;
789  for (auto val : paddings) {
790  paddingValues.push_back(val.getSExtValue());
791  }
792 
793  outputShape.reserve(inputShape.getRank());
794  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
795  if (inputShape.isDynamicDim(i)) {
796  outputShape.push_back(ShapedType::kDynamic);
797  continue;
798  }
799 
800  outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] +
801  paddingValues[i * 2 + 1]);
802  }
803 
804  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
805  return success();
806 }
807 
809  return to_vector(llvm::map_range(shape, [](int64_t dim) {
810  return dim == -1 ? ShapedType::kDynamic : dim;
811  }));
812 }
813 
814 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
815  MLIRContext *context, ::std::optional<Location> location,
816  SliceOp::Adaptor adaptor,
817  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
818  inferredReturnShapes.push_back(
819  ShapedTypeComponents(convertToMlirShape(adaptor.getSize())));
820  return success();
821 }
822 
824  auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
825  if (!inputType)
826  return success();
827 
828  if (static_cast<size_t>(inputType.getRank()) != getStart().size())
829  return emitOpError(
830  "length of start attribute is not equal rank of input shape");
831 
832  if (static_cast<size_t>(inputType.getRank()) != getSize().size())
833  return emitOpError(
834  "length of size attribute is not equal rank of input shape");
835 
836  return success();
837 }
838 
839 LogicalResult tosa::TableOp::inferReturnTypeComponents(
840  MLIRContext *context, ::std::optional<Location> location,
841  TableOp::Adaptor adaptor,
842  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
843  ShapeAdaptor inputShape(adaptor.getInput().getType());
844 
845  if (!inputShape.hasRank()) {
846  inferredReturnShapes.push_back(ShapedTypeComponents());
847  return success();
848  }
849 
850  inferredReturnShapes.resize(1);
851  inputShape.getDims(inferredReturnShapes[0]);
852  return success();
853 }
854 
855 LogicalResult tosa::TileOp::inferReturnTypeComponents(
856  MLIRContext *context, ::std::optional<Location> location,
857  TileOp::Adaptor adaptor,
858  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
859  ArrayRef<int64_t> multiples = adaptor.getMultiples();
860  ShapeAdaptor inputShape(adaptor.getInput1().getType());
861  SmallVector<int64_t> outputShape;
862  if (!inputShape.hasRank()) {
863  outputShape.resize(multiples.size(), ShapedType::kDynamic);
864  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
865  return success();
866  } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
867  return failure();
868 
869  // Any non dynamic dimension can be multiplied to a known size.
870  outputShape.reserve(multiples.size());
871  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
872  int64_t dim = inputShape.getDimSize(i);
873  if (dim != ShapedType::kDynamic)
874  dim *= multiples[i];
875  outputShape.push_back(dim);
876  }
877 
878  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
879  return success();
880 }
881 
883  ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
884  ShapedType outputType = llvm::cast<ShapedType>(getType());
885  auto multiples = getMultiples();
886 
887  if (inputType.hasRank()) {
888  if (static_cast<size_t>(inputType.getRank()) != multiples.size())
889  return emitOpError("expect 'multiples' array to have length ")
890  << inputType.getRank() << " but got " << multiples.size() << ".";
891  if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
892  return emitOpError("expect same input and output tensor rank.");
893  } else if (outputType.hasRank() &&
894  static_cast<size_t>(outputType.getRank()) != multiples.size())
895  return emitOpError("expect 'multiples' array to have length ")
896  << outputType.getRank() << " but got " << multiples.size() << ".";
897 
898  return success();
899 }
900 
901 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
902  if (l.size() != r.size() || l.size() != 1)
903  return false;
904  return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
905 }
906 
907 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
908  MLIRContext *context, ::std::optional<Location> location,
909  ReshapeOp::Adaptor adaptor,
910  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
911  ShapeAdaptor inputShape(adaptor.getInput1().getType());
912  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
913  llvm::SmallVector<int64_t> newShapeValue =
914  convertToMlirShape(adaptor.getNewShape());
915 
916  // We cannot infer from the total number of elements so we must take the
917  // shape attribute as exact.
918  if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
919  inferredReturnShapes.push_back(
920  ShapedTypeComponents(newShapeValue, inputType));
921  return success();
922  }
923 
924  // Determine the number of elements covered by the slice of all static
925  // dimensions. This allows us to infer the length of the remaining dynamic
926  // dimension.
927  int64_t numElements = inputShape.getNumElements();
928  int64_t staticMul = 1;
929  for (auto val : newShapeValue) {
930  if (!ShapedType::isDynamic(val)) {
931  staticMul *= val;
932  }
933  }
934 
935  // Determine the length of the dynamic dimension.
936  for (auto &val : newShapeValue) {
937  if (ShapedType::isDynamic(val))
938  val = numElements / staticMul;
939  }
940 
941  inferredReturnShapes.push_back(
942  ShapedTypeComponents(newShapeValue, inputType));
943  return success();
944 }
945 
947  ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
948  ShapedType outputType = llvm::cast<ShapedType>(getType());
949 
950  if (hasZeroDimension(inputType) || hasZeroDimension(outputType))
951  return emitOpError() << "tensor has a dimension with size zero. Each "
952  "dimension of a tensor must have size >= 1";
953 
954  if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
955  int64_t inputElementsNum = inputType.getNumElements();
956  int64_t outputElementsNum = outputType.getNumElements();
957  if (inputElementsNum != outputElementsNum) {
958  return emitOpError() << "Cannot reshape " << inputElementsNum
959  << " elements into " << outputElementsNum;
960  }
961  }
962  return mlir::success();
963 }
964 
965 LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int64_t> &perms) {
966  // Perms must be constants.
967  DenseIntElementsAttr permsAttr;
968  if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
969  return failure();
970 
971  // Transpose is not the identity transpose.
972  perms = llvm::to_vector(
973  llvm::map_range(permsAttr.getValues<APInt>(),
974  [](const APInt &val) { return val.getSExtValue(); }));
975 
976  return success();
977 }
978 
979 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
980  MLIRContext *context, ::std::optional<Location> location,
981  TransposeOp::Adaptor adaptor,
982  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
983  ShapeAdaptor inputShape(adaptor.getInput1().getType());
984  ShapeAdaptor permsShape(adaptor.getPerms().getType());
985 
986  // If input rank and permutation length is unknown, the output rank is
987  // unknown.
988  if (!inputShape.hasRank() || !permsShape.hasRank() ||
989  permsShape.isDynamicDim(0)) {
990  inferredReturnShapes.push_back(ShapedTypeComponents());
991  return success();
992  }
993 
994  // This would imply the number of permutations does not match the rank of the
995  // input which is illegal.
996  if (permsShape.getDimSize(0) != inputShape.getRank()) {
997  return failure();
998  }
999 
1000  // Without the input dims we cannot determine the output dim sizes but we
1001  // can determine the output rank.
1002  SmallVector<int64_t> outputShape;
1003  if (!inputShape.hasRank()) {
1004  outputShape.resize(permsShape.getDimSize(0), ShapedType::kDynamic);
1005  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1006  return success();
1007  }
1008 
1009  // Rank-0 means no permutations matter.
1010  if (inputShape.getRank() == 0) {
1011  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1012  return success();
1013  }
1014 
1015  // Check whether the input dimensions are all the same.
1016  bool allTheSame = true;
1017  for (int i = 1, s = inputShape.getRank(); i < s; i++) {
1018  if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
1019  allTheSame = false;
1020  break;
1021  }
1022  }
1023 
1024  // If all of the input dimensions are the same we don't care about the
1025  // permutation.
1026  if (allTheSame) {
1027  outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
1028  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1029  return success();
1030  }
1031 
1032  outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1033  // If the permuations are a constant we can directly determine the output
1034  // shape.
1035  DenseIntElementsAttr attr;
1036  if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
1037  attr.getType().getRank() == 1) {
1038  ShapeAdaptor permShape = attr;
1039  // Constant permutation must be the same length as the input rank.
1040  if (inputShape.getRank() != permShape.getRank())
1041  return emitOptionalError(location,
1042  "constant permutation must be the same length"
1043  " as the input rank");
1044 
1045  // Constant permutation values must be within the input rank.
1046  for (int i = 0, e = inputShape.getRank(); i < e; i++) {
1047  if (inputShape.getRank() <= permShape.getDimSize(i))
1048  return failure();
1049  }
1050 
1051  outputShape.reserve(inputShape.getRank());
1052  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1053  outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
1054  }
1055  }
1056 
1057  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1058  return success();
1059 }
1060 
1061 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
1062  MLIRContext *context, ::std::optional<Location> location,
1063  GatherOp::Adaptor adaptor,
1064  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1065  llvm::SmallVector<int64_t> outputShape;
1066  outputShape.resize(3, ShapedType::kDynamic);
1067 
1068  ShapeAdaptor valuesShape(adaptor.getValues().getType());
1069  if (valuesShape.hasRank()) {
1070  outputShape[0] = valuesShape.getDimSize(0);
1071  outputShape[2] = valuesShape.getDimSize(2);
1072  }
1073 
1074  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
1075  if (indicesShape.hasRank()) {
1076  if (outputShape[0] == ShapedType::kDynamic)
1077  outputShape[0] = indicesShape.getDimSize(0);
1078  if (outputShape[1] == ShapedType::kDynamic)
1079  outputShape[1] = indicesShape.getDimSize(1);
1080  }
1081 
1082  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1083  return success();
1084 }
1085 
1086 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
1087  MLIRContext *context, ::std::optional<Location> location,
1088  ResizeOp::Adaptor adaptor,
1089  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1090  llvm::SmallVector<int64_t, 4> outputShape;
1091  outputShape.resize(4, ShapedType::kDynamic);
1092 
1093  ShapeAdaptor inputShape(adaptor.getInput().getType());
1094  if (!inputShape.hasRank())
1095  return failure();
1096 
1097  outputShape[0] = inputShape.getDimSize(0);
1098  outputShape[3] = inputShape.getDimSize(3);
1099  int64_t inputHeight = inputShape.getDimSize(1);
1100  int64_t inputWidth = inputShape.getDimSize(2);
1101 
1102  if ((inputHeight == ShapedType::kDynamic) ||
1103  (inputWidth == ShapedType::kDynamic))
1104  return failure();
1105 
1106  llvm::ArrayRef<int64_t> scaleInt = adaptor.getScale();
1107  llvm::ArrayRef<int64_t> offsetInt = adaptor.getOffset();
1108  llvm::ArrayRef<int64_t> borderInt = adaptor.getBorder();
1109 
1110  // Compute the output shape based on attributes: scale, offset, and border.
1111  outputShape[1] =
1112  (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
1113  scaleInt[1]) +
1114  1;
1115 
1116  outputShape[2] =
1117  (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
1118  scaleInt[3]) +
1119  1;
1120 
1121  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1122  return success();
1123 }
1124 
1125 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
1126  MLIRContext *context, ::std::optional<Location> location,
1127  ScatterOp::Adaptor adaptor,
1128  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1129  llvm::SmallVector<int64_t> outputShape;
1130  outputShape.resize(3, ShapedType::kDynamic);
1131 
1132  ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
1133  if (valuesInShape.hasRank()) {
1134  outputShape[0] = valuesInShape.getDimSize(0);
1135  outputShape[1] = valuesInShape.getDimSize(1);
1136  outputShape[2] = valuesInShape.getDimSize(2);
1137  }
1138 
1139  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
1140  if (indicesShape.hasRank()) {
1141  if (outputShape[0] == ShapedType::kDynamic)
1142  outputShape[0] = indicesShape.getDimSize(0);
1143  }
1144 
1145  ShapeAdaptor inputShape(adaptor.getInput().getType());
1146  if (inputShape.hasRank()) {
1147  if (outputShape[0] == ShapedType::kDynamic)
1148  outputShape[0] = inputShape.getDimSize(0);
1149  if (outputShape[2] == ShapedType::kDynamic)
1150  outputShape[2] = inputShape.getDimSize(2);
1151  }
1152 
1153  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1154  return success();
1155 }
1156 
1158  ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
1159  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1160  int64_t axisVal = axis.getValue().getSExtValue();
1161  if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
1162  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1163  return success();
1164  }
1165 
1166  SmallVector<int64_t> outputShape;
1167  operandShape.getDims(outputShape);
1168  outputShape[axisVal] = 1;
1169  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1170  return success();
1171 }
1172 
1173 #define COMPATIBLE_RETURN_TYPES(OP) \
1174  bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
1175  if (l.size() != r.size() || l.size() != 1) \
1176  return false; \
1177  if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
1178  return false; \
1179  return succeeded(verifyCompatibleShape(l[0], r[0])); \
1180  }
1181 
1182 #define REDUCE_SHAPE_INFER(OP) \
1183  LogicalResult OP::inferReturnTypeComponents( \
1184  MLIRContext *context, ::std::optional<Location> location, \
1185  OP::Adaptor adaptor, \
1186  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
1187  Type inputType = \
1188  llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
1189  ShapeAdaptor inputShape(adaptor.getInput().getType()); \
1190  const Properties &prop = adaptor.getProperties(); \
1191  return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
1192  inferredReturnShapes); \
1193  } \
1194  COMPATIBLE_RETURN_TYPES(OP)
1195 
1196 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
1197 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
1198 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
1199 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
1200 REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
1201 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
1202 #undef REDUCE_SHAPE_INFER
1203 COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
1204 #undef COMPATIBLE_RETURN_TYPES
1205 
1206 template <typename T>
1208  // All TOSA reduce Ops have input, output and axis.
1209  TensorType inputType = op.getInput().getType();
1210  TensorType outputType = op.getOutput().getType();
1211  int32_t reduceAxis = op.getAxis();
1212 
1213  if (reduceAxis < 0) {
1214  op.emitOpError("reduce axis must not be negative");
1215  return failure();
1216  }
1217  if (inputType.hasRank()) {
1218  int64_t inputRank = inputType.getRank();
1219  // We allow for a special case where the input/output shape has rank 0 and
1220  // axis is also 0.
1221  if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
1222  op.emitOpError("expect input tensor rank (")
1223  << inputRank << ") to be larger than reduce axis (" << reduceAxis
1224  << ")";
1225  return failure();
1226  }
1227  }
1228  if (outputType.hasRank()) {
1229  int64_t outputRank = outputType.getRank();
1230  if (inputType.hasRank() && outputRank != inputType.getRank()) {
1231  op.emitOpError(
1232  "expect output tensor rank to be equal to input tensor rank");
1233  return failure();
1234  }
1235  if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
1236  op.emitOpError("expect output tensor rank (")
1237  << outputRank << ") to be larger than reduce axis (" << reduceAxis
1238  << ")";
1239  return failure();
1240  }
1241  // We can only verify the reduced dimension size to be 1 if this is not the
1242  // special case of output rank == 0.
1243  if (outputRank != 0) {
1244  auto outputShape = outputType.getShape();
1245  if (!outputType.isDynamicDim(reduceAxis) &&
1246  outputShape[reduceAxis] != 1) {
1247  op.emitOpError("expect reduced dimension size to be 1, got ")
1248  << outputShape[reduceAxis];
1249  return failure();
1250  }
1251  }
1252  }
1253  return success();
1254 }
1255 
1262 
1264  const ValueShapeRange &operands,
1265  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1266  llvm::SmallVector<int64_t> outShape;
1267  if (resolveBroadcastShape(operands, outShape).failed()) {
1268  inferredReturnShapes.push_back(ShapedTypeComponents());
1269  } else {
1270  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1271  }
1272  return success();
1273 }
1274 
1275 #define NARY_SHAPE_INFER(OP) \
1276  LogicalResult OP::inferReturnTypeComponents( \
1277  MLIRContext *context, ::std::optional<Location> location, \
1278  ValueShapeRange operands, DictionaryAttr attributes, \
1279  OpaqueProperties properties, RegionRange regions, \
1280  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
1281  return NAryInferReturnTypes(operands, inferredReturnShapes); \
1282  }
1283 
1284 NARY_SHAPE_INFER(tosa::AbsOp)
1285 NARY_SHAPE_INFER(tosa::AddOp)
1286 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
1287 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
1288 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
1289 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
1290 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
1291 NARY_SHAPE_INFER(tosa::CastOp)
1292 NARY_SHAPE_INFER(tosa::CeilOp)
1293 NARY_SHAPE_INFER(tosa::ClampOp)
1294 NARY_SHAPE_INFER(tosa::ClzOp)
1295 NARY_SHAPE_INFER(tosa::DivOp)
1296 NARY_SHAPE_INFER(tosa::ExpOp)
1297 NARY_SHAPE_INFER(tosa::FloorOp)
1298 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
1299 NARY_SHAPE_INFER(tosa::GreaterOp)
1300 NARY_SHAPE_INFER(tosa::IdentityOp)
1301 NARY_SHAPE_INFER(tosa::LogOp)
1302 NARY_SHAPE_INFER(tosa::LogicalAndOp)
1303 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
1304 NARY_SHAPE_INFER(tosa::LogicalNotOp)
1305 NARY_SHAPE_INFER(tosa::LogicalOrOp)
1306 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
1307 NARY_SHAPE_INFER(tosa::LogicalXorOp)
1308 NARY_SHAPE_INFER(tosa::MaximumOp)
1309 NARY_SHAPE_INFER(tosa::MinimumOp)
1310 NARY_SHAPE_INFER(tosa::MulOp)
1311 NARY_SHAPE_INFER(tosa::NegateOp)
1312 NARY_SHAPE_INFER(tosa::PowOp)
1313 NARY_SHAPE_INFER(tosa::ReciprocalOp)
1314 NARY_SHAPE_INFER(tosa::RescaleOp)
1315 NARY_SHAPE_INFER(tosa::ReverseOp)
1316 NARY_SHAPE_INFER(tosa::RsqrtOp)
1317 NARY_SHAPE_INFER(tosa::SelectOp)
1318 NARY_SHAPE_INFER(tosa::SubOp)
1319 NARY_SHAPE_INFER(tosa::TanhOp)
1320 NARY_SHAPE_INFER(tosa::ErfOp)
1321 NARY_SHAPE_INFER(tosa::SigmoidOp)
1322 #undef PRED_SHAPE_INFER
1323 
1325  ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
1326  ArrayRef<int64_t> pad,
1327  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1328  llvm::SmallVector<int64_t> outputShape;
1329  outputShape.resize(4, ShapedType::kDynamic);
1330 
1331  // We only know the rank if the input type is unranked.
1332  if (!inputShape) {
1333  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1334  return success();
1335  }
1336 
1337  // Batch and number of channels are identical for pooling layer.
1338  outputShape[0] = inputShape.getDimSize(0);
1339  outputShape[3] = inputShape.getDimSize(3);
1340 
1341  int64_t height = inputShape.getDimSize(1);
1342  int64_t width = inputShape.getDimSize(2);
1343 
1344  if (!ShapedType::isDynamic(height)) {
1345  int64_t padded = height + pad[0] + pad[1] - kernel[0];
1346  outputShape[1] = padded / stride[0] + 1;
1347  }
1348 
1349  if (!ShapedType::isDynamic(width)) {
1350  int64_t padded = width + pad[2] + pad[3] - kernel[1];
1351  outputShape[2] = padded / stride[1] + 1;
1352  }
1353 
1354  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1355  return success();
1356 }
1357 
1358 LogicalResult Conv2DOp::inferReturnTypeComponents(
1359  MLIRContext *context, ::std::optional<Location> location,
1360  Conv2DOp::Adaptor adaptor,
1361  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1362  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
1363 
1364  int64_t inputWidth = ShapedType::kDynamic;
1365  int64_t inputHeight = ShapedType::kDynamic;
1366  int64_t weightWidth = ShapedType::kDynamic;
1367  int64_t weightHeight = ShapedType::kDynamic;
1368 
1369  // Input shape describes input width/height and batch.
1370 
1371  ShapeAdaptor inputShape(adaptor.getInput().getType());
1372  if (inputShape.hasRank()) {
1373  outputShape[0] = inputShape.getDimSize(0);
1374  inputHeight = inputShape.getDimSize(1);
1375  inputWidth = inputShape.getDimSize(2);
1376  }
1377 
1378  // Weight shapes describes the filter width/height and the output channels.
1379  ShapeAdaptor weightShape(adaptor.getWeight().getType());
1380  if (weightShape.hasRank()) {
1381  outputShape[3] = weightShape.getDimSize(0);
1382  weightHeight = weightShape.getDimSize(1);
1383  weightWidth = weightShape.getDimSize(2);
1384  }
1385 
1386  // Bias shape can describe the output channels.
1387  ShapeAdaptor biasShape(adaptor.getBias().getType());
1388  if (biasShape.hasRank()) {
1389  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1390  ? biasShape.getDimSize(0)
1391  : outputShape[3];
1392  }
1393 
1394  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
1395  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1396  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
1397 
1398  if (!ShapedType::isDynamic(inputHeight) &&
1399  !ShapedType::isDynamic(weightHeight)) {
1400  int64_t inputSize = inputHeight + padding[0] + padding[1];
1401  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1402  int64_t unstridedResult = inputSize - filterSize + 1;
1403  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1404  }
1405 
1406  if (!ShapedType::isDynamic(inputWidth) &&
1407  !ShapedType::isDynamic(weightWidth)) {
1408  int64_t inputSize = inputWidth + padding[2] + padding[3];
1409  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1410  int64_t unstridedResult = inputSize - filterSize + 1;
1411  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1412  }
1413 
1414  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1415  return success();
1416 }
1417 
1418 LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); }
1419 
1420 LogicalResult Conv3DOp::inferReturnTypeComponents(
1421  MLIRContext *context, ::std::optional<Location> location,
1422  Conv3DOp::Adaptor adaptor,
1423  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1424  llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
1425 
1426  int64_t inputWidth = ShapedType::kDynamic;
1427  int64_t inputHeight = ShapedType::kDynamic;
1428  int64_t inputDepth = ShapedType::kDynamic;
1429 
1430  int64_t weightWidth = ShapedType::kDynamic;
1431  int64_t weightHeight = ShapedType::kDynamic;
1432  int64_t weightDepth = ShapedType::kDynamic;
1433 
1434  // Input shape describes input width/height and batch.
1435  ShapeAdaptor inputShape(adaptor.getInput().getType());
1436  if (inputShape.hasRank()) {
1437  outputShape[0] = inputShape.getDimSize(0);
1438  inputDepth = inputShape.getDimSize(1);
1439  inputHeight = inputShape.getDimSize(2);
1440  inputWidth = inputShape.getDimSize(3);
1441  }
1442 
1443  // Weight shapes describes the filter width/height and the output channels.
1444  ShapeAdaptor weightShape(adaptor.getWeight().getType());
1445  if (weightShape.hasRank()) {
1446  outputShape[4] = weightShape.getDimSize(0);
1447  weightDepth = weightShape.getDimSize(1);
1448  weightHeight = weightShape.getDimSize(2);
1449  weightWidth = weightShape.getDimSize(3);
1450  }
1451 
1452  // Bias shape can describe the output channels.
1453  ShapeAdaptor biasShape(adaptor.getBias().getType());
1454  if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
1455  outputShape[4] = biasShape.getDimSize(0);
1456  }
1457 
1458  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
1459  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1460  llvm::ArrayRef<int64_t> pad = adaptor.getPad();
1461 
1462  if (!ShapedType::isDynamic(inputDepth) &&
1463  !ShapedType::isDynamic(weightDepth)) {
1464  int32_t inputSize = inputDepth + pad[0] + pad[1];
1465  int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
1466  int32_t unstridedResult = inputSize - filterSize + 1;
1467  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1468  }
1469 
1470  if (!ShapedType::isDynamic(inputHeight) &&
1471  !ShapedType::isDynamic(weightHeight)) {
1472  int32_t inputSize = inputHeight + pad[2] + pad[3];
1473  int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
1474  int32_t unstridedResult = inputSize - filterSize + 1;
1475  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1476  }
1477 
1478  if (!ShapedType::isDynamic(inputWidth) &&
1479  !ShapedType::isDynamic(weightWidth)) {
1480  int32_t inputSize = inputWidth + pad[4] + pad[5];
1481  int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
1482  int32_t unstridedResult = inputSize - filterSize + 1;
1483  outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
1484  }
1485 
1486  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1487  return success();
1488 }
1489 
1490 LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); }
1491 
1492 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
1493  MLIRContext *context, ::std::optional<Location> location,
1494  AvgPool2dOp::Adaptor adaptor,
1495  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1496  ShapeAdaptor inputShape(adaptor.getInput().getType());
1497  const Properties &prop = adaptor.getProperties();
1498  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
1499  inferredReturnShapes);
1500 }
1501 
1502 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
1503  MLIRContext *context, ::std::optional<Location> location,
1504  MaxPool2dOp::Adaptor adaptor,
1505  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1506  ShapeAdaptor inputShape(adaptor.getInput().getType());
1507  const Properties &prop = adaptor.getProperties();
1508  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
1509  inferredReturnShapes);
1510 }
1511 
1512 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
1513  MLIRContext *context, ::std::optional<Location> location,
1514  DepthwiseConv2DOp::Adaptor adaptor,
1515  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1516  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
1517 
1518  int64_t inputWidth = ShapedType::kDynamic;
1519  int64_t inputHeight = ShapedType::kDynamic;
1520  int64_t inputChannels = ShapedType::kDynamic;
1521 
1522  int64_t weightWidth = ShapedType::kDynamic;
1523  int64_t weightHeight = ShapedType::kDynamic;
1524  int64_t depthChannels = ShapedType::kDynamic;
1525 
1526  // Input shape describes input width/height and batch.
1527  ShapeAdaptor inputShape(adaptor.getInput().getType());
1528  if (inputShape.hasRank()) {
1529  outputShape[0] = inputShape.getDimSize(0);
1530  inputHeight = inputShape.getDimSize(1);
1531  inputWidth = inputShape.getDimSize(2);
1532  inputChannels = inputShape.getDimSize(3);
1533  }
1534 
1535  // Weight shapes describes the filter width/height and the output channels.
1536  ShapeAdaptor weightShape(adaptor.getWeight().getType());
1537  if (weightShape.hasRank()) {
1538  weightHeight = weightShape.getDimSize(0);
1539  weightWidth = weightShape.getDimSize(1);
1540  inputChannels = ShapedType::isDynamic(inputChannels)
1541  ? weightShape.getDimSize(2)
1542  : inputChannels;
1543  depthChannels = weightShape.getDimSize(3);
1544  }
1545 
1546  // If both inputChannels and depthChannels are available we can determine
1547  // the output channels.
1548  if (!ShapedType::isDynamic(inputChannels) &&
1549  !ShapedType::isDynamic(depthChannels)) {
1550  outputShape[3] = inputChannels * depthChannels;
1551  }
1552 
1553  // Bias shape can describe the output channels.
1554  ShapeAdaptor biasShape(adaptor.getBias().getType());
1555  if (biasShape.hasRank()) {
1556  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1557  ? biasShape.getDimSize(0)
1558  : outputShape[3];
1559  }
1560 
1561  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
1562  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
1563  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1564 
1565  if (!ShapedType::isDynamic(inputHeight) &&
1566  !ShapedType::isDynamic(weightHeight)) {
1567  int64_t inputSize = inputHeight + padding[0] + padding[1];
1568  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1569  int64_t unstridedResult = inputSize - filterSize + 1;
1570  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1571  }
1572 
1573  if (!ShapedType::isDynamic(inputWidth) &&
1574  !ShapedType::isDynamic(weightWidth)) {
1575  int64_t inputSize = inputWidth + padding[2] + padding[3];
1576  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1577  int64_t unstridedResult = inputSize - filterSize + 1;
1578  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1579  }
1580 
1581  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1582  return success();
1583 }
1584 
1586 
1587 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
1588  MLIRContext *context, ::std::optional<Location> location,
1589  TransposeConv2DOp::Adaptor adaptor,
1590  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1591  // outputShape is mutable.
1592  llvm::SmallVector<int64_t> outputShape =
1593  convertToMlirShape(adaptor.getOutShape());
1594 
1595  int64_t inputWidth = ShapedType::kDynamic;
1596  int64_t inputHeight = ShapedType::kDynamic;
1597  int64_t weightWidth = ShapedType::kDynamic;
1598  int64_t weightHeight = ShapedType::kDynamic;
1599 
1600  // Input shape describes input width/height and batch.
1601  ShapeAdaptor inputShape(adaptor.getInput().getType());
1602  if (inputShape.hasRank()) {
1603  outputShape[0] = ShapedType::isDynamic(outputShape[0])
1604  ? inputShape.getDimSize(0)
1605  : outputShape[0];
1606  inputHeight = inputShape.getDimSize(1);
1607  inputWidth = inputShape.getDimSize(2);
1608  }
1609 
1610  // Weight shapes describes the filter width/height and the output channels.
1611  ShapeAdaptor weightShape(adaptor.getFilter().getType());
1612  if (weightShape.hasRank()) {
1613  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1614  ? weightShape.getDimSize(0)
1615  : outputShape[3];
1616  weightHeight = weightShape.getDimSize(1);
1617  weightWidth = weightShape.getDimSize(2);
1618  }
1619 
1620  // Bias shape can describe the output channels.
1621  ShapeAdaptor biasShape(adaptor.getInput().getType());
1622  if (biasShape.hasRank()) {
1623  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1624  ? biasShape.getDimSize(0)
1625  : outputShape[3];
1626  }
1627 
1628  llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
1629  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1630 
1631  if (!ShapedType::isDynamic(inputHeight) &&
1632  !ShapedType::isDynamic(weightHeight)) {
1633  int64_t calculateSize =
1634  (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
1635  outputShape[1] =
1636  ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
1637  }
1638 
1639  if (!ShapedType::isDynamic(inputWidth) &&
1640  !ShapedType::isDynamic(weightWidth)) {
1641  int64_t calculateSize =
1642  (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
1643  outputShape[2] =
1644  ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
1645  }
1646 
1647  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1648  return success();
1649 }
1650 
1651 LogicalResult IfOp::inferReturnTypeComponents(
1652  MLIRContext *context, ::std::optional<Location> location,
1653  IfOp::Adaptor adaptor,
1654  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1656  for (Region *region : adaptor.getRegions()) {
1657  for (auto &block : *region)
1658  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1659  yieldOps.push_back(returnOp);
1660  }
1661 
1662  if (yieldOps.empty())
1663  return failure();
1664 
1665  // Get the initial type information for the yield op.
1666  llvm::SmallVector<ValueKnowledge> resultKnowledge;
1667  resultKnowledge.reserve(yieldOps.front().getNumOperands());
1668  for (auto operand : yieldOps.front().getOperands()) {
1669  resultKnowledge.push_back(
1670  ValueKnowledge::getKnowledgeFromType(operand.getType()));
1671  }
1672 
1673  for (auto yieldOp : yieldOps) {
1674  if (resultKnowledge.size() != yieldOp.getNumOperands())
1675  return failure();
1676 
1677  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
1678  int32_t index = it.index();
1679  auto meet = ValueKnowledge::meet(
1680  resultKnowledge[index],
1681  ValueKnowledge::getKnowledgeFromType(it.value().getType()));
1682  if (!meet)
1683  continue;
1684  resultKnowledge[index] = meet;
1685  }
1686  }
1687 
1688  for (const ValueKnowledge &result : resultKnowledge) {
1689  inferredReturnShapes.push_back(result.getShapedTypeComponents());
1690  }
1691 
1692  return success();
1693 }
1694 
1695 LogicalResult WhileOp::inferReturnTypeComponents(
1696  MLIRContext *context, ::std::optional<Location> location,
1697  WhileOp::Adaptor adaptor,
1698  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1700  for (auto &block : adaptor.getBody())
1701  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1702  yieldOps.push_back(returnOp);
1703 
1704  // TOSA's while must have a tosa.yield as its terminator. If not found this
1705  // tosa.while is invalid.
1706  if (yieldOps.empty())
1707  return failure();
1708 
1709  // Get the initial type information from the operand types.
1710  llvm::SmallVector<ValueKnowledge> resultKnowledge;
1711  resultKnowledge.reserve(yieldOps.front().getNumOperands());
1712  for (auto operand : yieldOps.front().getOperands()) {
1713  resultKnowledge.push_back(
1714  ValueKnowledge::getKnowledgeFromType(operand.getType()));
1715  }
1716 
1717  for (auto yieldOp : yieldOps) {
1718  if (resultKnowledge.size() != yieldOp.getNumOperands())
1719  return failure();
1720 
1721  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
1722  int32_t index = it.index();
1723  if (auto meet = ValueKnowledge::meet(
1724  resultKnowledge[index],
1725  ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
1726  resultKnowledge[index] = meet;
1727  }
1728  }
1729  }
1730 
1731  for (const ValueKnowledge &result : resultKnowledge) {
1732  inferredReturnShapes.push_back(result.getShapedTypeComponents());
1733  }
1734 
1735  return success();
1736 }
1737 
1738 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
1739  if (auto vt = llvm::dyn_cast<VectorType>(getType()))
1740  return llvm::to_vector<4>(vt.getShape());
1741  return std::nullopt;
1742 }
1743 
1744 // parse and print of IfOp refer to the implementation of SCF dialect.
1746  // Create the regions for 'then'.
1747  result.regions.reserve(2);
1748  Region *thenRegion = result.addRegion();
1749  Region *elseRegion = result.addRegion();
1750 
1751  auto &builder = parser.getBuilder();
1753  // Create a i1 tensor type for the boolean condition.
1754  Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
1755  if (parser.parseOperand(cond) ||
1756  parser.resolveOperand(cond, i1Type, result.operands))
1757  return failure();
1758  // Parse optional results type list.
1759  if (parser.parseOptionalArrowTypeList(result.types))
1760  return failure();
1761  // Parse the 'then' region.
1762  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
1763  return failure();
1764 
1765  // If we find an 'else' keyword then parse the 'else' region.
1766  if (!parser.parseOptionalKeyword("else")) {
1767  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
1768  return failure();
1769  }
1770 
1771  // Parse the optional attribute list.
1772  if (parser.parseOptionalAttrDict(result.attributes))
1773  return failure();
1774  return success();
1775 }
1776 
1777 void IfOp::print(OpAsmPrinter &p) {
1778  bool printBlockTerminators = false;
1779 
1780  p << " " << getCond();
1781  if (!getResults().empty()) {
1782  p << " -> (" << getResultTypes() << ")";
1783  // Print yield explicitly if the op defines values.
1784  printBlockTerminators = true;
1785  }
1786  p << ' ';
1787  p.printRegion(getThenBranch(),
1788  /*printEntryBlockArgs=*/false,
1789  /*printBlockTerminators=*/printBlockTerminators);
1790 
1791  // Print the 'else' regions if it exists and has a block.
1792  auto &elseRegion = getElseBranch();
1793  if (!elseRegion.empty()) {
1794  p << " else ";
1795  p.printRegion(elseRegion,
1796  /*printEntryBlockArgs=*/false,
1797  /*printBlockTerminators=*/printBlockTerminators);
1798  }
1799 
1800  p.printOptionalAttrDict((*this)->getAttrs());
1801 }
1802 
1804  TensorType inputType = getInput().getType();
1805  TensorType outputType = getOutput().getType();
1806  int32_t reverseAxis = getAxis();
1807 
1808  if (reverseAxis < 0)
1809  return emitOpError("expected non-negative reverse axis");
1810  if (inputType.hasRank()) {
1811  int64_t inputRank = inputType.getRank();
1812  // We allow for a special case where the input/output shape has rank 0 and
1813  // axis is also 0.
1814  if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
1815  return emitOpError("expect input tensor rank (")
1816  << inputRank << ") to be larger than reverse axis (" << reverseAxis
1817  << ")";
1818  }
1819  if (outputType.hasRank()) {
1820  int64_t outputRank = outputType.getRank();
1821  if (inputType.hasRank() && outputRank != inputType.getRank())
1822  return emitOpError(
1823  "expect output tensor rank to be equal to input tensor rank");
1824  if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
1825  return emitOpError("expect output tensor rank (")
1826  << outputRank << ") to be larger than reverse axis ("
1827  << reverseAxis << ")";
1828  }
1829  return success();
1830 }
1831 
1832 // parse and print of WhileOp refer to the implementation of SCF dialect.
1836  Region *cond = result.addRegion();
1837  Region *body = result.addRegion();
1838 
1839  OptionalParseResult listResult =
1840  parser.parseOptionalAssignmentList(regionArgs, operands);
1841  if (listResult.has_value() && failed(listResult.value()))
1842  return failure();
1843 
1844  FunctionType functionType;
1845  SMLoc typeLoc = parser.getCurrentLocation();
1846  if (failed(parser.parseColonType(functionType)))
1847  return failure();
1848 
1849  result.addTypes(functionType.getResults());
1850 
1851  if (functionType.getNumInputs() != operands.size()) {
1852  return parser.emitError(typeLoc)
1853  << "expected as many input types as operands "
1854  << "(expected " << operands.size() << " got "
1855  << functionType.getNumInputs() << ")";
1856  }
1857 
1858  // Resolve input operands.
1859  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
1860  parser.getCurrentLocation(),
1861  result.operands)))
1862  return failure();
1863 
1864  // Propagate the types into the region arguments.
1865  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
1866  regionArgs[i].type = functionType.getInput(i);
1867 
1868  return failure(parser.parseRegion(*cond, regionArgs) ||
1869  parser.parseKeyword("do") || parser.parseRegion(*body) ||
1871 }
1872 
1874  Block::BlockArgListType blocksArgs,
1875  ValueRange initializers,
1876  StringRef prefix = "") {
1877  assert(blocksArgs.size() == initializers.size() &&
1878  "expected same length of arguments and initializers");
1879  if (initializers.empty())
1880  return;
1881 
1882  parser << prefix << '(';
1883  llvm::interleaveComma(
1884  llvm::zip(blocksArgs, initializers), parser,
1885  [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
1886  parser << ")";
1887 }
1888 
1889 void WhileOp::print(OpAsmPrinter &parser) {
1890  printInitializationList(parser, getCond().front().getArguments(), getInputs(),
1891  " ");
1892  parser << " : ";
1893  parser.printFunctionalType(getInputs().getTypes(), getResults().getTypes());
1894  parser << ' ';
1895  parser.printRegion(getCond(), /*printEntryBlockArgs=*/false);
1896  parser << " do ";
1897  parser.printRegion(getBody());
1898  parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
1899 }
1900 
1901 //===----------------------------------------------------------------------===//
1902 // TOSA Attribute Definitions.
1903 //===----------------------------------------------------------------------===//
1904 
1905 #define GET_ATTRDEF_CLASSES
1906 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
1907 
1908 //===----------------------------------------------------------------------===//
1909 // TOSA Operator Definitions.
1910 //===----------------------------------------------------------------------===//
1911 
1912 #define GET_OP_CLASSES
1913 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition: TosaOps.cpp:415
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:1157
#define REDUCE_SHAPE_INFER(OP)
Definition: TosaOps.cpp:1182
static bool hasZeroDimension(ShapedType shapedType)
Definition: TosaOps.cpp:196
static LogicalResult verifyConvOp(T op)
Definition: TosaOps.cpp:212
static void buildUnaryOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter unary operators that have scale relationship between their...
Definition: TosaOps.cpp:470
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:1324
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition: TosaOps.cpp:483
static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias)
The tosa.fully_connected op has its own builder as it does not have strides/dilation/padding.
Definition: TosaOps.cpp:396
static LogicalResult verifyReduceOp(T op)
Definition: TosaOps.cpp:1207
#define NARY_SHAPE_INFER(OP)
Definition: TosaOps.cpp:1275
static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings, Value padConst)
This builder is called on TOSA pad operator when an explicit pad_const value is passed in.
Definition: TosaOps.cpp:495
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr acc_type)
Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr but avg_pool operator has...
Definition: TosaOps.cpp:452
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:1263
#define COMPATIBLE_RETURN_TYPES(OP)
Definition: TosaOps.cpp:1173
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition: TosaOps.cpp:511
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition: TosaOps.cpp:375
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition: TosaOps.cpp:808
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition: TosaOps.cpp:353
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Definition: TosaOps.cpp:1873
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast_or_null() const
Definition: Attributes.h:180
U dyn_cast() const
Definition: Attributes.h:175
IntegerType getI32Type()
Definition: Builders.cpp:83
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:91
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:206
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:334
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:91
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:59
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:19
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Definition: QuantUtils.cpp:120
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
Definition: QuantUtils.cpp:239
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
Definition: QuantUtils.cpp:219
ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &attr)
Definition: TosaOps.cpp:153
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
Definition: QuantUtils.cpp:164
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, Attribute attr)
Definition: TosaOps.cpp:175
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Definition: QuantUtils.cpp:193
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:491
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
Definition: ShapeUtils.h:33
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45