MLIR  20.0.0git
TosaOps.cpp
Go to the documentation of this file.
1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://developer.mlplatform.org/w/tosa/
12 //
13 //===----------------------------------------------------------------------===//
14 
22 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 #include <numeric>
34 
35 using namespace mlir;
36 using namespace mlir::tosa;
37 
38 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
39 
40 //===----------------------------------------------------------------------===//
41 // Tosa dialect interface includes.
42 //===----------------------------------------------------------------------===//
43 
44 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
45 
46 namespace {
47 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
48 
49 //===----------------------------------------------------------------------===//
50 // Dialect Function Inliner Interface.
51 //===----------------------------------------------------------------------===//
52 struct TosaInlinerInterface : public DialectInlinerInterface {
54 
55  //===--------------------------------------------------------------------===//
56  // Analysis Hooks.
57  //===--------------------------------------------------------------------===//
58 
59  /// All operations can be inlined by default.
60  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
61  IRMapping &map) const final {
62  return true;
63  }
64 
65  /// All regions with If and While parent operators can be inlined.
66  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
67  IRMapping &map) const final {
68  return (isa<tosa::IfOp>(dest->getParentOp()) ||
69  isa<tosa::WhileOp>(dest->getParentOp()));
70  }
71 };
72 
73 /// This class implements the bytecode interface for the Tosa dialect.
74 struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
75  TosaDialectBytecodeInterface(Dialect *dialect)
76  : BytecodeDialectInterface(dialect) {}
77 
78  //===--------------------------------------------------------------------===//
79  // Attributes
80 
81  Attribute readAttribute(DialectBytecodeReader &reader) const override {
82  return ::readAttribute(getContext(), reader);
83  }
84 
85  LogicalResult writeAttribute(Attribute attr,
86  DialectBytecodeWriter &writer) const override {
87  return ::writeAttribute(attr, writer);
88  }
89 
90  //===--------------------------------------------------------------------===//
91  // Types
92 
93  Type readType(DialectBytecodeReader &reader) const override {
94  return ::readType(getContext(), reader);
95  }
96 
97  LogicalResult writeType(Type type,
98  DialectBytecodeWriter &writer) const override {
99  return ::writeType(type, writer);
100  }
101 
102  void writeVersion(DialectBytecodeWriter &writer) const final {
103  // TODO: Populate.
104  }
105 
106  std::unique_ptr<DialectVersion>
107  readVersion(DialectBytecodeReader &reader) const final {
108  // TODO: Populate
109  reader.emitError("Dialect does not support versioning");
110  return nullptr;
111  }
112 
113  LogicalResult upgradeFromVersion(Operation *topLevelOp,
114  const DialectVersion &version) const final {
115  return success();
116  }
117 };
118 
119 } // namespace
120 
121 //===----------------------------------------------------------------------===//
122 // TOSA control flow support.
123 //===----------------------------------------------------------------------===//
124 
125 /// Returns the while loop body.
126 SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
127 
128 //===----------------------------------------------------------------------===//
129 // Tosa dialect initialization.
130 //===----------------------------------------------------------------------===//
131 
132 void TosaDialect::initialize() {
133  addOperations<
134 #define GET_OP_LIST
135 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
136  >();
137  addAttributes<
138 #define GET_ATTRDEF_LIST
139 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
140  >();
141  addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
142  declarePromisedInterfaces<
143  mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
144  ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
145  LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
146  LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
147  BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
148  NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
149  GreaterEqualOp, MatMulOp>();
150 }
151 
153  Type type, Location loc) {
154  // Tosa dialect constants only support ElementsAttr unlike standard dialect
155  // constant which supports all attributes.
156  if (llvm::isa<ElementsAttr>(value))
157  return builder.create<tosa::ConstOp>(loc, type,
158  llvm::cast<ElementsAttr>(value));
159  return nullptr;
160 }
161 
162 //===----------------------------------------------------------------------===//
163 // Parsers and printers
164 //===----------------------------------------------------------------------===//
165 
166 ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
167  Attribute &attr) {
168  if (succeeded(parser.parseOptionalEqual())) {
169  if (failed(parser.parseAttribute(attr))) {
170  return parser.emitError(parser.getCurrentLocation())
171  << "expected attribute";
172  }
173  if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
174  typeAttr = TypeAttr::get(typedAttr.getType());
175  }
176  return success();
177  }
178 
179  Type type;
180  if (failed(parser.parseColonType(type))) {
181  return parser.emitError(parser.getCurrentLocation()) << "expected type";
182  }
183  typeAttr = TypeAttr::get(type);
184 
185  return success();
186 }
187 
189  Attribute attr) {
190  bool needsSpace = false;
191  auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
192  if (!typedAttr || typedAttr.getType() != type.getValue()) {
193  p << ": ";
194  p.printAttribute(type);
195  needsSpace = true; // subsequent attr value needs a space separator
196  }
197  if (attr) {
198  if (needsSpace)
199  p << ' ';
200  p << "= ";
201  p.printAttribute(attr);
202  }
203 }
204 
205 //===----------------------------------------------------------------------===//
206 // TOSA Operator Verifiers.
207 //===----------------------------------------------------------------------===//
208 
209 template <typename T>
210 static LogicalResult verifyConvOp(T op) {
211  // All TOSA conv ops have an input() and weight().
212  auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
213  auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
214 
215  // Must be ranked tensor types
216  if (!inputType) {
217  op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
218  return failure();
219  }
220  if (!weightType) {
221  op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
222  return failure();
223  }
224 
225  auto inputEType = inputType.getElementType();
226  auto weightEType = weightType.getElementType();
227 
228  bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
229  bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
230 
231  // Either both must be quantized or both unquantized.
232  if (inputIsQuant != weightIsQuant) {
233  op.emitOpError(
234  "expect both input and weight to be float or not together, got ")
235  << inputEType << " and " << weightEType;
236  return failure();
237  }
238 
239  // Quantized type must have constructed the quantizationattr, and unquantized
240  // types should not have a quantizationattr.
241  if ((inputIsQuant && !op.getQuantizationInfo()) ||
242  (!inputIsQuant && op.getQuantizationInfo())) {
243  op.emitOpError("quantizationattr is required for quantized type, and not "
244  "allowed for float type");
245  return failure();
246  }
247  return success();
248 }
249 
250 LogicalResult tosa::ConstOp::verify() {
251 
252  auto attrType = llvm::dyn_cast<TensorType>(getValueAttr().getType());
253  auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
254 
255  if (!attrType || !outputType) {
256  emitOpError("expected tensors for attr/result type");
257  return failure();
258  }
259 
260  if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
261  outputType.getElementType())) {
262  if (result.getStorageType() == attrType.getElementType())
263  return success();
264  }
265 
266  if (attrType.getElementType() != outputType.getElementType()) {
267  emitOpError("expected same attr/result element types");
268  return failure();
269  }
270 
271  return success();
272 }
273 
274 LogicalResult tosa::ArgMaxOp::verify() {
275  // Ensure output is of 32-bit integer
276  const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
277  if (!resultETy.isIntOrIndex())
278  return emitOpError("result tensor is not of integer type");
279 
280  // Ensure axis is within the tensor rank
281  const auto inputType = llvm::cast<ShapedType>(getInput().getType());
282  const int64_t axis = getAxisAttr().getInt();
283  if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
284  return emitOpError("specified axis is outside the rank of the tensor");
285 
286  return success();
287 }
288 
289 LogicalResult tosa::AvgPool2dOp::verify() {
290  auto inputType = llvm::cast<ShapedType>(getInput().getType());
291 
292  auto inputETy = inputType.getElementType();
293  auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
294 
295  if (auto quantType =
296  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
297  inputETy = quantType.getStorageType();
298 
299  if (auto quantType =
300  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy))
301  resultETy = quantType.getStorageType();
302 
303  auto accType = getAccType();
304  if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
305  return emitOpError("accumulator type for integer tensor is not i32");
306 
307  if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
308  return emitOpError("accumulator type for f16 tensor is not f16/f32");
309 
310  if (inputETy.isBF16() && !accType.isF32())
311  return emitOpError("accumulator type for bf16 tensor is not f32");
312 
313  if (inputETy.isF32() && !accType.isF32())
314  return emitOpError("accumulator type for f32 tensor is not f32");
315 
316  if ((inputETy.isF32() && resultETy.isF32()) ||
317  (inputETy.isF16() && resultETy.isF16()) ||
318  (inputETy.isBF16() && resultETy.isBF16()) ||
319  (inputETy.isInteger(8) && resultETy.isInteger(8)) ||
320  (inputETy.isInteger(16) && resultETy.isInteger(16)))
321  return success();
322 
323  return emitOpError("input/output element types are incompatible.");
324 }
325 
326 LogicalResult tosa::ClampOp::verify() {
327  mlir::Type inputETy =
328  llvm::cast<ShapedType>(getInput().getType()).getElementType();
329  if (auto quantType =
330  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
331  inputETy = quantType.getStorageType();
332  }
333  mlir::Type maxFpType = getMaxFpAttr().getType();
334  mlir::Type minFpType = getMinFpAttr().getType();
335  mlir::Type outputETy =
336  llvm::cast<ShapedType>(getOutput().getType()).getElementType();
337  if (auto quantType =
338  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
339  outputETy = quantType.getStorageType();
340  }
341  unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
342 
343  if (inputETy != outputETy)
344  return emitOpError("input/output element types are incompatible.");
345 
346  // If input datatype is float, check that the two min/max_fp attributes
347  // share the same type and that their type is either the same of the input's
348  // datatype, or a float type whose bitwidth > input datatype bitwidth.
349  if (!inputETy.isInteger(dataTypeBitWidth)) {
350  if (((maxFpType != minFpType) ||
351  (maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <=
352  inputETy.getIntOrFloatBitWidth())))
353  return emitOpError("min/max attributes types are incompatible with "
354  "input/output element types.");
355  }
356 
357  return success();
358 }
359 
360 //===----------------------------------------------------------------------===//
361 // TOSA Operator Quantization Builders.
362 //===----------------------------------------------------------------------===//
363 
364 /// This builder is called on all convolution operators except TransposeConv,
365 /// which has specialized output shape semantics. The builder also defines the
366 /// bitwidth of the output given the bit width of the input & weight content.
367 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
368  Type outputType, Value input, Value weight,
369  Value bias, DenseI64ArrayAttr pad,
370  DenseI64ArrayAttr stride,
371  DenseI64ArrayAttr dilation) {
372 
373  result.addOperands({input, weight, bias});
374  result.addAttribute("pad", pad);
375  result.addAttribute("stride", stride);
376  result.addAttribute("dilation", dilation);
377 
378  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
379  if (quantAttr) {
380  result.addAttribute("quantization_info", quantAttr);
381  result.addTypes(
382  buildConvOpResultTypeInfo(builder, outputType, input, weight));
383  } else {
384  result.addTypes(outputType);
385  }
386 }
387 
388 /// Handles tosa.transpose_conv2d which has outpad and output shape
389 /// attributes.
391  OpBuilder &builder, OperationState &result, Type outputType, Value input,
392  Value weight, Value bias, DenseI64ArrayAttr outpad,
393  DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) {
394  result.addOperands({input, weight, bias});
395  result.addAttribute("out_pad", outpad);
396  result.addAttribute("stride", stride);
397  result.addAttribute("out_shape", outputShape);
398  auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
399 
400  if (quantAttr) {
401  result.addAttribute("quantization_info", quantAttr);
402  result.addTypes(
403  buildConvOpResultTypeInfo(builder, outputType, input, weight));
404  } else {
405  result.addTypes(outputType);
406  }
407 }
408 
409 /// The tosa.fully_connected op has its own builder as it does not have
410 /// strides/dilation/padding.
411 static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
412  Type outputType, Value input, Value weight,
413  Value bias) {
414 
415  result.addOperands({input, weight, bias});
416  auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
417  if (quantAttr) {
418  result.addAttribute("quantization_info", quantAttr);
419  result.addTypes(
420  buildConvOpResultTypeInfo(builder, outputType, input, weight));
421  } else {
422  result.addTypes(outputType);
423  }
424 }
425 
426 /// The tosa.matmul op is also intended to be generated where a
427 /// fully_connected op must be constructed where the weight is not a constant.
428 /// In this case, the fully_connected op must be expressed using matmul.
429 /// TODO: Add link to the leglization document explaining this.
431  OperationState &result, Type outputType,
432  Value a, Value b) {
433  result.addOperands({a, b});
434  auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
435 
436  if (quantAttr) {
437  result.addAttribute("quantization_info", quantAttr);
438 
439  auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
440  assert(inputType && "Input must be a shaped tensor type!");
441 
442  auto inputQType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
443  inputType.getElementType());
444  assert(inputQType && "Tensor must have quantized datatype!");
445 
446  unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
447 
448  auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
449  assert(outputShapedType && "Output must be a shaped type");
450 
451  IntegerType accElementType;
452  if (inputBits == 16)
453  accElementType = builder.getIntegerType(48);
454  else
455  accElementType = builder.getI32Type();
456  auto accType = outputShapedType.clone(accElementType);
457  result.addTypes(accType);
458  } else {
459  result.addTypes(outputType);
460  }
461 }
462 
463 /// Both the tosa.avg_pool2d and unary ops use the same
464 /// UnaruOpQuantizationAttr but avg_pool operator has its own builder as it
465 /// has additional parameters not part of the unary ops.
466 static void
468  Type outputType, Value input,
469  DenseArrayAttr kernel, DenseArrayAttr stride,
470  DenseArrayAttr pad, TypeAttr accType) {
471  result.addOperands(input);
472  result.addAttribute("kernel", kernel);
473  result.addAttribute("stride", stride);
474  result.addAttribute("pad", pad);
475  result.addAttribute("acc_type", accType);
476  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
477  if (quantAttr)
478  result.addAttribute("quantization_info", quantAttr);
479  result.types.push_back(outputType);
480 }
481 
482 /// This builder is called on single-parameter unary operators that have scale
483 /// relationship between their input and output, expressed by the
484 /// UnaryOpQuantizationAttr.
486  OperationState &result, Type outputType,
487  Value input) {
488  result.addOperands(input);
489  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
490  if (quantAttr)
491  result.addAttribute("quantization_info", quantAttr);
492  result.types.push_back(outputType);
493 }
494 
495 /// This builder is called on TOSA pad operator that needs to create its own
496 /// OptionalAttr quantization_attr parameter to scale the padding values
497 /// correctly. No pad_const is interpreted as zero-padding.
498 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
499  Type outputType, Value input,
500  Value paddings) {
501  result.addOperands({input, paddings});
502  auto quantAttr = buildPadOpQuantizationAttr(builder, input);
503  if (quantAttr)
504  result.addAttribute("quantization_info", quantAttr);
505  result.types.push_back(outputType);
506 }
507 
508 /// This builder is called on TOSA pad operator when an explicit pad_const
509 /// value is passed in. It also optionally constructs quantization_attr.
511  OperationState &result,
512  Type outputType, Value input,
513  Value paddings,
514  Value padConst) {
515  result.addOperands({input, paddings, padConst});
516  auto quantAttr = buildPadOpQuantizationAttr(builder, input);
517  if (quantAttr)
518  result.addAttribute("quantization_info", quantAttr);
519  result.types.push_back(outputType);
520 }
521 
522 //===----------------------------------------------------------------------===//
523 // TOSA Operator Return Type Inference.
524 //===----------------------------------------------------------------------===//
525 
526 static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
527  SmallVector<int64_t> &outShape) {
528  int64_t outRank = 0;
529  for (int i = 0, e = operands.size(); i != e; ++i) {
530  auto shape = operands.getShape(i);
531  if (!shape.hasRank()) {
532  // TODO(jennik): Update function to have better case handling for
533  // invalid operands and for ranked tensors.
534  return failure();
535  }
536  outRank = std::max<int64_t>(outRank, shape.getRank());
537  }
538 
539  outShape.resize(outRank, 1);
540 
541  for (int i = 0, e = operands.size(); i != e; ++i) {
542  auto shape = operands.getShape(i);
543  auto rankDiff = outShape.size() - shape.getRank();
544 
545  for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
546  auto dim1 = outShape[i + rankDiff];
547  auto dim2 = shape.getDimSize(i);
548  auto resolvedDim = dim1;
549 
550  if (dim1 == 1) {
551  resolvedDim = dim2;
552  } else if (dim2 == 1) {
553  resolvedDim = dim1;
554  } else if (dim1 != dim2) {
555  return failure();
556  }
557  outShape[i + rankDiff] = resolvedDim;
558  }
559  }
560 
561  return success();
562 }
563 
564 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
565  MLIRContext *context, ::std::optional<Location> location,
566  ArgMaxOp::Adaptor adaptor,
567  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
568  ShapeAdaptor inputShape(adaptor.getInput().getType());
569  IntegerAttr axis = adaptor.getProperties().axis;
570  int32_t axisVal = axis.getValue().getSExtValue();
571 
572  if (!inputShape.hasRank()) {
573  inferredReturnShapes.push_back(ShapedTypeComponents());
574  return success();
575  }
576 
577  SmallVector<int64_t> outShape;
578  outShape.reserve(inputShape.getRank() - 1);
579  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
580  if (i == axisVal)
581  continue;
582  outShape.push_back(inputShape.getDimSize(i));
583  }
584 
585  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
586  return success();
587 }
588 
589 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
590  MLIRContext *context, ::std::optional<Location> location,
591  RFFT2dOp::Adaptor adaptor,
592  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
593  ShapeAdaptor inputShape(adaptor.getInput().getType());
594 
595  if (!inputShape.hasRank())
596  return failure();
597 
598  llvm::SmallVector<int64_t> outputShape;
599  outputShape.resize(3, ShapedType::kDynamic);
600  outputShape[0] = inputShape.getDimSize(0);
601  outputShape[1] = inputShape.getDimSize(1);
602  int64_t inWidth = inputShape.getDimSize(2);
603 
604  // Note that we can support this calculation symbolically
605  // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1]
606  if (inWidth != ShapedType::kDynamic)
607  outputShape[2] = inWidth / 2 + 1;
608 
609  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
610  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
611 
612  return success();
613 }
614 
615 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
616  MLIRContext *context, ::std::optional<Location> location,
617  FFT2dOp::Adaptor adaptor,
618  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
619  inferredReturnShapes.push_back(
620  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
621  inferredReturnShapes.push_back(
622  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
623  return success();
624 }
625 
626 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
627  MLIRContext *context, ::std::optional<Location> location,
628  ConcatOp::Adaptor adaptor,
629  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
630  // Infer all dimension sizes by reducing based on inputs.
631  const Properties &prop = adaptor.getProperties();
632  int32_t axis = prop.axis.getValue().getSExtValue();
633  llvm::SmallVector<int64_t> outputShape;
634  bool hasRankedInput = false;
635  for (auto operand : adaptor.getOperands()) {
636  ShapeAdaptor operandShape(operand.getType());
637  if (!operandShape.hasRank())
638  continue;
639 
640  // Copy the Operand's rank.
641  if (!hasRankedInput)
642  outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
643 
644  // Copy shapes until the dim is non-dynamic.
645  for (int i = 0, s = operandShape.getRank(); i < s; i++) {
646  if (i == axis || operandShape.isDynamicDim(i))
647  continue;
648  if (outputShape[i] == ShapedType::kDynamic)
649  outputShape[i] = operandShape.getDimSize(i);
650  if (outputShape[i] != operandShape.getDimSize(i))
651  return emitOptionalError(location,
652  "Cannot concat tensors with different sizes"
653  " on the non-axis dimension ",
654  i);
655  }
656 
657  hasRankedInput = true;
658  }
659  Type inputType =
660  llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
661  if (!hasRankedInput) {
662  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
663  return success();
664  }
665 
666  // Determine the dimension size along the concatenation axis.
667  int64_t concatDimSize = 0;
668  for (auto operand : adaptor.getOperands()) {
669  ShapeAdaptor operandShape(operand.getType());
670 
671  // We need to know the length of the concatenation axis of all inputs to
672  // determine the dimension size of the output shape.
673  if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
674  concatDimSize = ShapedType::kDynamic;
675  break;
676  }
677 
678  concatDimSize += operandShape.getDimSize(axis);
679  }
680 
681  outputShape[axis] = concatDimSize;
682 
683  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
684  return success();
685 }
686 
687 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
688  MLIRContext *context, ::std::optional<Location> location,
689  ValueShapeRange operands, DictionaryAttr attributes,
690  OpaqueProperties properties, RegionRange regions,
691  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
692  auto elementType = IntegerType::get(context, /*width=*/1);
693 
695  if (resolveBroadcastShape(operands, outShape).failed()) {
696  inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
697  return success();
698  }
699 
700  inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
701  return success();
702 }
703 
704 bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
705  if (l.size() != r.size() || l.size() != 1)
706  return false;
707  return succeeded(verifyCompatibleShape(l[0], r[0]));
708 }
709 
710 LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
711  MLIRContext *context, ::std::optional<Location> location,
712  FullyConnectedOp::Adaptor adaptor,
713  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
714  ShapeAdaptor inputShape(adaptor.getInput().getType());
715  ShapeAdaptor weightShape(adaptor.getWeight().getType());
716  ShapeAdaptor biasShape(adaptor.getBias().getType());
717 
718  // All shapes are dynamic.
719  SmallVector<int64_t> outShape;
720  outShape.resize(2, ShapedType::kDynamic);
721 
722  if (inputShape.hasRank()) {
723  outShape[0] = inputShape.getDimSize(0);
724  }
725 
726  if (weightShape.hasRank()) {
727  outShape[1] = weightShape.getDimSize(0);
728  }
729 
730  if (biasShape.hasRank()) {
731  outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0)
732  : outShape[1];
733  }
734 
735  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
736  return success();
737 }
738 
739 LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); }
740 
741 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
742  MLIRContext *context, ::std::optional<Location> location,
743  MatMulOp::Adaptor adaptor,
744  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
745  ShapeAdaptor lhsShape(adaptor.getA().getType());
746  ShapeAdaptor rhsShape(adaptor.getB().getType());
747 
748  // All shapes are dynamic.
749  SmallVector<int64_t> outShape;
750  outShape.resize(3, ShapedType::kDynamic);
751 
752  if (lhsShape.hasRank()) {
753  outShape[0] = lhsShape.getDimSize(0);
754  outShape[1] = lhsShape.getDimSize(1);
755  }
756 
757  if (rhsShape.hasRank()) {
758  outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
759  : outShape[0];
760  outShape[2] = rhsShape.getDimSize(2);
761  }
762 
763  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
764  return success();
765 }
766 
767 LogicalResult tosa::PadOp::inferReturnTypeComponents(
768  MLIRContext *context, ::std::optional<Location> location,
769  PadOp::Adaptor adaptor,
770  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
771  ShapeAdaptor inputShape(adaptor.getInput1().getType());
772  ShapeAdaptor paddingShape(adaptor.getPadding().getType());
773  SmallVector<int64_t> outputShape;
774 
775  // If both inputs have unknown shape, we cannot determine the shape of the
776  // output.
777  if (!inputShape.hasRank() && !paddingShape.hasRank()) {
778  inferredReturnShapes.push_back(ShapedTypeComponents());
779  return success();
780  }
781 
782  // If the input rank is unknown we can info the output rank using the
783  // padding shape's first dim.
784  if (!inputShape.hasRank()) {
785  if (paddingShape.isDynamicDim(0)) {
786  inferredReturnShapes.push_back(ShapedTypeComponents());
787  return success();
788  }
789 
790  outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamic);
791  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
792  return success();
793  }
794 
795  DenseIntElementsAttr paddings;
796  // If the paddings value is not a constant, all dimensions must be dynamic.
797  if (!matchPattern(adaptor.getPadding(), m_Constant(&paddings))) {
798  outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
799  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
800  return success();
801  }
802 
803  SmallVector<int64_t> paddingValues;
804  for (auto val : paddings) {
805  paddingValues.push_back(val.getSExtValue());
806  }
807 
808  outputShape.reserve(inputShape.getRank());
809  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
810  if (inputShape.isDynamicDim(i)) {
811  outputShape.push_back(ShapedType::kDynamic);
812  continue;
813  }
814 
815  outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] +
816  paddingValues[i * 2 + 1]);
817  }
818 
819  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
820  return success();
821 }
822 
823 LogicalResult tosa::PadOp::verify() {
824  RankedTensorType inputType = getInput1().getType();
825  RankedTensorType outputType = getOutput().getType();
826  TensorType paddingType = getPadding().getType();
827 
828  if (inputType.getRank() != outputType.getRank())
829  return emitOpError() << "expect same input and output tensor rank.";
830 
831  if (paddingType.hasRank() && paddingType.getRank() != 2)
832  return emitOpError() << "expect 'padding' tensor rank equal to 2.";
833 
834  return success();
835 }
836 
838  return to_vector(llvm::map_range(shape, [](int64_t dim) {
839  return dim == -1 ? ShapedType::kDynamic : dim;
840  }));
841 }
842 
843 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
844  MLIRContext *context, ::std::optional<Location> location,
845  SliceOp::Adaptor adaptor,
846  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
847  inferredReturnShapes.push_back(
848  ShapedTypeComponents(convertToMlirShape(adaptor.getSize())));
849  return success();
850 }
851 
852 LogicalResult tosa::SliceOp::verify() {
853  auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
854  if (!inputType)
855  return success();
856 
857  if (static_cast<size_t>(inputType.getRank()) != getStart().size())
858  return emitOpError(
859  "length of start attribute is not equal rank of input shape");
860 
861  if (static_cast<size_t>(inputType.getRank()) != getSize().size())
862  return emitOpError(
863  "length of size attribute is not equal rank of input shape");
864 
865  return success();
866 }
867 
868 LogicalResult tosa::MulOp::verify() {
869  Type elementTy = getInput1().getType().getElementType();
870  if (isa<FloatType>(elementTy) && getShift() != 0)
871  return emitOpError() << "require shift to be 0 for float type";
872 
873  return success();
874 }
875 
876 LogicalResult tosa::TableOp::inferReturnTypeComponents(
877  MLIRContext *context, ::std::optional<Location> location,
878  TableOp::Adaptor adaptor,
879  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
880  ShapeAdaptor inputShape(adaptor.getInput1().getType());
881 
882  if (!inputShape.hasRank()) {
883  inferredReturnShapes.push_back(ShapedTypeComponents());
884  return success();
885  }
886 
887  inferredReturnShapes.resize(1);
888  inputShape.getDims(inferredReturnShapes[0]);
889  return success();
890 }
891 
892 LogicalResult tosa::TableOp::verify() {
893  TensorType inputType = getInput1().getType();
894  TensorType outputType = getOutput().getType();
895 
896  if (inputType.hasRank() && outputType.hasRank() &&
897  inputType.getRank() != outputType.getRank())
898  return emitOpError()
899  << "expected input tensor rank to equal result tensor rank";
900 
901  auto inputDims = inputType.getShape();
902  auto outputDims = outputType.getShape();
903  for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
904  int64_t dim = it.index();
905  auto [inputDim, outputDim] = it.value();
906  if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) {
907  return emitOpError() << "dim(result, " << dim << ") = " << outputDim
908  << " doesn't match dim(input, " << dim
909  << ") = " << inputDim;
910  }
911  }
912  return success();
913 }
914 
915 LogicalResult tosa::TileOp::inferReturnTypeComponents(
916  MLIRContext *context, ::std::optional<Location> location,
917  TileOp::Adaptor adaptor,
918  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
919  ArrayRef<int64_t> multiples = adaptor.getMultiples();
920  ShapeAdaptor inputShape(adaptor.getInput1().getType());
921  SmallVector<int64_t> outputShape;
922  if (!inputShape.hasRank()) {
923  outputShape.resize(multiples.size(), ShapedType::kDynamic);
924  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
925  return success();
926  } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
927  return failure();
928 
929  // Any non dynamic dimension can be multiplied to a known size.
930  outputShape.reserve(multiples.size());
931  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
932  int64_t dim = inputShape.getDimSize(i);
933  if (dim != ShapedType::kDynamic)
934  dim *= multiples[i];
935  outputShape.push_back(dim);
936  }
937 
938  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
939  return success();
940 }
941 
942 LogicalResult tosa::TileOp::verify() {
943  ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
944  ShapedType outputType = llvm::cast<ShapedType>(getType());
945  auto multiples = getMultiples();
946 
947  if (inputType.hasRank()) {
948  if (static_cast<size_t>(inputType.getRank()) != multiples.size())
949  return emitOpError("expect 'multiples' array to have length ")
950  << inputType.getRank() << " but got " << multiples.size() << ".";
951  if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
952  return emitOpError("expect same input and output tensor rank.");
953  } else if (outputType.hasRank() &&
954  static_cast<size_t>(outputType.getRank()) != multiples.size())
955  return emitOpError("expect 'multiples' array to have length ")
956  << outputType.getRank() << " but got " << multiples.size() << ".";
957 
958  if (llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
959  return emitOpError(
960  "expect element of 'multiples' to be positive integer or -1.");
961 
962  return success();
963 }
964 
965 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
966  if (l.size() != r.size() || l.size() != 1)
967  return false;
968  return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
969 }
970 
971 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
972  MLIRContext *context, ::std::optional<Location> location,
973  ReshapeOp::Adaptor adaptor,
974  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
975  ShapeAdaptor inputShape(adaptor.getInput1().getType());
976  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
977  llvm::SmallVector<int64_t> newShapeValue =
978  convertToMlirShape(adaptor.getNewShape());
979 
980  // We cannot infer from the total number of elements so we must take the
981  // shape attribute as exact.
982  if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
983  inferredReturnShapes.push_back(
984  ShapedTypeComponents(newShapeValue, inputType));
985  return success();
986  }
987 
988  // Determine the number of elements covered by the slice of all static
989  // dimensions. This allows us to infer the length of the remaining dynamic
990  // dimension.
991  int64_t numElements = inputShape.getNumElements();
992  int64_t staticMul = 1;
993  for (auto val : newShapeValue) {
994  if (!ShapedType::isDynamic(val)) {
995  staticMul *= val;
996  }
997  }
998 
999  // Determine the length of the dynamic dimension.
1000  for (auto &val : newShapeValue) {
1001  if (ShapedType::isDynamic(val))
1002  val = numElements / staticMul;
1003  }
1004 
1005  inferredReturnShapes.push_back(
1006  ShapedTypeComponents(newShapeValue, inputType));
1007  return success();
1008 }
1009 
1010 llvm::LogicalResult tosa::ReshapeOp::verify() {
1011  TensorType inputType = getInput1().getType();
1012  RankedTensorType outputType = getType();
1013 
1014  if ((int64_t)getNewShape().size() != outputType.getRank())
1015  return emitOpError() << "new shape does not match result rank";
1016 
1017  for (auto [newShapeDim, outputShapeDim] :
1018  zip(getNewShape(), outputType.getShape())) {
1019  if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
1020  newShapeDim != outputShapeDim)
1021  return emitOpError() << "new shape is inconsistent with result shape";
1022 
1023  if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
1024  return emitOpError() << "new shape has invalid tensor dimension size "
1025  << newShapeDim;
1026  }
1027 
1028  if (inputType.hasStaticShape()) {
1029  int64_t inputElementsNum = inputType.getNumElements();
1030  if (outputType.hasStaticShape()) {
1031  int64_t outputElementsNum = outputType.getNumElements();
1032  if (inputElementsNum != outputElementsNum) {
1033  return emitOpError() << "cannot reshape " << inputElementsNum
1034  << " elements into " << outputElementsNum;
1035  }
1036  }
1037 
1038  int64_t newShapeElementsNum = std::accumulate(
1039  getNewShape().begin(), getNewShape().end(), 1LL,
1040  [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
1041  bool isStaticNewShape =
1042  llvm::all_of(getNewShape(), [](int64_t s) { return s > 0; });
1043  if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
1044  (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
1045  return emitOpError() << "cannot reshape " << inputElementsNum
1046  << " elements into " << newShapeElementsNum;
1047  }
1048  }
1049 
1050  int missingDims = llvm::count(getNewShape(), -1);
1051  if (missingDims > 1)
1052  return emitOpError() << "expected at most one target dimension to be -1";
1053 
1054  return mlir::success();
1055 }
1056 
1057 LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) {
1058  // Perms must be constants.
1059  DenseIntElementsAttr permsAttr;
1060  if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
1061  return failure();
1062 
1063  perms.clear();
1064  for (auto v : permsAttr.getValues<APInt>())
1065  perms.push_back(v.getSExtValue());
1066 
1067  return success();
1068 }
1069 
1070 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
1071  MLIRContext *context, ::std::optional<Location> location,
1072  TransposeOp::Adaptor adaptor,
1073  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1074  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1075  ShapeAdaptor permsShape(adaptor.getPerms().getType());
1076 
1077  // We cannot infer anything from a rank-0 "permutation" tensor.
1078  if (permsShape.hasRank() && permsShape.getRank() == 0)
1079  return failure();
1080 
1081  // If input rank and permutation length is unknown, the output rank is
1082  // unknown.
1083  if (!inputShape.hasRank() || !permsShape.hasRank() ||
1084  permsShape.isDynamicDim(0)) {
1085  inferredReturnShapes.push_back(ShapedTypeComponents());
1086  return success();
1087  }
1088 
1089  // This would imply the number of permutations does not match the rank of
1090  // the input which is illegal.
1091  if (permsShape.getDimSize(0) != inputShape.getRank()) {
1092  return failure();
1093  }
1094 
1095  SmallVector<int64_t> outputShape;
1096  // Rank-0 means no permutations matter.
1097  if (inputShape.getRank() == 0) {
1098  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1099  return success();
1100  }
1101 
1102  // Check whether the input dimensions are all the same.
1103  bool allTheSame = true;
1104  for (int i = 1, s = inputShape.getRank(); i < s; i++) {
1105  if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
1106  allTheSame = false;
1107  break;
1108  }
1109  }
1110 
1111  // If all of the input dimensions are the same we don't care about the
1112  // permutation.
1113  if (allTheSame) {
1114  outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
1115  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1116  return success();
1117  }
1118 
1119  outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1120  // If the permuations are a constant we can directly determine the output
1121  // shape.
1122  DenseIntElementsAttr attr;
1123  if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
1124  attr.getType().getRank() == 1) {
1125  ShapeAdaptor permShape = attr;
1126  // Constant permutation must be the same length as the input rank.
1127  if (inputShape.getRank() != permShape.getRank())
1128  return emitOptionalError(location,
1129  "constant permutation must be the same length"
1130  " as the input rank");
1131 
1132  // Constant permutation values must be within the input rank.
1133  for (int i = 0, e = inputShape.getRank(); i < e; i++) {
1134  if (inputShape.getRank() <= permShape.getDimSize(i))
1135  return failure();
1136  }
1137 
1138  outputShape.reserve(inputShape.getRank());
1139  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1140  outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
1141  }
1142  }
1143 
1144  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1145  return success();
1146 }
1147 
1148 LogicalResult tosa::TransposeOp::verify() {
1149  TensorType inputType = getInput1().getType();
1150  TensorType permType = getPerms().getType();
1151  TensorType outputType = getOutput().getType();
1152 
1153  if (permType.hasRank() && permType.getRank() != 1)
1154  return emitOpError()
1155  << "expected permutation tensor to be rank 1 but got rank "
1156  << permType.getRank();
1157  if (inputType.hasRank() && permType.hasRank())
1158  if (!permType.isDynamicDim(0) &&
1159  permType.getDimSize(0) != inputType.getRank())
1160  return emitOpError() << "expected permutation tensor dim 0 to have size "
1161  << inputType.getRank()
1162  << " (input rank) but got size "
1163  << permType.getDimSize(0);
1164  if (inputType.hasRank() && outputType.hasRank() &&
1165  inputType.getRank() != outputType.getRank())
1166  return emitOpError()
1167  << "expected input tensor rank to equal result tensor rank";
1168  if (outputType.hasRank() && permType.hasRank())
1169  if (!permType.isDynamicDim(0) &&
1170  permType.getDimSize(0) != outputType.getRank())
1171  return emitOpError() << "expected permutation tensor dim 0 to have size "
1172  << outputType.getRank()
1173  << " (output rank) but got size "
1174  << permType.getDimSize(0);
1175 
1176  SmallVector<int32_t> constantPerms;
1177  if (succeeded(getConstantPerms(constantPerms))) {
1178  // Assert that the permutation tensor has a rank, which means that the
1179  // rank has been verified above.
1180  assert(permType.hasRank() &&
1181  "Unexpectedly found permutation tensor without rank");
1182  if (!llvm::all_of(constantPerms,
1183  [&constantPerms](int32_t s) {
1184  return s >= 0 &&
1185  static_cast<size_t>(s) < constantPerms.size();
1186  }) ||
1187  !isPermutationVector(llvm::to_vector(llvm::map_range(
1188  constantPerms, [](int32_t v) -> int64_t { return v; }))))
1189  return emitOpError() << "expected valid permutation tensor";
1190 
1191  // Verify that the types of the input and output tensors are properly
1192  // permuted.
1193  if (inputType.hasRank() && outputType.hasRank()) {
1194  assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
1195  inputType.getRank() == outputType.getRank());
1196 
1197  for (auto i = 0; i < outputType.getRank(); i++) {
1198  if (inputType.isDynamicDim(constantPerms[i]) ||
1199  outputType.isDynamicDim(i))
1200  continue;
1201 
1202  if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
1203  return emitOpError()
1204  << "expected output tensor dim " << i << " to match "
1205  << "input dim " << constantPerms[i] << " with value of "
1206  << inputType.getDimSize(constantPerms[i]);
1207  }
1208  }
1209  }
1210  return success();
1211 }
1212 
1213 LogicalResult TransposeOp::reifyResultShapes(
1214  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1215 
1216  SmallVector<int32_t> transposePerms;
1217  if (getConstantPerms(transposePerms).failed())
1218  return failure();
1219 
1220  Value input = getInput1();
1221  auto inputType = cast<TensorType>(input.getType());
1222 
1223  SmallVector<OpFoldResult> returnedDims(inputType.getRank());
1224  for (auto dim : transposePerms) {
1225  int32_t dimInInput = transposePerms[dim];
1226  if (inputType.isDynamicDim(dimInInput))
1227  returnedDims[dim] =
1228  builder.create<tensor::DimOp>(getLoc(), input, dimInInput)
1229  .getResult();
1230  else
1231  returnedDims[dim] =
1232  builder.getIndexAttr(inputType.getDimSize(dimInInput));
1233  }
1234 
1235  reifiedReturnShapes.emplace_back(std::move(returnedDims));
1236  return success();
1237 }
1238 
1239 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
1240  MLIRContext *context, ::std::optional<Location> location,
1241  GatherOp::Adaptor adaptor,
1242  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1243  llvm::SmallVector<int64_t> outputShape;
1244  outputShape.resize(3, ShapedType::kDynamic);
1245 
1246  ShapeAdaptor valuesShape(adaptor.getValues().getType());
1247  if (valuesShape.hasRank()) {
1248  outputShape[0] = valuesShape.getDimSize(0);
1249  outputShape[2] = valuesShape.getDimSize(2);
1250  }
1251 
1252  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
1253  if (indicesShape.hasRank()) {
1254  if (outputShape[0] == ShapedType::kDynamic)
1255  outputShape[0] = indicesShape.getDimSize(0);
1256  if (outputShape[1] == ShapedType::kDynamic)
1257  outputShape[1] = indicesShape.getDimSize(1);
1258  }
1259 
1260  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1261  return success();
1262 }
1263 
1264 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
1265  MLIRContext *context, ::std::optional<Location> location,
1266  ResizeOp::Adaptor adaptor,
1267  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1268  llvm::SmallVector<int64_t, 4> outputShape;
1269  outputShape.resize(4, ShapedType::kDynamic);
1270 
1271  ShapeAdaptor inputShape(adaptor.getInput().getType());
1272  if (!inputShape.hasRank())
1273  return failure();
1274 
1275  outputShape[0] = inputShape.getDimSize(0);
1276  outputShape[3] = inputShape.getDimSize(3);
1277  int64_t inputHeight = inputShape.getDimSize(1);
1278  int64_t inputWidth = inputShape.getDimSize(2);
1279 
1280  if ((inputHeight == ShapedType::kDynamic) ||
1281  (inputWidth == ShapedType::kDynamic))
1282  return failure();
1283 
1284  llvm::ArrayRef<int64_t> scaleInt = adaptor.getScale();
1285  llvm::ArrayRef<int64_t> offsetInt = adaptor.getOffset();
1286  llvm::ArrayRef<int64_t> borderInt = adaptor.getBorder();
1287 
1288  // Compute the output shape based on attributes: scale, offset, and border.
1289  outputShape[1] =
1290  (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
1291  scaleInt[1]) +
1292  1;
1293 
1294  outputShape[2] =
1295  (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
1296  scaleInt[3]) +
1297  1;
1298 
1299  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1300  return success();
1301 }
1302 
1303 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
1304  MLIRContext *context, ::std::optional<Location> location,
1305  ScatterOp::Adaptor adaptor,
1306  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1307  llvm::SmallVector<int64_t> outputShape;
1308  outputShape.resize(3, ShapedType::kDynamic);
1309 
1310  ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
1311  if (valuesInShape.hasRank()) {
1312  outputShape[0] = valuesInShape.getDimSize(0);
1313  outputShape[1] = valuesInShape.getDimSize(1);
1314  outputShape[2] = valuesInShape.getDimSize(2);
1315  }
1316 
1317  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
1318  if (indicesShape.hasRank()) {
1319  if (outputShape[0] == ShapedType::kDynamic)
1320  outputShape[0] = indicesShape.getDimSize(0);
1321  }
1322 
1323  ShapeAdaptor inputShape(adaptor.getInput().getType());
1324  if (inputShape.hasRank()) {
1325  if (outputShape[0] == ShapedType::kDynamic)
1326  outputShape[0] = inputShape.getDimSize(0);
1327  if (outputShape[2] == ShapedType::kDynamic)
1328  outputShape[2] = inputShape.getDimSize(2);
1329  }
1330 
1331  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1332  return success();
1333 }
1334 
1335 static LogicalResult ReduceInferReturnTypes(
1336  ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
1337  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1338  int64_t axisVal = axis.getValue().getSExtValue();
1339  if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
1340  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1341  return success();
1342  }
1343 
1344  SmallVector<int64_t> outputShape;
1345  operandShape.getDims(outputShape);
1346  outputShape[axisVal] = 1;
1347  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1348  return success();
1349 }
1350 
1351 #define COMPATIBLE_RETURN_TYPES(OP) \
1352  bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
1353  if (l.size() != r.size() || l.size() != 1) \
1354  return false; \
1355  if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
1356  return false; \
1357  return succeeded(verifyCompatibleShape(l[0], r[0])); \
1358  }
1359 
1360 #define REDUCE_SHAPE_INFER(OP) \
1361  LogicalResult OP::inferReturnTypeComponents( \
1362  MLIRContext *context, ::std::optional<Location> location, \
1363  OP::Adaptor adaptor, \
1364  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
1365  Type inputType = \
1366  llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
1367  ShapeAdaptor inputShape(adaptor.getInput().getType()); \
1368  const Properties &prop = adaptor.getProperties(); \
1369  return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
1370  inferredReturnShapes); \
1371  } \
1372  COMPATIBLE_RETURN_TYPES(OP)
1373 
1374 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
1375 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
1376 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
1377 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
1378 REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
1379 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
1380 #undef REDUCE_SHAPE_INFER
1381 COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
1382 #undef COMPATIBLE_RETURN_TYPES
1383 
1384 template <typename T>
1385 static LogicalResult verifyReduceOp(T op) {
1386  // All TOSA reduce Ops have input, output and axis.
1387  TensorType inputType = op.getInput().getType();
1388  TensorType outputType = op.getOutput().getType();
1389  int32_t reduceAxis = op.getAxis();
1390 
1391  if (reduceAxis < 0) {
1392  op.emitOpError("reduce axis must not be negative");
1393  return failure();
1394  }
1395  if (inputType.hasRank()) {
1396  int64_t inputRank = inputType.getRank();
1397  // We allow for a special case where the input/output shape has rank 0 and
1398  // axis is also 0.
1399  if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
1400  op.emitOpError("expect input tensor rank (")
1401  << inputRank << ") to be larger than reduce axis (" << reduceAxis
1402  << ")";
1403  return failure();
1404  }
1405  }
1406  if (outputType.hasRank()) {
1407  int64_t outputRank = outputType.getRank();
1408  if (inputType.hasRank() && outputRank != inputType.getRank()) {
1409  op.emitOpError(
1410  "expect output tensor rank to be equal to input tensor rank");
1411  return failure();
1412  }
1413  if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
1414  op.emitOpError("expect output tensor rank (")
1415  << outputRank << ") to be larger than reduce axis (" << reduceAxis
1416  << ")";
1417  return failure();
1418  }
1419  // We can only verify the reduced dimension size to be 1 if this is not
1420  // the special case of output rank == 0.
1421  if (outputRank != 0) {
1422  auto outputShape = outputType.getShape();
1423  if (!outputType.isDynamicDim(reduceAxis) &&
1424  outputShape[reduceAxis] != 1) {
1425  op.emitOpError("expect reduced dimension size to be 1, got ")
1426  << outputShape[reduceAxis];
1427  return failure();
1428  }
1429  }
1430  }
1431  return success();
1432 }
1433 
1434 LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
1435 LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
1436 LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
1437 LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
1438 LogicalResult tosa::ReduceProdOp::verify() { return verifyReduceOp(*this); }
1439 LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
1440 
1441 static LogicalResult NAryInferReturnTypes(
1442  const ValueShapeRange &operands,
1443  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1444  llvm::SmallVector<int64_t> outShape;
1445  if (resolveBroadcastShape(operands, outShape).failed()) {
1446  inferredReturnShapes.push_back(ShapedTypeComponents());
1447  } else {
1448  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1449  }
1450  return success();
1451 }
1452 
1453 #define NARY_SHAPE_INFER(OP) \
1454  LogicalResult OP::inferReturnTypeComponents( \
1455  MLIRContext *context, ::std::optional<Location> location, \
1456  ValueShapeRange operands, DictionaryAttr attributes, \
1457  OpaqueProperties properties, RegionRange regions, \
1458  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
1459  return NAryInferReturnTypes(operands, inferredReturnShapes); \
1460  }
1461 
1462 NARY_SHAPE_INFER(tosa::AbsOp)
1463 NARY_SHAPE_INFER(tosa::AddOp)
1464 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
1465 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
1466 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
1467 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
1468 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
1469 NARY_SHAPE_INFER(tosa::CastOp)
1470 NARY_SHAPE_INFER(tosa::CeilOp)
1471 NARY_SHAPE_INFER(tosa::ClampOp)
1472 NARY_SHAPE_INFER(tosa::ClzOp)
1473 NARY_SHAPE_INFER(tosa::CosOp)
1474 NARY_SHAPE_INFER(tosa::ExpOp)
1475 NARY_SHAPE_INFER(tosa::FloorOp)
1476 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
1477 NARY_SHAPE_INFER(tosa::GreaterOp)
1478 NARY_SHAPE_INFER(tosa::IdentityOp)
1479 NARY_SHAPE_INFER(tosa::IntDivOp)
1480 NARY_SHAPE_INFER(tosa::LogOp)
1481 NARY_SHAPE_INFER(tosa::LogicalAndOp)
1482 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
1483 NARY_SHAPE_INFER(tosa::LogicalNotOp)
1484 NARY_SHAPE_INFER(tosa::LogicalOrOp)
1485 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
1486 NARY_SHAPE_INFER(tosa::LogicalXorOp)
1487 NARY_SHAPE_INFER(tosa::MaximumOp)
1488 NARY_SHAPE_INFER(tosa::MinimumOp)
1489 NARY_SHAPE_INFER(tosa::MulOp)
1490 NARY_SHAPE_INFER(tosa::NegateOp)
1491 NARY_SHAPE_INFER(tosa::PowOp)
1492 NARY_SHAPE_INFER(tosa::ReciprocalOp)
1493 NARY_SHAPE_INFER(tosa::RescaleOp)
1494 NARY_SHAPE_INFER(tosa::ReverseOp)
1495 NARY_SHAPE_INFER(tosa::RsqrtOp)
1496 NARY_SHAPE_INFER(tosa::SinOp)
1497 NARY_SHAPE_INFER(tosa::SelectOp)
1498 NARY_SHAPE_INFER(tosa::SubOp)
1499 NARY_SHAPE_INFER(tosa::TanhOp)
1500 NARY_SHAPE_INFER(tosa::ErfOp)
1501 NARY_SHAPE_INFER(tosa::SigmoidOp)
1502 #undef PRED_SHAPE_INFER
1503 
1504 static LogicalResult poolingInferReturnTypes(
1505  ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
1506  ArrayRef<int64_t> pad,
1507  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1508  llvm::SmallVector<int64_t> outputShape;
1509  outputShape.resize(4, ShapedType::kDynamic);
1510 
1511  // We only know the rank if the input type is unranked.
1512  if (!inputShape) {
1513  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1514  return success();
1515  }
1516 
1517  // Batch and number of channels are identical for pooling layer.
1518  outputShape[0] = inputShape.getDimSize(0);
1519  outputShape[3] = inputShape.getDimSize(3);
1520 
1521  int64_t height = inputShape.getDimSize(1);
1522  int64_t width = inputShape.getDimSize(2);
1523 
1524  if (!ShapedType::isDynamic(height)) {
1525  int64_t padded = height + pad[0] + pad[1] - kernel[0];
1526  outputShape[1] = padded / stride[0] + 1;
1527  }
1528 
1529  if (!ShapedType::isDynamic(width)) {
1530  int64_t padded = width + pad[2] + pad[3] - kernel[1];
1531  outputShape[2] = padded / stride[1] + 1;
1532  }
1533 
1534  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1535  return success();
1536 }
1537 
1538 LogicalResult Conv2DOp::inferReturnTypeComponents(
1539  MLIRContext *context, ::std::optional<Location> location,
1540  Conv2DOp::Adaptor adaptor,
1541  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1542  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
1543 
1544  int64_t inputWidth = ShapedType::kDynamic;
1545  int64_t inputHeight = ShapedType::kDynamic;
1546  int64_t weightWidth = ShapedType::kDynamic;
1547  int64_t weightHeight = ShapedType::kDynamic;
1548 
1549  // Input shape describes input width/height and batch.
1550 
1551  ShapeAdaptor inputShape(adaptor.getInput().getType());
1552  if (inputShape.hasRank()) {
1553  outputShape[0] = inputShape.getDimSize(0);
1554  inputHeight = inputShape.getDimSize(1);
1555  inputWidth = inputShape.getDimSize(2);
1556  }
1557 
1558  // Weight shapes describes the filter width/height and the output channels.
1559  ShapeAdaptor weightShape(adaptor.getWeight().getType());
1560  if (weightShape.hasRank()) {
1561  outputShape[3] = weightShape.getDimSize(0);
1562  weightHeight = weightShape.getDimSize(1);
1563  weightWidth = weightShape.getDimSize(2);
1564  }
1565 
1566  // Bias shape can describe the output channels.
1567  ShapeAdaptor biasShape(adaptor.getBias().getType());
1568  if (biasShape.hasRank()) {
1569  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1570  ? biasShape.getDimSize(0)
1571  : outputShape[3];
1572  }
1573 
1574  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
1575  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1576  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
1577 
1578  if (!ShapedType::isDynamic(inputHeight) &&
1579  !ShapedType::isDynamic(weightHeight)) {
1580  int64_t inputSize = inputHeight + padding[0] + padding[1];
1581  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1582  int64_t unstridedResult = inputSize - filterSize + 1;
1583  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1584  }
1585 
1586  if (!ShapedType::isDynamic(inputWidth) &&
1587  !ShapedType::isDynamic(weightWidth)) {
1588  int64_t inputSize = inputWidth + padding[2] + padding[3];
1589  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1590  int64_t unstridedResult = inputSize - filterSize + 1;
1591  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1592  }
1593 
1594  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1595  return success();
1596 }
1597 
1598 LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); }
1599 
1600 LogicalResult Conv3DOp::inferReturnTypeComponents(
1601  MLIRContext *context, ::std::optional<Location> location,
1602  Conv3DOp::Adaptor adaptor,
1603  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1604  llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
1605 
1606  int64_t inputWidth = ShapedType::kDynamic;
1607  int64_t inputHeight = ShapedType::kDynamic;
1608  int64_t inputDepth = ShapedType::kDynamic;
1609 
1610  int64_t weightWidth = ShapedType::kDynamic;
1611  int64_t weightHeight = ShapedType::kDynamic;
1612  int64_t weightDepth = ShapedType::kDynamic;
1613 
1614  // Input shape describes input width/height and batch.
1615  ShapeAdaptor inputShape(adaptor.getInput().getType());
1616  if (inputShape.hasRank()) {
1617  outputShape[0] = inputShape.getDimSize(0);
1618  inputDepth = inputShape.getDimSize(1);
1619  inputHeight = inputShape.getDimSize(2);
1620  inputWidth = inputShape.getDimSize(3);
1621  }
1622 
1623  // Weight shapes describes the filter width/height and the output channels.
1624  ShapeAdaptor weightShape(adaptor.getWeight().getType());
1625  if (weightShape.hasRank()) {
1626  outputShape[4] = weightShape.getDimSize(0);
1627  weightDepth = weightShape.getDimSize(1);
1628  weightHeight = weightShape.getDimSize(2);
1629  weightWidth = weightShape.getDimSize(3);
1630  }
1631 
1632  // Bias shape can describe the output channels.
1633  ShapeAdaptor biasShape(adaptor.getBias().getType());
1634  if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
1635  outputShape[4] = biasShape.getDimSize(0);
1636  }
1637 
1638  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
1639  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1640  llvm::ArrayRef<int64_t> pad = adaptor.getPad();
1641 
1642  if (!ShapedType::isDynamic(inputDepth) &&
1643  !ShapedType::isDynamic(weightDepth)) {
1644  int32_t inputSize = inputDepth + pad[0] + pad[1];
1645  int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
1646  int32_t unstridedResult = inputSize - filterSize + 1;
1647  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1648  }
1649 
1650  if (!ShapedType::isDynamic(inputHeight) &&
1651  !ShapedType::isDynamic(weightHeight)) {
1652  int32_t inputSize = inputHeight + pad[2] + pad[3];
1653  int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
1654  int32_t unstridedResult = inputSize - filterSize + 1;
1655  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1656  }
1657 
1658  if (!ShapedType::isDynamic(inputWidth) &&
1659  !ShapedType::isDynamic(weightWidth)) {
1660  int32_t inputSize = inputWidth + pad[4] + pad[5];
1661  int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
1662  int32_t unstridedResult = inputSize - filterSize + 1;
1663  outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
1664  }
1665 
1666  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1667  return success();
1668 }
1669 
1670 LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); }
1671 
1672 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
1673  MLIRContext *context, ::std::optional<Location> location,
1674  AvgPool2dOp::Adaptor adaptor,
1675  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1676  ShapeAdaptor inputShape(adaptor.getInput().getType());
1677  const Properties &prop = adaptor.getProperties();
1678  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
1679  inferredReturnShapes);
1680 }
1681 
1682 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
1683  MLIRContext *context, ::std::optional<Location> location,
1684  MaxPool2dOp::Adaptor adaptor,
1685  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1686  ShapeAdaptor inputShape(adaptor.getInput().getType());
1687  const Properties &prop = adaptor.getProperties();
1688  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
1689  inferredReturnShapes);
1690 }
1691 
1692 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
1693  MLIRContext *context, ::std::optional<Location> location,
1694  DepthwiseConv2DOp::Adaptor adaptor,
1695  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1696  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
1697 
1698  int64_t inputWidth = ShapedType::kDynamic;
1699  int64_t inputHeight = ShapedType::kDynamic;
1700  int64_t inputChannels = ShapedType::kDynamic;
1701 
1702  int64_t weightWidth = ShapedType::kDynamic;
1703  int64_t weightHeight = ShapedType::kDynamic;
1704  int64_t depthChannels = ShapedType::kDynamic;
1705 
1706  // Input shape describes input width/height and batch.
1707  ShapeAdaptor inputShape(adaptor.getInput().getType());
1708  if (inputShape.hasRank()) {
1709  outputShape[0] = inputShape.getDimSize(0);
1710  inputHeight = inputShape.getDimSize(1);
1711  inputWidth = inputShape.getDimSize(2);
1712  inputChannels = inputShape.getDimSize(3);
1713  }
1714 
1715  // Weight shapes describes the filter width/height and the output channels.
1716  ShapeAdaptor weightShape(adaptor.getWeight().getType());
1717  if (weightShape.hasRank()) {
1718  weightHeight = weightShape.getDimSize(0);
1719  weightWidth = weightShape.getDimSize(1);
1720  inputChannels = ShapedType::isDynamic(inputChannels)
1721  ? weightShape.getDimSize(2)
1722  : inputChannels;
1723  depthChannels = weightShape.getDimSize(3);
1724  }
1725 
1726  // If both inputChannels and depthChannels are available we can determine
1727  // the output channels.
1728  if (!ShapedType::isDynamic(inputChannels) &&
1729  !ShapedType::isDynamic(depthChannels)) {
1730  outputShape[3] = inputChannels * depthChannels;
1731  }
1732 
1733  // Bias shape can describe the output channels.
1734  ShapeAdaptor biasShape(adaptor.getBias().getType());
1735  if (biasShape.hasRank()) {
1736  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1737  ? biasShape.getDimSize(0)
1738  : outputShape[3];
1739  }
1740 
1741  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
1742  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
1743  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1744 
1745  if (!ShapedType::isDynamic(inputHeight) &&
1746  !ShapedType::isDynamic(weightHeight)) {
1747  int64_t inputSize = inputHeight + padding[0] + padding[1];
1748  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1749  int64_t unstridedResult = inputSize - filterSize + 1;
1750  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1751  }
1752 
1753  if (!ShapedType::isDynamic(inputWidth) &&
1754  !ShapedType::isDynamic(weightWidth)) {
1755  int64_t inputSize = inputWidth + padding[2] + padding[3];
1756  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1757  int64_t unstridedResult = inputSize - filterSize + 1;
1758  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1759  }
1760 
1761  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1762  return success();
1763 }
1764 
1765 LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); }
1766 
1767 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
1768  MLIRContext *context, ::std::optional<Location> location,
1769  TransposeConv2DOp::Adaptor adaptor,
1770  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1771  // outputShape is mutable.
1772  llvm::SmallVector<int64_t> outputShape =
1773  convertToMlirShape(adaptor.getOutShape());
1774 
1775  int64_t inputWidth = ShapedType::kDynamic;
1776  int64_t inputHeight = ShapedType::kDynamic;
1777  int64_t weightWidth = ShapedType::kDynamic;
1778  int64_t weightHeight = ShapedType::kDynamic;
1779 
1780  // Input shape describes input width/height and batch.
1781  ShapeAdaptor inputShape(adaptor.getInput().getType());
1782  if (inputShape.hasRank()) {
1783  outputShape[0] = ShapedType::isDynamic(outputShape[0])
1784  ? inputShape.getDimSize(0)
1785  : outputShape[0];
1786  inputHeight = inputShape.getDimSize(1);
1787  inputWidth = inputShape.getDimSize(2);
1788  }
1789 
1790  // Weight shapes describes the filter width/height and the output channels.
1791  ShapeAdaptor weightShape(adaptor.getFilter().getType());
1792  if (weightShape.hasRank()) {
1793  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1794  ? weightShape.getDimSize(0)
1795  : outputShape[3];
1796  weightHeight = weightShape.getDimSize(1);
1797  weightWidth = weightShape.getDimSize(2);
1798  }
1799 
1800  // Bias shape can describe the output channels.
1801  ShapeAdaptor biasShape(adaptor.getInput().getType());
1802  if (biasShape.hasRank()) {
1803  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1804  ? biasShape.getDimSize(0)
1805  : outputShape[3];
1806  }
1807 
1808  llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
1809  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1810 
1811  if (!ShapedType::isDynamic(inputHeight) &&
1812  !ShapedType::isDynamic(weightHeight)) {
1813  int64_t calculateSize =
1814  (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
1815  outputShape[1] =
1816  ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
1817  }
1818 
1819  if (!ShapedType::isDynamic(inputWidth) &&
1820  !ShapedType::isDynamic(weightWidth)) {
1821  int64_t calculateSize =
1822  (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
1823  outputShape[2] =
1824  ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
1825  }
1826 
1827  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1828  return success();
1829 }
1830 
1831 LogicalResult IfOp::inferReturnTypeComponents(
1832  MLIRContext *context, ::std::optional<Location> location,
1833  IfOp::Adaptor adaptor,
1834  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1836  for (Region *region : adaptor.getRegions()) {
1837  for (auto &block : *region)
1838  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1839  yieldOps.push_back(returnOp);
1840  }
1841 
1842  if (yieldOps.empty())
1843  return failure();
1844 
1845  // Get the initial type information for the yield op.
1846  llvm::SmallVector<ValueKnowledge> resultKnowledge;
1847  resultKnowledge.reserve(yieldOps.front().getNumOperands());
1848  for (auto operand : yieldOps.front().getOperands()) {
1849  resultKnowledge.push_back(
1850  ValueKnowledge::getKnowledgeFromType(operand.getType()));
1851  }
1852 
1853  for (auto yieldOp : yieldOps) {
1854  if (resultKnowledge.size() != yieldOp.getNumOperands())
1855  return failure();
1856 
1857  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
1858  int32_t index = it.index();
1859  auto meet = ValueKnowledge::meet(
1860  resultKnowledge[index],
1861  ValueKnowledge::getKnowledgeFromType(it.value().getType()));
1862  if (!meet)
1863  continue;
1864  resultKnowledge[index] = meet;
1865  }
1866  }
1867 
1868  for (const ValueKnowledge &result : resultKnowledge) {
1869  inferredReturnShapes.push_back(result.getShapedTypeComponents());
1870  }
1871 
1872  return success();
1873 }
1874 
1875 LogicalResult WhileOp::inferReturnTypeComponents(
1876  MLIRContext *context, ::std::optional<Location> location,
1877  WhileOp::Adaptor adaptor,
1878  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1880  for (auto &block : adaptor.getBody())
1881  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1882  yieldOps.push_back(returnOp);
1883 
1884  // TOSA's while must have a tosa.yield as its terminator. If not found this
1885  // tosa.while is invalid.
1886  if (yieldOps.empty())
1887  return failure();
1888 
1889  // Get the initial type information from the operand types.
1890  llvm::SmallVector<ValueKnowledge> resultKnowledge;
1891  resultKnowledge.reserve(yieldOps.front().getNumOperands());
1892  for (auto operand : yieldOps.front().getOperands()) {
1893  resultKnowledge.push_back(
1894  ValueKnowledge::getKnowledgeFromType(operand.getType()));
1895  }
1896 
1897  for (auto yieldOp : yieldOps) {
1898  if (resultKnowledge.size() != yieldOp.getNumOperands())
1899  return failure();
1900 
1901  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
1902  int32_t index = it.index();
1903  if (auto meet = ValueKnowledge::meet(
1904  resultKnowledge[index],
1905  ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
1906  resultKnowledge[index] = meet;
1907  }
1908  }
1909  }
1910 
1911  for (const ValueKnowledge &result : resultKnowledge) {
1912  inferredReturnShapes.push_back(result.getShapedTypeComponents());
1913  }
1914 
1915  return success();
1916 }
1917 
1918 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
1919  if (auto vt = llvm::dyn_cast<VectorType>(getType()))
1920  return llvm::to_vector<4>(vt.getShape());
1921  return std::nullopt;
1922 }
1923 
1924 // parse and print of IfOp refer to the implementation of SCF dialect.
1925 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
1926  // Create the regions for 'then'.
1927  result.regions.reserve(2);
1928  Region *thenRegion = result.addRegion();
1929  Region *elseRegion = result.addRegion();
1930 
1931  auto &builder = parser.getBuilder();
1933  // Create a i1 tensor type for the boolean condition.
1934  Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
1935  if (parser.parseOperand(cond) ||
1936  parser.resolveOperand(cond, i1Type, result.operands))
1937  return failure();
1938  // Parse optional results type list.
1939  if (parser.parseOptionalArrowTypeList(result.types))
1940  return failure();
1941  // Parse the 'then' region.
1942  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
1943  return failure();
1944 
1945  // If we find an 'else' keyword then parse the 'else' region.
1946  if (!parser.parseOptionalKeyword("else")) {
1947  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
1948  return failure();
1949  }
1950 
1951  // Parse the optional attribute list.
1952  if (parser.parseOptionalAttrDict(result.attributes))
1953  return failure();
1954  return success();
1955 }
1956 
1957 void IfOp::print(OpAsmPrinter &p) {
1958  bool printBlockTerminators = false;
1959 
1960  p << " " << getCond();
1961  if (!getResults().empty()) {
1962  p << " -> (" << getResultTypes() << ")";
1963  // Print yield explicitly if the op defines values.
1964  printBlockTerminators = true;
1965  }
1966  p << ' ';
1967  p.printRegion(getThenBranch(),
1968  /*printEntryBlockArgs=*/false,
1969  /*printBlockTerminators=*/printBlockTerminators);
1970 
1971  // Print the 'else' regions if it exists and has a block.
1972  auto &elseRegion = getElseBranch();
1973  if (!elseRegion.empty()) {
1974  p << " else ";
1975  p.printRegion(elseRegion,
1976  /*printEntryBlockArgs=*/false,
1977  /*printBlockTerminators=*/printBlockTerminators);
1978  }
1979 
1980  p.printOptionalAttrDict((*this)->getAttrs());
1981 }
1982 
1983 LogicalResult ReverseOp::verify() {
1984  TensorType inputType = getInput1().getType();
1985  TensorType outputType = getOutput().getType();
1986  int32_t reverseAxis = getAxis();
1987 
1988  if (reverseAxis < 0)
1989  return emitOpError("expected non-negative reverse axis");
1990  if (inputType.hasRank()) {
1991  int64_t inputRank = inputType.getRank();
1992  // We allow for a special case where the input/output shape has rank 0 and
1993  // axis is also 0.
1994  if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
1995  return emitOpError("expect input tensor rank (")
1996  << inputRank << ") to be larger than reverse axis (" << reverseAxis
1997  << ")";
1998  }
1999  if (outputType.hasRank()) {
2000  int64_t outputRank = outputType.getRank();
2001  if (inputType.hasRank() && outputRank != inputType.getRank())
2002  return emitOpError(
2003  "expect output tensor rank to be equal to input tensor rank");
2004  if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
2005  return emitOpError("expect output tensor rank (")
2006  << outputRank << ") to be larger than reverse axis ("
2007  << reverseAxis << ")";
2008  }
2009  return success();
2010 }
2011 
2012 // parse and print of WhileOp refer to the implementation of SCF dialect.
2013 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
2016  Region *cond = result.addRegion();
2017  Region *body = result.addRegion();
2018 
2019  OptionalParseResult listResult =
2020  parser.parseOptionalAssignmentList(regionArgs, operands);
2021  if (listResult.has_value() && failed(listResult.value()))
2022  return failure();
2023 
2024  FunctionType functionType;
2025  SMLoc typeLoc = parser.getCurrentLocation();
2026  if (failed(parser.parseColonType(functionType)))
2027  return failure();
2028 
2029  result.addTypes(functionType.getResults());
2030 
2031  if (functionType.getNumInputs() != operands.size()) {
2032  return parser.emitError(typeLoc)
2033  << "expected as many input types as operands "
2034  << "(expected " << operands.size() << " got "
2035  << functionType.getNumInputs() << ")";
2036  }
2037 
2038  // Resolve input operands.
2039  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
2040  parser.getCurrentLocation(),
2041  result.operands)))
2042  return failure();
2043 
2044  // Propagate the types into the region arguments.
2045  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
2046  regionArgs[i].type = functionType.getInput(i);
2047 
2048  return failure(parser.parseRegion(*cond, regionArgs) ||
2049  parser.parseKeyword("do") || parser.parseRegion(*body) ||
2051 }
2052 
2054  Block::BlockArgListType blocksArgs,
2055  ValueRange initializers,
2056  StringRef prefix = "") {
2057  assert(blocksArgs.size() == initializers.size() &&
2058  "expected same length of arguments and initializers");
2059  if (initializers.empty())
2060  return;
2061 
2062  parser << prefix << '(';
2063  llvm::interleaveComma(
2064  llvm::zip(blocksArgs, initializers), parser,
2065  [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
2066  parser << ")";
2067 }
2068 
2069 void WhileOp::print(OpAsmPrinter &parser) {
2070  printInitializationList(parser, getCond().front().getArguments(), getInputs(),
2071  " ");
2072  parser << " : ";
2073  parser.printFunctionalType(getInputs().getTypes(), getResults().getTypes());
2074  parser << ' ';
2075  parser.printRegion(getCond(), /*printEntryBlockArgs=*/false);
2076  parser << " do ";
2077  parser.printRegion(getBody());
2078  parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
2079 }
2080 
2081 //===----------------------------------------------------------------------===//
2082 // TOSA Attribute Definitions.
2083 //===----------------------------------------------------------------------===//
2084 
2085 #define GET_ATTRDEF_CLASSES
2086 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
2087 
2088 //===----------------------------------------------------------------------===//
2089 // TOSA Operator Definitions.
2090 //===----------------------------------------------------------------------===//
2091 
2092 #define GET_OP_CLASSES
2093 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition: TosaOps.cpp:430
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:1335
#define REDUCE_SHAPE_INFER(OP)
Definition: TosaOps.cpp:1360
static LogicalResult verifyConvOp(T op)
Definition: TosaOps.cpp:210
static void buildUnaryOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter unary operators that have scale relationship between their...
Definition: TosaOps.cpp:485
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:1504
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition: TosaOps.cpp:498
static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias)
The tosa.fully_connected op has its own builder as it does not have strides/dilation/padding.
Definition: TosaOps.cpp:411
static LogicalResult verifyReduceOp(T op)
Definition: TosaOps.cpp:1385
#define NARY_SHAPE_INFER(OP)
Definition: TosaOps.cpp:1453
static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings, Value padConst)
This builder is called on TOSA pad operator when an explicit pad_const value is passed in.
Definition: TosaOps.cpp:510
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:1441
#define COMPATIBLE_RETURN_TYPES(OP)
Definition: TosaOps.cpp:1351
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition: TosaOps.cpp:526
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition: TosaOps.cpp:390
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr but avg_pool operator has...
Definition: TosaOps.cpp:467
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition: TosaOps.cpp:837
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition: TosaOps.cpp:367
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Definition: TosaOps.cpp:2053
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
IntegerType getI32Type()
Definition: Builders.cpp:107
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:95
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:216
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:102
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:66
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Definition: QuantUtils.cpp:120
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
Definition: QuantUtils.cpp:239
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
Definition: QuantUtils.cpp:219
ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &attr)
Definition: TosaOps.cpp:166
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
Definition: QuantUtils.cpp:164
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, Attribute attr)
Definition: TosaOps.cpp:188
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Definition: QuantUtils.cpp:193
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
Definition: ShapeUtils.h:33
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45