MLIR  20.0.0git
TosaOps.cpp
Go to the documentation of this file.
1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://developer.mlplatform.org/w/tosa/
12 //
13 //===----------------------------------------------------------------------===//
14 
22 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 using namespace mlir;
34 using namespace mlir::tosa;
35 
36 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
37 
38 //===----------------------------------------------------------------------===//
39 // Tosa dialect interface includes.
40 //===----------------------------------------------------------------------===//
41 
42 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
43 
44 namespace {
45 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
46 
47 //===----------------------------------------------------------------------===//
48 // Dialect Function Inliner Interface.
49 //===----------------------------------------------------------------------===//
50 struct TosaInlinerInterface : public DialectInlinerInterface {
52 
53  //===--------------------------------------------------------------------===//
54  // Analysis Hooks.
55  //===--------------------------------------------------------------------===//
56 
57  /// All operations can be inlined by default.
58  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
59  IRMapping &map) const final {
60  return true;
61  }
62 
63  /// All regions with If and While parent operators can be inlined.
64  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
65  IRMapping &map) const final {
66  return (isa<tosa::IfOp>(dest->getParentOp()) ||
67  isa<tosa::WhileOp>(dest->getParentOp()));
68  }
69 };
70 
71 /// This class implements the bytecode interface for the Tosa dialect.
72 struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
73  TosaDialectBytecodeInterface(Dialect *dialect)
74  : BytecodeDialectInterface(dialect) {}
75 
76  //===--------------------------------------------------------------------===//
77  // Attributes
78 
79  Attribute readAttribute(DialectBytecodeReader &reader) const override {
80  return ::readAttribute(getContext(), reader);
81  }
82 
83  LogicalResult writeAttribute(Attribute attr,
84  DialectBytecodeWriter &writer) const override {
85  return ::writeAttribute(attr, writer);
86  }
87 
88  //===--------------------------------------------------------------------===//
89  // Types
90 
91  Type readType(DialectBytecodeReader &reader) const override {
92  return ::readType(getContext(), reader);
93  }
94 
95  LogicalResult writeType(Type type,
96  DialectBytecodeWriter &writer) const override {
97  return ::writeType(type, writer);
98  }
99 
100  void writeVersion(DialectBytecodeWriter &writer) const final {
101  // TODO: Populate.
102  }
103 
104  std::unique_ptr<DialectVersion>
105  readVersion(DialectBytecodeReader &reader) const final {
106  // TODO: Populate
107  reader.emitError("Dialect does not support versioning");
108  return nullptr;
109  }
110 
111  LogicalResult upgradeFromVersion(Operation *topLevelOp,
112  const DialectVersion &version) const final {
113  return success();
114  }
115 };
116 
117 } // namespace
118 
119 //===----------------------------------------------------------------------===//
120 // TOSA control flow support.
121 //===----------------------------------------------------------------------===//
122 
123 /// Returns the while loop body.
124 SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
125 
126 //===----------------------------------------------------------------------===//
127 // Tosa dialect initialization.
128 //===----------------------------------------------------------------------===//
129 
130 void TosaDialect::initialize() {
131  addOperations<
132 #define GET_OP_LIST
133 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
134  >();
135  addAttributes<
136 #define GET_ATTRDEF_LIST
137 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
138  >();
139  addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
140  declarePromisedInterfaces<
141  mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
142  ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
143  LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
144  LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
145  BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
146  NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
147  GreaterEqualOp, MatMulOp>();
148 }
149 
151  Type type, Location loc) {
152  // Tosa dialect constants only support ElementsAttr unlike standard dialect
153  // constant which supports all attributes.
154  if (llvm::isa<ElementsAttr>(value))
155  return builder.create<tosa::ConstOp>(loc, type,
156  llvm::cast<ElementsAttr>(value));
157  return nullptr;
158 }
159 
160 //===----------------------------------------------------------------------===//
161 // Parsers and printers
162 //===----------------------------------------------------------------------===//
163 
164 ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
165  Attribute &attr) {
166  if (succeeded(parser.parseOptionalEqual())) {
167  if (failed(parser.parseAttribute(attr))) {
168  return parser.emitError(parser.getCurrentLocation())
169  << "expected attribute";
170  }
171  if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
172  typeAttr = TypeAttr::get(typedAttr.getType());
173  }
174  return success();
175  }
176 
177  Type type;
178  if (failed(parser.parseColonType(type))) {
179  return parser.emitError(parser.getCurrentLocation()) << "expected type";
180  }
181  typeAttr = TypeAttr::get(type);
182 
183  return success();
184 }
185 
187  Attribute attr) {
188  bool needsSpace = false;
189  auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
190  if (!typedAttr || typedAttr.getType() != type.getValue()) {
191  p << ": ";
192  p.printAttribute(type);
193  needsSpace = true; // subsequent attr value needs a space separator
194  }
195  if (attr) {
196  if (needsSpace)
197  p << ' ';
198  p << "= ";
199  p.printAttribute(attr);
200  }
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // TOSA Operator Verifiers.
205 //===----------------------------------------------------------------------===//
206 
207 template <typename T>
208 static LogicalResult verifyConvOp(T op) {
209  // All TOSA conv ops have an input() and weight().
210  auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
211  auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
212 
213  // Must be ranked tensor types
214  if (!inputType) {
215  op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
216  return failure();
217  }
218  if (!weightType) {
219  op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
220  return failure();
221  }
222 
223  auto inputEType = inputType.getElementType();
224  auto weightEType = weightType.getElementType();
225 
226  bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
227  bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
228 
229  // Either both must be quantized or both unquantized.
230  if (inputIsQuant != weightIsQuant) {
231  op.emitOpError(
232  "expect both input and weight to be float or not together, got ")
233  << inputEType << " and " << weightEType;
234  return failure();
235  }
236 
237  // Quantized type must have constructed the quantizationattr, and unquantized
238  // types should not have a quantizationattr.
239  if ((inputIsQuant && !op.getQuantizationInfo()) ||
240  (!inputIsQuant && op.getQuantizationInfo())) {
241  op.emitOpError("quantizationattr is required for quantized type, and not "
242  "allowed for float type");
243  return failure();
244  }
245  return success();
246 }
247 
248 LogicalResult tosa::ConstOp::verify() {
249 
250  auto attrType = llvm::dyn_cast<TensorType>(getValueAttr().getType());
251  auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
252 
253  if (!attrType || !outputType) {
254  emitOpError("expected tensors for attr/result type");
255  return failure();
256  }
257 
258  if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
259  outputType.getElementType())) {
260  if (result.getStorageType() == attrType.getElementType())
261  return success();
262  }
263 
264  if (attrType.getElementType() != outputType.getElementType()) {
265  emitOpError("expected same attr/result element types");
266  return failure();
267  }
268 
269  return success();
270 }
271 
272 LogicalResult tosa::ArgMaxOp::verify() {
273  // Ensure output is of 32-bit integer
274  const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
275  if (!resultETy.isIntOrIndex())
276  return emitOpError("result tensor is not of integer type");
277 
278  // Ensure axis is within the tensor rank
279  const auto inputType = llvm::cast<ShapedType>(getInput().getType());
280  const int64_t axis = getAxisAttr().getInt();
281  if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
282  return emitOpError("specified axis is outside the rank of the tensor");
283 
284  return success();
285 }
286 
287 LogicalResult tosa::AvgPool2dOp::verify() {
288  auto inputType = llvm::cast<ShapedType>(getInput().getType());
289 
290  auto inputETy = inputType.getElementType();
291  auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
292 
293  if (auto quantType =
294  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
295  inputETy = quantType.getStorageType();
296 
297  if (auto quantType =
298  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy))
299  resultETy = quantType.getStorageType();
300 
301  auto accType = getAccType();
302  if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
303  return emitOpError("accumulator type for integer tensor is not i32");
304 
305  if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
306  return emitOpError("accumulator type for f16 tensor is not f16/f32");
307 
308  if (inputETy.isBF16() && !accType.isF32())
309  return emitOpError("accumulator type for bf16 tensor is not f32");
310 
311  if (inputETy.isF32() && !accType.isF32())
312  return emitOpError("accumulator type for f32 tensor is not f32");
313 
314  if ((inputETy.isF32() && resultETy.isF32()) ||
315  (inputETy.isF16() && resultETy.isF16()) ||
316  (inputETy.isBF16() && resultETy.isBF16()) ||
317  (inputETy.isInteger(8) && resultETy.isInteger(8)) ||
318  (inputETy.isInteger(16) && resultETy.isInteger(16)))
319  return success();
320 
321  return emitOpError("input/output element types are incompatible.");
322 }
323 
324 LogicalResult tosa::ClampOp::verify() {
325  mlir::Type inputETy =
326  llvm::cast<ShapedType>(getInput().getType()).getElementType();
327  if (auto quantType =
328  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
329  inputETy = quantType.getStorageType();
330  }
331  mlir::Type maxFpType = getMaxFpAttr().getType();
332  mlir::Type minFpType = getMinFpAttr().getType();
333  mlir::Type outputETy =
334  llvm::cast<ShapedType>(getOutput().getType()).getElementType();
335  if (auto quantType =
336  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
337  outputETy = quantType.getStorageType();
338  }
339  unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
340 
341  if (inputETy != outputETy)
342  return emitOpError("input/output element types are incompatible.");
343 
344  // If input datatype is float, check that the two min/max_fp attributes
345  // share the same type and that their type is either the same of the input's
346  // datatype, or a float type whose bitwidth > input datatype bitwidth.
347  if (!inputETy.isInteger(dataTypeBitWidth)) {
348  if (((maxFpType != minFpType) ||
349  (maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <=
350  inputETy.getIntOrFloatBitWidth())))
351  return emitOpError("min/max attributes types are incompatible with "
352  "input/output element types.");
353  }
354 
355  return success();
356 }
357 
358 //===----------------------------------------------------------------------===//
359 // TOSA Operator Quantization Builders.
360 //===----------------------------------------------------------------------===//
361 
362 /// This builder is called on all convolution operators except TransposeConv,
363 /// which has specialized output shape semantics. The builder also defines the
364 /// bitwidth of the output given the bit width of the input & weight content.
365 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
366  Type outputType, Value input, Value weight,
367  Value bias, DenseI64ArrayAttr pad,
368  DenseI64ArrayAttr stride,
369  DenseI64ArrayAttr dilation) {
370 
371  result.addOperands({input, weight, bias});
372  result.addAttribute("pad", pad);
373  result.addAttribute("stride", stride);
374  result.addAttribute("dilation", dilation);
375 
376  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
377  if (quantAttr) {
378  result.addAttribute("quantization_info", quantAttr);
379  result.addTypes(
380  buildConvOpResultTypeInfo(builder, outputType, input, weight));
381  } else {
382  result.addTypes(outputType);
383  }
384 }
385 
386 /// Handles tosa.transpose_conv2d which has outpad and output shape
387 /// attributes.
389  OpBuilder &builder, OperationState &result, Type outputType, Value input,
390  Value weight, Value bias, DenseI64ArrayAttr outpad,
391  DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) {
392  result.addOperands({input, weight, bias});
393  result.addAttribute("out_pad", outpad);
394  result.addAttribute("stride", stride);
395  result.addAttribute("out_shape", outputShape);
396  auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
397 
398  if (quantAttr) {
399  result.addAttribute("quantization_info", quantAttr);
400  result.addTypes(
401  buildConvOpResultTypeInfo(builder, outputType, input, weight));
402  } else {
403  result.addTypes(outputType);
404  }
405 }
406 
407 /// The tosa.fully_connected op has its own builder as it does not have
408 /// strides/dilation/padding.
409 static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
410  Type outputType, Value input, Value weight,
411  Value bias) {
412 
413  result.addOperands({input, weight, bias});
414  auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
415  if (quantAttr) {
416  result.addAttribute("quantization_info", quantAttr);
417  result.addTypes(
418  buildConvOpResultTypeInfo(builder, outputType, input, weight));
419  } else {
420  result.addTypes(outputType);
421  }
422 }
423 
424 /// The tosa.matmul op is also intended to be generated where a
425 /// fully_connected op must be constructed where the weight is not a constant.
426 /// In this case, the fully_connected op must be expressed using matmul.
427 /// TODO: Add link to the leglization document explaining this.
429  OperationState &result, Type outputType,
430  Value a, Value b) {
431  result.addOperands({a, b});
432  auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
433 
434  if (quantAttr) {
435  result.addAttribute("quantization_info", quantAttr);
436 
437  auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
438  assert(inputType && "Input must be a shaped tensor type!");
439 
440  auto inputQType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
441  inputType.getElementType());
442  assert(inputQType && "Tensor must have quantized datatype!");
443 
444  unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
445 
446  auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
447  assert(outputShapedType && "Output must be a shaped type");
448 
449  IntegerType accElementType;
450  if (inputBits == 16)
451  accElementType = builder.getIntegerType(48);
452  else
453  accElementType = builder.getI32Type();
454  auto accType = outputShapedType.clone(accElementType);
455  result.addTypes(accType);
456  } else {
457  result.addTypes(outputType);
458  }
459 }
460 
461 /// Both the tosa.avg_pool2d and unary ops use the same
462 /// UnaruOpQuantizationAttr but avg_pool operator has its own builder as it
463 /// has additional parameters not part of the unary ops.
464 static void
466  Type outputType, Value input,
467  DenseArrayAttr kernel, DenseArrayAttr stride,
468  DenseArrayAttr pad, TypeAttr accType) {
469  result.addOperands(input);
470  result.addAttribute("kernel", kernel);
471  result.addAttribute("stride", stride);
472  result.addAttribute("pad", pad);
473  result.addAttribute("acc_type", accType);
474  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
475  if (quantAttr)
476  result.addAttribute("quantization_info", quantAttr);
477  result.types.push_back(outputType);
478 }
479 
480 /// This builder is called on single-parameter unary operators that have scale
481 /// relationship between their input and output, expressed by the
482 /// UnaryOpQuantizationAttr.
484  OperationState &result, Type outputType,
485  Value input) {
486  result.addOperands(input);
487  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
488  if (quantAttr)
489  result.addAttribute("quantization_info", quantAttr);
490  result.types.push_back(outputType);
491 }
492 
493 /// This builder is called on TOSA pad operator that needs to create its own
494 /// OptionalAttr quantization_attr parameter to scale the padding values
495 /// correctly. No pad_const is interpreted as zero-padding.
496 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
497  Type outputType, Value input,
498  Value paddings) {
499  result.addOperands({input, paddings});
500  auto quantAttr = buildPadOpQuantizationAttr(builder, input);
501  if (quantAttr)
502  result.addAttribute("quantization_info", quantAttr);
503  result.types.push_back(outputType);
504 }
505 
506 /// This builder is called on TOSA pad operator when an explicit pad_const
507 /// value is passed in. It also optionally constructs quantization_attr.
509  OperationState &result,
510  Type outputType, Value input,
511  Value paddings,
512  Value padConst) {
513  result.addOperands({input, paddings, padConst});
514  auto quantAttr = buildPadOpQuantizationAttr(builder, input);
515  if (quantAttr)
516  result.addAttribute("quantization_info", quantAttr);
517  result.types.push_back(outputType);
518 }
519 
520 //===----------------------------------------------------------------------===//
521 // TOSA Operator Return Type Inference.
522 //===----------------------------------------------------------------------===//
523 
524 static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
525  SmallVector<int64_t> &outShape) {
526  int64_t outRank = 0;
527  for (int i = 0, e = operands.size(); i != e; ++i) {
528  auto shape = operands.getShape(i);
529  if (!shape.hasRank()) {
530  // TODO(jennik): Update function to have better case handling for
531  // invalid operands and for ranked tensors.
532  return failure();
533  }
534  outRank = std::max<int64_t>(outRank, shape.getRank());
535  }
536 
537  outShape.resize(outRank, 1);
538 
539  for (int i = 0, e = operands.size(); i != e; ++i) {
540  auto shape = operands.getShape(i);
541  auto rankDiff = outShape.size() - shape.getRank();
542 
543  for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
544  auto dim1 = outShape[i + rankDiff];
545  auto dim2 = shape.getDimSize(i);
546  auto resolvedDim = dim1;
547 
548  if (dim1 == 1) {
549  resolvedDim = dim2;
550  } else if (dim2 == 1) {
551  resolvedDim = dim1;
552  } else if (dim1 != dim2) {
553  return failure();
554  }
555  outShape[i + rankDiff] = resolvedDim;
556  }
557  }
558 
559  return success();
560 }
561 
562 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
563  MLIRContext *context, ::std::optional<Location> location,
564  ArgMaxOp::Adaptor adaptor,
565  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
566  ShapeAdaptor inputShape(adaptor.getInput().getType());
567  IntegerAttr axis = adaptor.getProperties().axis;
568  int32_t axisVal = axis.getValue().getSExtValue();
569 
570  if (!inputShape.hasRank()) {
571  inferredReturnShapes.push_back(ShapedTypeComponents());
572  return success();
573  }
574 
575  SmallVector<int64_t> outShape;
576  outShape.reserve(inputShape.getRank() - 1);
577  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
578  if (i == axisVal)
579  continue;
580  outShape.push_back(inputShape.getDimSize(i));
581  }
582 
583  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
584  return success();
585 }
586 
587 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
588  MLIRContext *context, ::std::optional<Location> location,
589  RFFT2dOp::Adaptor adaptor,
590  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
591  ShapeAdaptor inputShape(adaptor.getInput().getType());
592 
593  if (!inputShape.hasRank())
594  return failure();
595 
596  llvm::SmallVector<int64_t> outputShape;
597  outputShape.resize(3, ShapedType::kDynamic);
598  outputShape[0] = inputShape.getDimSize(0);
599  outputShape[1] = inputShape.getDimSize(1);
600  int64_t inWidth = inputShape.getDimSize(2);
601 
602  // Note that we can support this calculation symbolically
603  // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1]
604  if (inWidth != ShapedType::kDynamic)
605  outputShape[2] = inWidth / 2 + 1;
606 
607  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
608  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
609 
610  return success();
611 }
612 
613 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
614  MLIRContext *context, ::std::optional<Location> location,
615  FFT2dOp::Adaptor adaptor,
616  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
617  inferredReturnShapes.push_back(
618  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
619  inferredReturnShapes.push_back(
620  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
621  return success();
622 }
623 
624 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
625  MLIRContext *context, ::std::optional<Location> location,
626  ConcatOp::Adaptor adaptor,
627  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
628  // Infer all dimension sizes by reducing based on inputs.
629  const Properties &prop = adaptor.getProperties();
630  int32_t axis = prop.axis.getValue().getSExtValue();
631  llvm::SmallVector<int64_t> outputShape;
632  bool hasRankedInput = false;
633  for (auto operand : adaptor.getOperands()) {
634  ShapeAdaptor operandShape(operand.getType());
635  if (!operandShape.hasRank())
636  continue;
637 
638  // Copy the Operand's rank.
639  if (!hasRankedInput)
640  outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
641 
642  // Copy shapes until the dim is non-dynamic.
643  for (int i = 0, s = operandShape.getRank(); i < s; i++) {
644  if (i == axis || operandShape.isDynamicDim(i))
645  continue;
646  if (outputShape[i] == ShapedType::kDynamic)
647  outputShape[i] = operandShape.getDimSize(i);
648  if (outputShape[i] != operandShape.getDimSize(i))
649  return emitOptionalError(location,
650  "Cannot concat tensors with different sizes"
651  " on the non-axis dimension ",
652  i);
653  }
654 
655  hasRankedInput = true;
656  }
657  Type inputType =
658  llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
659  if (!hasRankedInput) {
660  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
661  return success();
662  }
663 
664  // Determine the dimension size along the concatenation axis.
665  int64_t concatDimSize = 0;
666  for (auto operand : adaptor.getOperands()) {
667  ShapeAdaptor operandShape(operand.getType());
668 
669  // We need to know the length of the concatenation axis of all inputs to
670  // determine the dimension size of the output shape.
671  if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
672  concatDimSize = ShapedType::kDynamic;
673  break;
674  }
675 
676  concatDimSize += operandShape.getDimSize(axis);
677  }
678 
679  outputShape[axis] = concatDimSize;
680 
681  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
682  return success();
683 }
684 
685 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
686  MLIRContext *context, ::std::optional<Location> location,
687  ValueShapeRange operands, DictionaryAttr attributes,
688  OpaqueProperties properties, RegionRange regions,
689  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
690  auto elementType = IntegerType::get(context, /*width=*/1);
691 
693  if (resolveBroadcastShape(operands, outShape).failed()) {
694  inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
695  return success();
696  }
697 
698  inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
699  return success();
700 }
701 
702 bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
703  if (l.size() != r.size() || l.size() != 1)
704  return false;
705  return succeeded(verifyCompatibleShape(l[0], r[0]));
706 }
707 
708 LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
709  MLIRContext *context, ::std::optional<Location> location,
710  FullyConnectedOp::Adaptor adaptor,
711  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
712  ShapeAdaptor inputShape(adaptor.getInput().getType());
713  ShapeAdaptor weightShape(adaptor.getWeight().getType());
714  ShapeAdaptor biasShape(adaptor.getBias().getType());
715 
716  // All shapes are dynamic.
717  SmallVector<int64_t> outShape;
718  outShape.resize(2, ShapedType::kDynamic);
719 
720  if (inputShape.hasRank()) {
721  outShape[0] = inputShape.getDimSize(0);
722  }
723 
724  if (weightShape.hasRank()) {
725  outShape[1] = weightShape.getDimSize(0);
726  }
727 
728  if (biasShape.hasRank()) {
729  outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0)
730  : outShape[1];
731  }
732 
733  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
734  return success();
735 }
736 
737 LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); }
738 
739 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
740  MLIRContext *context, ::std::optional<Location> location,
741  MatMulOp::Adaptor adaptor,
742  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
743  ShapeAdaptor lhsShape(adaptor.getA().getType());
744  ShapeAdaptor rhsShape(adaptor.getB().getType());
745 
746  // All shapes are dynamic.
747  SmallVector<int64_t> outShape;
748  outShape.resize(3, ShapedType::kDynamic);
749 
750  if (lhsShape.hasRank()) {
751  outShape[0] = lhsShape.getDimSize(0);
752  outShape[1] = lhsShape.getDimSize(1);
753  }
754 
755  if (rhsShape.hasRank()) {
756  outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
757  : outShape[0];
758  outShape[2] = rhsShape.getDimSize(2);
759  }
760 
761  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
762  return success();
763 }
764 
765 LogicalResult tosa::PadOp::inferReturnTypeComponents(
766  MLIRContext *context, ::std::optional<Location> location,
767  PadOp::Adaptor adaptor,
768  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
769  ShapeAdaptor inputShape(adaptor.getInput1().getType());
770  ShapeAdaptor paddingShape(adaptor.getPadding().getType());
771  SmallVector<int64_t> outputShape;
772 
773  // If both inputs have unknown shape, we cannot determine the shape of the
774  // output.
775  if (!inputShape.hasRank() && !paddingShape.hasRank()) {
776  inferredReturnShapes.push_back(ShapedTypeComponents());
777  return success();
778  }
779 
780  // If the input rank is unknown we can info the output rank using the
781  // padding shape's first dim.
782  if (!inputShape.hasRank()) {
783  if (paddingShape.isDynamicDim(0)) {
784  inferredReturnShapes.push_back(ShapedTypeComponents());
785  return success();
786  }
787 
788  outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamic);
789  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
790  return success();
791  }
792 
793  DenseIntElementsAttr paddings;
794  // If the paddings value is not a constant, all dimensions must be dynamic.
795  if (!matchPattern(adaptor.getPadding(), m_Constant(&paddings))) {
796  outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
797  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
798  return success();
799  }
800 
801  SmallVector<int64_t> paddingValues;
802  for (auto val : paddings) {
803  paddingValues.push_back(val.getSExtValue());
804  }
805 
806  outputShape.reserve(inputShape.getRank());
807  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
808  if (inputShape.isDynamicDim(i)) {
809  outputShape.push_back(ShapedType::kDynamic);
810  continue;
811  }
812 
813  outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] +
814  paddingValues[i * 2 + 1]);
815  }
816 
817  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
818  return success();
819 }
820 
821 LogicalResult tosa::PadOp::verify() {
822  RankedTensorType inputType = getInput1().getType();
823  RankedTensorType outputType = getOutput().getType();
824  TensorType paddingType = getPadding().getType();
825 
826  if (inputType.getRank() != outputType.getRank())
827  return emitOpError() << "expect same input and output tensor rank.";
828 
829  if (paddingType.hasRank() && paddingType.getRank() != 2)
830  return emitOpError() << "expect 'padding' tensor rank equal to 2.";
831 
832  return success();
833 }
834 
836  return to_vector(llvm::map_range(shape, [](int64_t dim) {
837  return dim == -1 ? ShapedType::kDynamic : dim;
838  }));
839 }
840 
841 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
842  MLIRContext *context, ::std::optional<Location> location,
843  SliceOp::Adaptor adaptor,
844  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
845  inferredReturnShapes.push_back(
846  ShapedTypeComponents(convertToMlirShape(adaptor.getSize())));
847  return success();
848 }
849 
850 LogicalResult tosa::SliceOp::verify() {
851  auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
852  if (!inputType)
853  return success();
854 
855  if (static_cast<size_t>(inputType.getRank()) != getStart().size())
856  return emitOpError(
857  "length of start attribute is not equal rank of input shape");
858 
859  if (static_cast<size_t>(inputType.getRank()) != getSize().size())
860  return emitOpError(
861  "length of size attribute is not equal rank of input shape");
862 
863  return success();
864 }
865 
866 LogicalResult tosa::TableOp::inferReturnTypeComponents(
867  MLIRContext *context, ::std::optional<Location> location,
868  TableOp::Adaptor adaptor,
869  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
870  ShapeAdaptor inputShape(adaptor.getInput().getType());
871 
872  if (!inputShape.hasRank()) {
873  inferredReturnShapes.push_back(ShapedTypeComponents());
874  return success();
875  }
876 
877  inferredReturnShapes.resize(1);
878  inputShape.getDims(inferredReturnShapes[0]);
879  return success();
880 }
881 
882 LogicalResult tosa::TableOp::verify() {
883  TensorType inputType = getInput().getType();
884  TensorType outputType = getOutput().getType();
885 
886  if (inputType.hasRank() && outputType.hasRank() &&
887  inputType.getRank() != outputType.getRank())
888  return emitOpError()
889  << "expected input tensor rank to equal result tensor rank";
890 
891  auto inputDims = inputType.getShape();
892  auto outputDims = outputType.getShape();
893  for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
894  int64_t dim = it.index();
895  auto [inputDim, outputDim] = it.value();
896  if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) {
897  return emitOpError() << "dim(result, " << dim << ") = " << outputDim
898  << " doesn't match dim(input, " << dim
899  << ") = " << inputDim;
900  }
901  }
902  return success();
903 }
904 
905 LogicalResult tosa::TileOp::inferReturnTypeComponents(
906  MLIRContext *context, ::std::optional<Location> location,
907  TileOp::Adaptor adaptor,
908  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
909  ArrayRef<int64_t> multiples = adaptor.getMultiples();
910  ShapeAdaptor inputShape(adaptor.getInput1().getType());
911  SmallVector<int64_t> outputShape;
912  if (!inputShape.hasRank()) {
913  outputShape.resize(multiples.size(), ShapedType::kDynamic);
914  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
915  return success();
916  } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
917  return failure();
918 
919  // Any non dynamic dimension can be multiplied to a known size.
920  outputShape.reserve(multiples.size());
921  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
922  int64_t dim = inputShape.getDimSize(i);
923  if (dim != ShapedType::kDynamic)
924  dim *= multiples[i];
925  outputShape.push_back(dim);
926  }
927 
928  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
929  return success();
930 }
931 
932 LogicalResult tosa::TileOp::verify() {
933  ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
934  ShapedType outputType = llvm::cast<ShapedType>(getType());
935  auto multiples = getMultiples();
936 
937  if (inputType.hasRank()) {
938  if (static_cast<size_t>(inputType.getRank()) != multiples.size())
939  return emitOpError("expect 'multiples' array to have length ")
940  << inputType.getRank() << " but got " << multiples.size() << ".";
941  if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
942  return emitOpError("expect same input and output tensor rank.");
943  } else if (outputType.hasRank() &&
944  static_cast<size_t>(outputType.getRank()) != multiples.size())
945  return emitOpError("expect 'multiples' array to have length ")
946  << outputType.getRank() << " but got " << multiples.size() << ".";
947 
948  if (llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
949  return emitOpError(
950  "expect element of 'multiples' to be positive integer or -1.");
951 
952  return success();
953 }
954 
955 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
956  if (l.size() != r.size() || l.size() != 1)
957  return false;
958  return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
959 }
960 
961 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
962  MLIRContext *context, ::std::optional<Location> location,
963  ReshapeOp::Adaptor adaptor,
964  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
965  ShapeAdaptor inputShape(adaptor.getInput1().getType());
966  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
967  llvm::SmallVector<int64_t> newShapeValue =
968  convertToMlirShape(adaptor.getNewShape());
969 
970  // We cannot infer from the total number of elements so we must take the
971  // shape attribute as exact.
972  if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
973  inferredReturnShapes.push_back(
974  ShapedTypeComponents(newShapeValue, inputType));
975  return success();
976  }
977 
978  // Determine the number of elements covered by the slice of all static
979  // dimensions. This allows us to infer the length of the remaining dynamic
980  // dimension.
981  int64_t numElements = inputShape.getNumElements();
982  int64_t staticMul = 1;
983  for (auto val : newShapeValue) {
984  if (!ShapedType::isDynamic(val)) {
985  staticMul *= val;
986  }
987  }
988 
989  // Determine the length of the dynamic dimension.
990  for (auto &val : newShapeValue) {
991  if (ShapedType::isDynamic(val))
992  val = numElements / staticMul;
993  }
994 
995  inferredReturnShapes.push_back(
996  ShapedTypeComponents(newShapeValue, inputType));
997  return success();
998 }
999 
1000 llvm::LogicalResult tosa::ReshapeOp::verify() {
1001  TensorType inputType = getInput1().getType();
1002  RankedTensorType outputType = getType();
1003 
1004  if ((int64_t)getNewShape().size() != outputType.getRank())
1005  return emitOpError() << "new shape does not match result rank";
1006 
1007  for (auto [newShapeDim, outputShapeDim] :
1008  zip(getNewShape(), outputType.getShape())) {
1009  if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
1010  newShapeDim != outputShapeDim)
1011  return emitOpError() << "new shape is inconsistent with result shape";
1012 
1013  if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
1014  return emitOpError() << "new shape has invalid tensor dimension size "
1015  << newShapeDim;
1016  }
1017 
1018  if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
1019  int64_t inputElementsNum = inputType.getNumElements();
1020  int64_t outputElementsNum = outputType.getNumElements();
1021  if (inputElementsNum != outputElementsNum) {
1022  return emitOpError() << "cannot reshape " << inputElementsNum
1023  << " elements into " << outputElementsNum;
1024  }
1025  }
1026 
1027  int missingDims = llvm::count(getNewShape(), -1);
1028  if (missingDims > 1)
1029  return emitOpError() << "expected at most one target dimension to be -1";
1030 
1031  return mlir::success();
1032 }
1033 
1034 LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) {
1035  // Perms must be constants.
1036  DenseIntElementsAttr permsAttr;
1037  if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
1038  return failure();
1039 
1040  perms.clear();
1041  for (auto v : permsAttr.getValues<APInt>())
1042  perms.push_back(v.getSExtValue());
1043 
1044  return success();
1045 }
1046 
1047 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
1048  MLIRContext *context, ::std::optional<Location> location,
1049  TransposeOp::Adaptor adaptor,
1050  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1051  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1052  ShapeAdaptor permsShape(adaptor.getPerms().getType());
1053 
1054  // We cannot infer anything from a rank-0 "permutation" tensor.
1055  if (permsShape.hasRank() && permsShape.getRank() == 0)
1056  return failure();
1057 
1058  // If input rank and permutation length is unknown, the output rank is
1059  // unknown.
1060  if (!inputShape.hasRank() || !permsShape.hasRank() ||
1061  permsShape.isDynamicDim(0)) {
1062  inferredReturnShapes.push_back(ShapedTypeComponents());
1063  return success();
1064  }
1065 
1066  // This would imply the number of permutations does not match the rank of
1067  // the input which is illegal.
1068  if (permsShape.getDimSize(0) != inputShape.getRank()) {
1069  return failure();
1070  }
1071 
1072  SmallVector<int64_t> outputShape;
1073  // Rank-0 means no permutations matter.
1074  if (inputShape.getRank() == 0) {
1075  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1076  return success();
1077  }
1078 
1079  // Check whether the input dimensions are all the same.
1080  bool allTheSame = true;
1081  for (int i = 1, s = inputShape.getRank(); i < s; i++) {
1082  if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
1083  allTheSame = false;
1084  break;
1085  }
1086  }
1087 
1088  // If all of the input dimensions are the same we don't care about the
1089  // permutation.
1090  if (allTheSame) {
1091  outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
1092  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1093  return success();
1094  }
1095 
1096  outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1097  // If the permuations are a constant we can directly determine the output
1098  // shape.
1099  DenseIntElementsAttr attr;
1100  if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
1101  attr.getType().getRank() == 1) {
1102  ShapeAdaptor permShape = attr;
1103  // Constant permutation must be the same length as the input rank.
1104  if (inputShape.getRank() != permShape.getRank())
1105  return emitOptionalError(location,
1106  "constant permutation must be the same length"
1107  " as the input rank");
1108 
1109  // Constant permutation values must be within the input rank.
1110  for (int i = 0, e = inputShape.getRank(); i < e; i++) {
1111  if (inputShape.getRank() <= permShape.getDimSize(i))
1112  return failure();
1113  }
1114 
1115  outputShape.reserve(inputShape.getRank());
1116  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1117  outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
1118  }
1119  }
1120 
1121  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1122  return success();
1123 }
1124 
1125 LogicalResult tosa::TransposeOp::verify() {
1126  TensorType inputType = getInput1().getType();
1127  TensorType permType = getPerms().getType();
1128  TensorType outputType = getOutput().getType();
1129 
1130  if (permType.hasRank() && permType.getRank() != 1)
1131  return emitOpError()
1132  << "expected permutation tensor to be rank 1 but got rank "
1133  << permType.getRank();
1134  if (inputType.hasRank() && permType.hasRank())
1135  if (!permType.isDynamicDim(0) &&
1136  permType.getDimSize(0) != inputType.getRank())
1137  return emitOpError() << "expected permutation tensor dim 0 to have size "
1138  << inputType.getRank()
1139  << " (input rank) but got size "
1140  << permType.getDimSize(0);
1141  if (inputType.hasRank() && outputType.hasRank() &&
1142  inputType.getRank() != outputType.getRank())
1143  return emitOpError()
1144  << "expected input tensor rank to equal result tensor rank";
1145  if (outputType.hasRank() && permType.hasRank())
1146  if (!permType.isDynamicDim(0) &&
1147  permType.getDimSize(0) != outputType.getRank())
1148  return emitOpError() << "expected permutation tensor dim 0 to have size "
1149  << outputType.getRank()
1150  << " (output rank) but got size "
1151  << permType.getDimSize(0);
1152 
1153  SmallVector<int32_t> constantPerms;
1154  if (succeeded(getConstantPerms(constantPerms))) {
1155  // Assert that the permutation tensor has a rank, which means that the
1156  // rank has been verified above.
1157  assert(permType.hasRank() &&
1158  "Unexpectedly found permutation tensor without rank");
1159  if (!llvm::all_of(constantPerms,
1160  [&constantPerms](int32_t s) {
1161  return s >= 0 &&
1162  static_cast<size_t>(s) < constantPerms.size();
1163  }) ||
1164  !isPermutationVector(llvm::to_vector(llvm::map_range(
1165  constantPerms, [](int32_t v) -> int64_t { return v; }))))
1166  return emitOpError() << "expected valid permutation tensor";
1167 
1168  // Verify that the types of the input and output tensors are properly
1169  // permuted.
1170  if (inputType.hasRank() && outputType.hasRank()) {
1171  assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
1172  inputType.getRank() == outputType.getRank());
1173 
1174  for (auto i = 0; i < outputType.getRank(); i++) {
1175  if (inputType.isDynamicDim(constantPerms[i]) ||
1176  outputType.isDynamicDim(i))
1177  continue;
1178 
1179  if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
1180  return emitOpError()
1181  << "expected output tensor dim " << i << " to match "
1182  << "input dim " << constantPerms[i] << " with value of "
1183  << inputType.getDimSize(constantPerms[i]);
1184  }
1185  }
1186  }
1187  return success();
1188 }
1189 
1190 LogicalResult TransposeOp::reifyResultShapes(
1191  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1192 
1193  SmallVector<int32_t> transposePerms;
1194  if (getConstantPerms(transposePerms).failed())
1195  return failure();
1196 
1197  Value input = getInput1();
1198  auto inputType = cast<TensorType>(input.getType());
1199 
1200  SmallVector<OpFoldResult> returnedDims(inputType.getRank());
1201  for (auto dim : transposePerms) {
1202  int32_t dimInInput = transposePerms[dim];
1203  if (inputType.isDynamicDim(dimInInput))
1204  returnedDims[dim] =
1205  builder.create<tensor::DimOp>(getLoc(), input, dimInInput)
1206  .getResult();
1207  else
1208  returnedDims[dim] =
1209  builder.getIndexAttr(inputType.getDimSize(dimInInput));
1210  }
1211 
1212  reifiedReturnShapes.emplace_back(std::move(returnedDims));
1213  return success();
1214 }
1215 
1216 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
1217  MLIRContext *context, ::std::optional<Location> location,
1218  GatherOp::Adaptor adaptor,
1219  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1220  llvm::SmallVector<int64_t> outputShape;
1221  outputShape.resize(3, ShapedType::kDynamic);
1222 
1223  ShapeAdaptor valuesShape(adaptor.getValues().getType());
1224  if (valuesShape.hasRank()) {
1225  outputShape[0] = valuesShape.getDimSize(0);
1226  outputShape[2] = valuesShape.getDimSize(2);
1227  }
1228 
1229  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
1230  if (indicesShape.hasRank()) {
1231  if (outputShape[0] == ShapedType::kDynamic)
1232  outputShape[0] = indicesShape.getDimSize(0);
1233  if (outputShape[1] == ShapedType::kDynamic)
1234  outputShape[1] = indicesShape.getDimSize(1);
1235  }
1236 
1237  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1238  return success();
1239 }
1240 
1241 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
1242  MLIRContext *context, ::std::optional<Location> location,
1243  ResizeOp::Adaptor adaptor,
1244  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1245  llvm::SmallVector<int64_t, 4> outputShape;
1246  outputShape.resize(4, ShapedType::kDynamic);
1247 
1248  ShapeAdaptor inputShape(adaptor.getInput().getType());
1249  if (!inputShape.hasRank())
1250  return failure();
1251 
1252  outputShape[0] = inputShape.getDimSize(0);
1253  outputShape[3] = inputShape.getDimSize(3);
1254  int64_t inputHeight = inputShape.getDimSize(1);
1255  int64_t inputWidth = inputShape.getDimSize(2);
1256 
1257  if ((inputHeight == ShapedType::kDynamic) ||
1258  (inputWidth == ShapedType::kDynamic))
1259  return failure();
1260 
1261  llvm::ArrayRef<int64_t> scaleInt = adaptor.getScale();
1262  llvm::ArrayRef<int64_t> offsetInt = adaptor.getOffset();
1263  llvm::ArrayRef<int64_t> borderInt = adaptor.getBorder();
1264 
1265  // Compute the output shape based on attributes: scale, offset, and border.
1266  outputShape[1] =
1267  (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
1268  scaleInt[1]) +
1269  1;
1270 
1271  outputShape[2] =
1272  (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
1273  scaleInt[3]) +
1274  1;
1275 
1276  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1277  return success();
1278 }
1279 
1280 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
1281  MLIRContext *context, ::std::optional<Location> location,
1282  ScatterOp::Adaptor adaptor,
1283  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1284  llvm::SmallVector<int64_t> outputShape;
1285  outputShape.resize(3, ShapedType::kDynamic);
1286 
1287  ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
1288  if (valuesInShape.hasRank()) {
1289  outputShape[0] = valuesInShape.getDimSize(0);
1290  outputShape[1] = valuesInShape.getDimSize(1);
1291  outputShape[2] = valuesInShape.getDimSize(2);
1292  }
1293 
1294  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
1295  if (indicesShape.hasRank()) {
1296  if (outputShape[0] == ShapedType::kDynamic)
1297  outputShape[0] = indicesShape.getDimSize(0);
1298  }
1299 
1300  ShapeAdaptor inputShape(adaptor.getInput().getType());
1301  if (inputShape.hasRank()) {
1302  if (outputShape[0] == ShapedType::kDynamic)
1303  outputShape[0] = inputShape.getDimSize(0);
1304  if (outputShape[2] == ShapedType::kDynamic)
1305  outputShape[2] = inputShape.getDimSize(2);
1306  }
1307 
1308  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1309  return success();
1310 }
1311 
1312 static LogicalResult ReduceInferReturnTypes(
1313  ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
1314  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1315  int64_t axisVal = axis.getValue().getSExtValue();
1316  if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
1317  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1318  return success();
1319  }
1320 
1321  SmallVector<int64_t> outputShape;
1322  operandShape.getDims(outputShape);
1323  outputShape[axisVal] = 1;
1324  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1325  return success();
1326 }
1327 
1328 #define COMPATIBLE_RETURN_TYPES(OP) \
1329  bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
1330  if (l.size() != r.size() || l.size() != 1) \
1331  return false; \
1332  if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
1333  return false; \
1334  return succeeded(verifyCompatibleShape(l[0], r[0])); \
1335  }
1336 
1337 #define REDUCE_SHAPE_INFER(OP) \
1338  LogicalResult OP::inferReturnTypeComponents( \
1339  MLIRContext *context, ::std::optional<Location> location, \
1340  OP::Adaptor adaptor, \
1341  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
1342  Type inputType = \
1343  llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
1344  ShapeAdaptor inputShape(adaptor.getInput().getType()); \
1345  const Properties &prop = adaptor.getProperties(); \
1346  return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
1347  inferredReturnShapes); \
1348  } \
1349  COMPATIBLE_RETURN_TYPES(OP)
1350 
1351 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
1352 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
1353 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
1354 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
1355 REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
1356 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
1357 #undef REDUCE_SHAPE_INFER
1358 COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
1359 #undef COMPATIBLE_RETURN_TYPES
1360 
1361 template <typename T>
1362 static LogicalResult verifyReduceOp(T op) {
1363  // All TOSA reduce Ops have input, output and axis.
1364  TensorType inputType = op.getInput().getType();
1365  TensorType outputType = op.getOutput().getType();
1366  int32_t reduceAxis = op.getAxis();
1367 
1368  if (reduceAxis < 0) {
1369  op.emitOpError("reduce axis must not be negative");
1370  return failure();
1371  }
1372  if (inputType.hasRank()) {
1373  int64_t inputRank = inputType.getRank();
1374  // We allow for a special case where the input/output shape has rank 0 and
1375  // axis is also 0.
1376  if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
1377  op.emitOpError("expect input tensor rank (")
1378  << inputRank << ") to be larger than reduce axis (" << reduceAxis
1379  << ")";
1380  return failure();
1381  }
1382  }
1383  if (outputType.hasRank()) {
1384  int64_t outputRank = outputType.getRank();
1385  if (inputType.hasRank() && outputRank != inputType.getRank()) {
1386  op.emitOpError(
1387  "expect output tensor rank to be equal to input tensor rank");
1388  return failure();
1389  }
1390  if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
1391  op.emitOpError("expect output tensor rank (")
1392  << outputRank << ") to be larger than reduce axis (" << reduceAxis
1393  << ")";
1394  return failure();
1395  }
1396  // We can only verify the reduced dimension size to be 1 if this is not
1397  // the special case of output rank == 0.
1398  if (outputRank != 0) {
1399  auto outputShape = outputType.getShape();
1400  if (!outputType.isDynamicDim(reduceAxis) &&
1401  outputShape[reduceAxis] != 1) {
1402  op.emitOpError("expect reduced dimension size to be 1, got ")
1403  << outputShape[reduceAxis];
1404  return failure();
1405  }
1406  }
1407  }
1408  return success();
1409 }
1410 
1411 LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
1412 LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
1413 LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
1414 LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
1415 LogicalResult tosa::ReduceProdOp::verify() { return verifyReduceOp(*this); }
1416 LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
1417 
1418 static LogicalResult NAryInferReturnTypes(
1419  const ValueShapeRange &operands,
1420  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1421  llvm::SmallVector<int64_t> outShape;
1422  if (resolveBroadcastShape(operands, outShape).failed()) {
1423  inferredReturnShapes.push_back(ShapedTypeComponents());
1424  } else {
1425  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1426  }
1427  return success();
1428 }
1429 
1430 #define NARY_SHAPE_INFER(OP) \
1431  LogicalResult OP::inferReturnTypeComponents( \
1432  MLIRContext *context, ::std::optional<Location> location, \
1433  ValueShapeRange operands, DictionaryAttr attributes, \
1434  OpaqueProperties properties, RegionRange regions, \
1435  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
1436  return NAryInferReturnTypes(operands, inferredReturnShapes); \
1437  }
1438 
1439 NARY_SHAPE_INFER(tosa::AbsOp)
1440 NARY_SHAPE_INFER(tosa::AddOp)
1441 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
1442 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
1443 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
1444 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
1445 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
1446 NARY_SHAPE_INFER(tosa::CastOp)
1447 NARY_SHAPE_INFER(tosa::CeilOp)
1448 NARY_SHAPE_INFER(tosa::ClampOp)
1449 NARY_SHAPE_INFER(tosa::ClzOp)
1450 NARY_SHAPE_INFER(tosa::CosOp)
1451 NARY_SHAPE_INFER(tosa::ExpOp)
1452 NARY_SHAPE_INFER(tosa::FloorOp)
1453 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
1454 NARY_SHAPE_INFER(tosa::GreaterOp)
1455 NARY_SHAPE_INFER(tosa::IdentityOp)
1456 NARY_SHAPE_INFER(tosa::IntDivOp)
1457 NARY_SHAPE_INFER(tosa::LogOp)
1458 NARY_SHAPE_INFER(tosa::LogicalAndOp)
1459 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
1460 NARY_SHAPE_INFER(tosa::LogicalNotOp)
1461 NARY_SHAPE_INFER(tosa::LogicalOrOp)
1462 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
1463 NARY_SHAPE_INFER(tosa::LogicalXorOp)
1464 NARY_SHAPE_INFER(tosa::MaximumOp)
1465 NARY_SHAPE_INFER(tosa::MinimumOp)
1466 NARY_SHAPE_INFER(tosa::MulOp)
1467 NARY_SHAPE_INFER(tosa::NegateOp)
1468 NARY_SHAPE_INFER(tosa::PowOp)
1469 NARY_SHAPE_INFER(tosa::ReciprocalOp)
1470 NARY_SHAPE_INFER(tosa::RescaleOp)
1471 NARY_SHAPE_INFER(tosa::ReverseOp)
1472 NARY_SHAPE_INFER(tosa::RsqrtOp)
1473 NARY_SHAPE_INFER(tosa::SinOp)
1474 NARY_SHAPE_INFER(tosa::SelectOp)
1475 NARY_SHAPE_INFER(tosa::SubOp)
1476 NARY_SHAPE_INFER(tosa::TanhOp)
1477 NARY_SHAPE_INFER(tosa::ErfOp)
1478 NARY_SHAPE_INFER(tosa::SigmoidOp)
1479 #undef PRED_SHAPE_INFER
1480 
1481 static LogicalResult poolingInferReturnTypes(
1482  ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
1483  ArrayRef<int64_t> pad,
1484  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1485  llvm::SmallVector<int64_t> outputShape;
1486  outputShape.resize(4, ShapedType::kDynamic);
1487 
1488  // We only know the rank if the input type is unranked.
1489  if (!inputShape) {
1490  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1491  return success();
1492  }
1493 
1494  // Batch and number of channels are identical for pooling layer.
1495  outputShape[0] = inputShape.getDimSize(0);
1496  outputShape[3] = inputShape.getDimSize(3);
1497 
1498  int64_t height = inputShape.getDimSize(1);
1499  int64_t width = inputShape.getDimSize(2);
1500 
1501  if (!ShapedType::isDynamic(height)) {
1502  int64_t padded = height + pad[0] + pad[1] - kernel[0];
1503  outputShape[1] = padded / stride[0] + 1;
1504  }
1505 
1506  if (!ShapedType::isDynamic(width)) {
1507  int64_t padded = width + pad[2] + pad[3] - kernel[1];
1508  outputShape[2] = padded / stride[1] + 1;
1509  }
1510 
1511  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1512  return success();
1513 }
1514 
1515 LogicalResult Conv2DOp::inferReturnTypeComponents(
1516  MLIRContext *context, ::std::optional<Location> location,
1517  Conv2DOp::Adaptor adaptor,
1518  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1519  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
1520 
1521  int64_t inputWidth = ShapedType::kDynamic;
1522  int64_t inputHeight = ShapedType::kDynamic;
1523  int64_t weightWidth = ShapedType::kDynamic;
1524  int64_t weightHeight = ShapedType::kDynamic;
1525 
1526  // Input shape describes input width/height and batch.
1527 
1528  ShapeAdaptor inputShape(adaptor.getInput().getType());
1529  if (inputShape.hasRank()) {
1530  outputShape[0] = inputShape.getDimSize(0);
1531  inputHeight = inputShape.getDimSize(1);
1532  inputWidth = inputShape.getDimSize(2);
1533  }
1534 
1535  // Weight shapes describes the filter width/height and the output channels.
1536  ShapeAdaptor weightShape(adaptor.getWeight().getType());
1537  if (weightShape.hasRank()) {
1538  outputShape[3] = weightShape.getDimSize(0);
1539  weightHeight = weightShape.getDimSize(1);
1540  weightWidth = weightShape.getDimSize(2);
1541  }
1542 
1543  // Bias shape can describe the output channels.
1544  ShapeAdaptor biasShape(adaptor.getBias().getType());
1545  if (biasShape.hasRank()) {
1546  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1547  ? biasShape.getDimSize(0)
1548  : outputShape[3];
1549  }
1550 
1551  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
1552  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1553  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
1554 
1555  if (!ShapedType::isDynamic(inputHeight) &&
1556  !ShapedType::isDynamic(weightHeight)) {
1557  int64_t inputSize = inputHeight + padding[0] + padding[1];
1558  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1559  int64_t unstridedResult = inputSize - filterSize + 1;
1560  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1561  }
1562 
1563  if (!ShapedType::isDynamic(inputWidth) &&
1564  !ShapedType::isDynamic(weightWidth)) {
1565  int64_t inputSize = inputWidth + padding[2] + padding[3];
1566  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1567  int64_t unstridedResult = inputSize - filterSize + 1;
1568  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1569  }
1570 
1571  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1572  return success();
1573 }
1574 
1575 LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); }
1576 
1577 LogicalResult Conv3DOp::inferReturnTypeComponents(
1578  MLIRContext *context, ::std::optional<Location> location,
1579  Conv3DOp::Adaptor adaptor,
1580  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1581  llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
1582 
1583  int64_t inputWidth = ShapedType::kDynamic;
1584  int64_t inputHeight = ShapedType::kDynamic;
1585  int64_t inputDepth = ShapedType::kDynamic;
1586 
1587  int64_t weightWidth = ShapedType::kDynamic;
1588  int64_t weightHeight = ShapedType::kDynamic;
1589  int64_t weightDepth = ShapedType::kDynamic;
1590 
1591  // Input shape describes input width/height and batch.
1592  ShapeAdaptor inputShape(adaptor.getInput().getType());
1593  if (inputShape.hasRank()) {
1594  outputShape[0] = inputShape.getDimSize(0);
1595  inputDepth = inputShape.getDimSize(1);
1596  inputHeight = inputShape.getDimSize(2);
1597  inputWidth = inputShape.getDimSize(3);
1598  }
1599 
1600  // Weight shapes describes the filter width/height and the output channels.
1601  ShapeAdaptor weightShape(adaptor.getWeight().getType());
1602  if (weightShape.hasRank()) {
1603  outputShape[4] = weightShape.getDimSize(0);
1604  weightDepth = weightShape.getDimSize(1);
1605  weightHeight = weightShape.getDimSize(2);
1606  weightWidth = weightShape.getDimSize(3);
1607  }
1608 
1609  // Bias shape can describe the output channels.
1610  ShapeAdaptor biasShape(adaptor.getBias().getType());
1611  if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
1612  outputShape[4] = biasShape.getDimSize(0);
1613  }
1614 
1615  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
1616  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1617  llvm::ArrayRef<int64_t> pad = adaptor.getPad();
1618 
1619  if (!ShapedType::isDynamic(inputDepth) &&
1620  !ShapedType::isDynamic(weightDepth)) {
1621  int32_t inputSize = inputDepth + pad[0] + pad[1];
1622  int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
1623  int32_t unstridedResult = inputSize - filterSize + 1;
1624  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1625  }
1626 
1627  if (!ShapedType::isDynamic(inputHeight) &&
1628  !ShapedType::isDynamic(weightHeight)) {
1629  int32_t inputSize = inputHeight + pad[2] + pad[3];
1630  int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
1631  int32_t unstridedResult = inputSize - filterSize + 1;
1632  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1633  }
1634 
1635  if (!ShapedType::isDynamic(inputWidth) &&
1636  !ShapedType::isDynamic(weightWidth)) {
1637  int32_t inputSize = inputWidth + pad[4] + pad[5];
1638  int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
1639  int32_t unstridedResult = inputSize - filterSize + 1;
1640  outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
1641  }
1642 
1643  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1644  return success();
1645 }
1646 
1647 LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); }
1648 
1649 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
1650  MLIRContext *context, ::std::optional<Location> location,
1651  AvgPool2dOp::Adaptor adaptor,
1652  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1653  ShapeAdaptor inputShape(adaptor.getInput().getType());
1654  const Properties &prop = adaptor.getProperties();
1655  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
1656  inferredReturnShapes);
1657 }
1658 
1659 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
1660  MLIRContext *context, ::std::optional<Location> location,
1661  MaxPool2dOp::Adaptor adaptor,
1662  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1663  ShapeAdaptor inputShape(adaptor.getInput().getType());
1664  const Properties &prop = adaptor.getProperties();
1665  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
1666  inferredReturnShapes);
1667 }
1668 
1669 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
1670  MLIRContext *context, ::std::optional<Location> location,
1671  DepthwiseConv2DOp::Adaptor adaptor,
1672  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1673  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
1674 
1675  int64_t inputWidth = ShapedType::kDynamic;
1676  int64_t inputHeight = ShapedType::kDynamic;
1677  int64_t inputChannels = ShapedType::kDynamic;
1678 
1679  int64_t weightWidth = ShapedType::kDynamic;
1680  int64_t weightHeight = ShapedType::kDynamic;
1681  int64_t depthChannels = ShapedType::kDynamic;
1682 
1683  // Input shape describes input width/height and batch.
1684  ShapeAdaptor inputShape(adaptor.getInput().getType());
1685  if (inputShape.hasRank()) {
1686  outputShape[0] = inputShape.getDimSize(0);
1687  inputHeight = inputShape.getDimSize(1);
1688  inputWidth = inputShape.getDimSize(2);
1689  inputChannels = inputShape.getDimSize(3);
1690  }
1691 
1692  // Weight shapes describes the filter width/height and the output channels.
1693  ShapeAdaptor weightShape(adaptor.getWeight().getType());
1694  if (weightShape.hasRank()) {
1695  weightHeight = weightShape.getDimSize(0);
1696  weightWidth = weightShape.getDimSize(1);
1697  inputChannels = ShapedType::isDynamic(inputChannels)
1698  ? weightShape.getDimSize(2)
1699  : inputChannels;
1700  depthChannels = weightShape.getDimSize(3);
1701  }
1702 
1703  // If both inputChannels and depthChannels are available we can determine
1704  // the output channels.
1705  if (!ShapedType::isDynamic(inputChannels) &&
1706  !ShapedType::isDynamic(depthChannels)) {
1707  outputShape[3] = inputChannels * depthChannels;
1708  }
1709 
1710  // Bias shape can describe the output channels.
1711  ShapeAdaptor biasShape(adaptor.getBias().getType());
1712  if (biasShape.hasRank()) {
1713  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1714  ? biasShape.getDimSize(0)
1715  : outputShape[3];
1716  }
1717 
1718  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
1719  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
1720  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1721 
1722  if (!ShapedType::isDynamic(inputHeight) &&
1723  !ShapedType::isDynamic(weightHeight)) {
1724  int64_t inputSize = inputHeight + padding[0] + padding[1];
1725  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1726  int64_t unstridedResult = inputSize - filterSize + 1;
1727  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1728  }
1729 
1730  if (!ShapedType::isDynamic(inputWidth) &&
1731  !ShapedType::isDynamic(weightWidth)) {
1732  int64_t inputSize = inputWidth + padding[2] + padding[3];
1733  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1734  int64_t unstridedResult = inputSize - filterSize + 1;
1735  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1736  }
1737 
1738  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1739  return success();
1740 }
1741 
1742 LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); }
1743 
1744 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
1745  MLIRContext *context, ::std::optional<Location> location,
1746  TransposeConv2DOp::Adaptor adaptor,
1747  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1748  // outputShape is mutable.
1749  llvm::SmallVector<int64_t> outputShape =
1750  convertToMlirShape(adaptor.getOutShape());
1751 
1752  int64_t inputWidth = ShapedType::kDynamic;
1753  int64_t inputHeight = ShapedType::kDynamic;
1754  int64_t weightWidth = ShapedType::kDynamic;
1755  int64_t weightHeight = ShapedType::kDynamic;
1756 
1757  // Input shape describes input width/height and batch.
1758  ShapeAdaptor inputShape(adaptor.getInput().getType());
1759  if (inputShape.hasRank()) {
1760  outputShape[0] = ShapedType::isDynamic(outputShape[0])
1761  ? inputShape.getDimSize(0)
1762  : outputShape[0];
1763  inputHeight = inputShape.getDimSize(1);
1764  inputWidth = inputShape.getDimSize(2);
1765  }
1766 
1767  // Weight shapes describes the filter width/height and the output channels.
1768  ShapeAdaptor weightShape(adaptor.getFilter().getType());
1769  if (weightShape.hasRank()) {
1770  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1771  ? weightShape.getDimSize(0)
1772  : outputShape[3];
1773  weightHeight = weightShape.getDimSize(1);
1774  weightWidth = weightShape.getDimSize(2);
1775  }
1776 
1777  // Bias shape can describe the output channels.
1778  ShapeAdaptor biasShape(adaptor.getInput().getType());
1779  if (biasShape.hasRank()) {
1780  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1781  ? biasShape.getDimSize(0)
1782  : outputShape[3];
1783  }
1784 
1785  llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
1786  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1787 
1788  if (!ShapedType::isDynamic(inputHeight) &&
1789  !ShapedType::isDynamic(weightHeight)) {
1790  int64_t calculateSize =
1791  (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
1792  outputShape[1] =
1793  ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
1794  }
1795 
1796  if (!ShapedType::isDynamic(inputWidth) &&
1797  !ShapedType::isDynamic(weightWidth)) {
1798  int64_t calculateSize =
1799  (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
1800  outputShape[2] =
1801  ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
1802  }
1803 
1804  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1805  return success();
1806 }
1807 
1808 LogicalResult IfOp::inferReturnTypeComponents(
1809  MLIRContext *context, ::std::optional<Location> location,
1810  IfOp::Adaptor adaptor,
1811  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1813  for (Region *region : adaptor.getRegions()) {
1814  for (auto &block : *region)
1815  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1816  yieldOps.push_back(returnOp);
1817  }
1818 
1819  if (yieldOps.empty())
1820  return failure();
1821 
1822  // Get the initial type information for the yield op.
1823  llvm::SmallVector<ValueKnowledge> resultKnowledge;
1824  resultKnowledge.reserve(yieldOps.front().getNumOperands());
1825  for (auto operand : yieldOps.front().getOperands()) {
1826  resultKnowledge.push_back(
1827  ValueKnowledge::getKnowledgeFromType(operand.getType()));
1828  }
1829 
1830  for (auto yieldOp : yieldOps) {
1831  if (resultKnowledge.size() != yieldOp.getNumOperands())
1832  return failure();
1833 
1834  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
1835  int32_t index = it.index();
1836  auto meet = ValueKnowledge::meet(
1837  resultKnowledge[index],
1838  ValueKnowledge::getKnowledgeFromType(it.value().getType()));
1839  if (!meet)
1840  continue;
1841  resultKnowledge[index] = meet;
1842  }
1843  }
1844 
1845  for (const ValueKnowledge &result : resultKnowledge) {
1846  inferredReturnShapes.push_back(result.getShapedTypeComponents());
1847  }
1848 
1849  return success();
1850 }
1851 
1852 LogicalResult WhileOp::inferReturnTypeComponents(
1853  MLIRContext *context, ::std::optional<Location> location,
1854  WhileOp::Adaptor adaptor,
1855  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1857  for (auto &block : adaptor.getBody())
1858  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1859  yieldOps.push_back(returnOp);
1860 
1861  // TOSA's while must have a tosa.yield as its terminator. If not found this
1862  // tosa.while is invalid.
1863  if (yieldOps.empty())
1864  return failure();
1865 
1866  // Get the initial type information from the operand types.
1867  llvm::SmallVector<ValueKnowledge> resultKnowledge;
1868  resultKnowledge.reserve(yieldOps.front().getNumOperands());
1869  for (auto operand : yieldOps.front().getOperands()) {
1870  resultKnowledge.push_back(
1871  ValueKnowledge::getKnowledgeFromType(operand.getType()));
1872  }
1873 
1874  for (auto yieldOp : yieldOps) {
1875  if (resultKnowledge.size() != yieldOp.getNumOperands())
1876  return failure();
1877 
1878  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
1879  int32_t index = it.index();
1880  if (auto meet = ValueKnowledge::meet(
1881  resultKnowledge[index],
1882  ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
1883  resultKnowledge[index] = meet;
1884  }
1885  }
1886  }
1887 
1888  for (const ValueKnowledge &result : resultKnowledge) {
1889  inferredReturnShapes.push_back(result.getShapedTypeComponents());
1890  }
1891 
1892  return success();
1893 }
1894 
1895 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
1896  if (auto vt = llvm::dyn_cast<VectorType>(getType()))
1897  return llvm::to_vector<4>(vt.getShape());
1898  return std::nullopt;
1899 }
1900 
1901 // parse and print of IfOp refer to the implementation of SCF dialect.
1902 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
1903  // Create the regions for 'then'.
1904  result.regions.reserve(2);
1905  Region *thenRegion = result.addRegion();
1906  Region *elseRegion = result.addRegion();
1907 
1908  auto &builder = parser.getBuilder();
1910  // Create a i1 tensor type for the boolean condition.
1911  Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
1912  if (parser.parseOperand(cond) ||
1913  parser.resolveOperand(cond, i1Type, result.operands))
1914  return failure();
1915  // Parse optional results type list.
1916  if (parser.parseOptionalArrowTypeList(result.types))
1917  return failure();
1918  // Parse the 'then' region.
1919  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
1920  return failure();
1921 
1922  // If we find an 'else' keyword then parse the 'else' region.
1923  if (!parser.parseOptionalKeyword("else")) {
1924  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
1925  return failure();
1926  }
1927 
1928  // Parse the optional attribute list.
1929  if (parser.parseOptionalAttrDict(result.attributes))
1930  return failure();
1931  return success();
1932 }
1933 
1934 void IfOp::print(OpAsmPrinter &p) {
1935  bool printBlockTerminators = false;
1936 
1937  p << " " << getCond();
1938  if (!getResults().empty()) {
1939  p << " -> (" << getResultTypes() << ")";
1940  // Print yield explicitly if the op defines values.
1941  printBlockTerminators = true;
1942  }
1943  p << ' ';
1944  p.printRegion(getThenBranch(),
1945  /*printEntryBlockArgs=*/false,
1946  /*printBlockTerminators=*/printBlockTerminators);
1947 
1948  // Print the 'else' regions if it exists and has a block.
1949  auto &elseRegion = getElseBranch();
1950  if (!elseRegion.empty()) {
1951  p << " else ";
1952  p.printRegion(elseRegion,
1953  /*printEntryBlockArgs=*/false,
1954  /*printBlockTerminators=*/printBlockTerminators);
1955  }
1956 
1957  p.printOptionalAttrDict((*this)->getAttrs());
1958 }
1959 
1960 LogicalResult ReverseOp::verify() {
1961  TensorType inputType = getInput().getType();
1962  TensorType outputType = getOutput().getType();
1963  int32_t reverseAxis = getAxis();
1964 
1965  if (reverseAxis < 0)
1966  return emitOpError("expected non-negative reverse axis");
1967  if (inputType.hasRank()) {
1968  int64_t inputRank = inputType.getRank();
1969  // We allow for a special case where the input/output shape has rank 0 and
1970  // axis is also 0.
1971  if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
1972  return emitOpError("expect input tensor rank (")
1973  << inputRank << ") to be larger than reverse axis (" << reverseAxis
1974  << ")";
1975  }
1976  if (outputType.hasRank()) {
1977  int64_t outputRank = outputType.getRank();
1978  if (inputType.hasRank() && outputRank != inputType.getRank())
1979  return emitOpError(
1980  "expect output tensor rank to be equal to input tensor rank");
1981  if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
1982  return emitOpError("expect output tensor rank (")
1983  << outputRank << ") to be larger than reverse axis ("
1984  << reverseAxis << ")";
1985  }
1986  return success();
1987 }
1988 
1989 // parse and print of WhileOp refer to the implementation of SCF dialect.
1990 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
1993  Region *cond = result.addRegion();
1994  Region *body = result.addRegion();
1995 
1996  OptionalParseResult listResult =
1997  parser.parseOptionalAssignmentList(regionArgs, operands);
1998  if (listResult.has_value() && failed(listResult.value()))
1999  return failure();
2000 
2001  FunctionType functionType;
2002  SMLoc typeLoc = parser.getCurrentLocation();
2003  if (failed(parser.parseColonType(functionType)))
2004  return failure();
2005 
2006  result.addTypes(functionType.getResults());
2007 
2008  if (functionType.getNumInputs() != operands.size()) {
2009  return parser.emitError(typeLoc)
2010  << "expected as many input types as operands "
2011  << "(expected " << operands.size() << " got "
2012  << functionType.getNumInputs() << ")";
2013  }
2014 
2015  // Resolve input operands.
2016  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
2017  parser.getCurrentLocation(),
2018  result.operands)))
2019  return failure();
2020 
2021  // Propagate the types into the region arguments.
2022  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
2023  regionArgs[i].type = functionType.getInput(i);
2024 
2025  return failure(parser.parseRegion(*cond, regionArgs) ||
2026  parser.parseKeyword("do") || parser.parseRegion(*body) ||
2028 }
2029 
2031  Block::BlockArgListType blocksArgs,
2032  ValueRange initializers,
2033  StringRef prefix = "") {
2034  assert(blocksArgs.size() == initializers.size() &&
2035  "expected same length of arguments and initializers");
2036  if (initializers.empty())
2037  return;
2038 
2039  parser << prefix << '(';
2040  llvm::interleaveComma(
2041  llvm::zip(blocksArgs, initializers), parser,
2042  [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
2043  parser << ")";
2044 }
2045 
2046 void WhileOp::print(OpAsmPrinter &parser) {
2047  printInitializationList(parser, getCond().front().getArguments(), getInputs(),
2048  " ");
2049  parser << " : ";
2050  parser.printFunctionalType(getInputs().getTypes(), getResults().getTypes());
2051  parser << ' ';
2052  parser.printRegion(getCond(), /*printEntryBlockArgs=*/false);
2053  parser << " do ";
2054  parser.printRegion(getBody());
2055  parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
2056 }
2057 
2058 //===----------------------------------------------------------------------===//
2059 // TOSA Attribute Definitions.
2060 //===----------------------------------------------------------------------===//
2061 
2062 #define GET_ATTRDEF_CLASSES
2063 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
2064 
2065 //===----------------------------------------------------------------------===//
2066 // TOSA Operator Definitions.
2067 //===----------------------------------------------------------------------===//
2068 
2069 #define GET_OP_CLASSES
2070 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition: TosaOps.cpp:428
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:1312
#define REDUCE_SHAPE_INFER(OP)
Definition: TosaOps.cpp:1337
static LogicalResult verifyConvOp(T op)
Definition: TosaOps.cpp:208
static void buildUnaryOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter unary operators that have scale relationship between their...
Definition: TosaOps.cpp:483
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:1481
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition: TosaOps.cpp:496
static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias)
The tosa.fully_connected op has its own builder as it does not have strides/dilation/padding.
Definition: TosaOps.cpp:409
static LogicalResult verifyReduceOp(T op)
Definition: TosaOps.cpp:1362
#define NARY_SHAPE_INFER(OP)
Definition: TosaOps.cpp:1430
static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings, Value padConst)
This builder is called on TOSA pad operator when an explicit pad_const value is passed in.
Definition: TosaOps.cpp:508
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:1418
#define COMPATIBLE_RETURN_TYPES(OP)
Definition: TosaOps.cpp:1328
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition: TosaOps.cpp:524
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition: TosaOps.cpp:388
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr but avg_pool operator has...
Definition: TosaOps.cpp:465
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition: TosaOps.cpp:835
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition: TosaOps.cpp:365
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Definition: TosaOps.cpp:2030
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:140
IntegerType getI32Type()
Definition: Builders.cpp:99
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:103
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:94
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:213
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:480
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:100
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:62
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:129
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Definition: QuantUtils.cpp:120
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
Definition: QuantUtils.cpp:239
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
Definition: QuantUtils.cpp:219
ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &attr)
Definition: TosaOps.cpp:164
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
Definition: QuantUtils.cpp:164
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, Attribute attr)
Definition: TosaOps.cpp:186
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Definition: QuantUtils.cpp:193
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:496
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
Definition: ShapeUtils.h:33
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45