MLIR  20.0.0git
TosaOps.cpp
Go to the documentation of this file.
1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://developer.mlplatform.org/w/tosa/
12 //
13 //===----------------------------------------------------------------------===//
14 
22 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 using namespace mlir;
34 using namespace mlir::tosa;
35 
36 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
37 
38 //===----------------------------------------------------------------------===//
39 // Tosa dialect interface includes.
40 //===----------------------------------------------------------------------===//
41 
42 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
43 
44 namespace {
45 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
46 
47 //===----------------------------------------------------------------------===//
48 // Dialect Function Inliner Interface.
49 //===----------------------------------------------------------------------===//
50 struct TosaInlinerInterface : public DialectInlinerInterface {
52 
53  //===--------------------------------------------------------------------===//
54  // Analysis Hooks.
55  //===--------------------------------------------------------------------===//
56 
57  /// All operations can be inlined by default.
58  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
59  IRMapping &map) const final {
60  return true;
61  }
62 
63  /// All regions with If and While parent operators can be inlined.
64  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
65  IRMapping &map) const final {
66  return (isa<tosa::IfOp>(dest->getParentOp()) ||
67  isa<tosa::WhileOp>(dest->getParentOp()));
68  }
69 };
70 
71 /// This class implements the bytecode interface for the Tosa dialect.
72 struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
73  TosaDialectBytecodeInterface(Dialect *dialect)
74  : BytecodeDialectInterface(dialect) {}
75 
76  //===--------------------------------------------------------------------===//
77  // Attributes
78 
79  Attribute readAttribute(DialectBytecodeReader &reader) const override {
80  return ::readAttribute(getContext(), reader);
81  }
82 
83  LogicalResult writeAttribute(Attribute attr,
84  DialectBytecodeWriter &writer) const override {
85  return ::writeAttribute(attr, writer);
86  }
87 
88  //===--------------------------------------------------------------------===//
89  // Types
90 
91  Type readType(DialectBytecodeReader &reader) const override {
92  return ::readType(getContext(), reader);
93  }
94 
95  LogicalResult writeType(Type type,
96  DialectBytecodeWriter &writer) const override {
97  return ::writeType(type, writer);
98  }
99 
100  void writeVersion(DialectBytecodeWriter &writer) const final {
101  // TODO: Populate.
102  }
103 
104  std::unique_ptr<DialectVersion>
105  readVersion(DialectBytecodeReader &reader) const final {
106  // TODO: Populate
107  reader.emitError("Dialect does not support versioning");
108  return nullptr;
109  }
110 
111  LogicalResult upgradeFromVersion(Operation *topLevelOp,
112  const DialectVersion &version) const final {
113  return success();
114  }
115 };
116 
117 } // namespace
118 
119 //===----------------------------------------------------------------------===//
120 // TOSA control flow support.
121 //===----------------------------------------------------------------------===//
122 
123 /// Returns the while loop body.
124 SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
125 
126 //===----------------------------------------------------------------------===//
127 // Tosa dialect initialization.
128 //===----------------------------------------------------------------------===//
129 
130 void TosaDialect::initialize() {
131  addOperations<
132 #define GET_OP_LIST
133 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
134  >();
135  addAttributes<
136 #define GET_ATTRDEF_LIST
137 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
138  >();
139  addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
140  declarePromisedInterfaces<
141  mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
142  ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
143  LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
144  LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
145  BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
146  NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
147  GreaterEqualOp, MatMulOp>();
148 }
149 
151  Type type, Location loc) {
152  // Tosa dialect constants only support ElementsAttr unlike standard dialect
153  // constant which supports all attributes.
154  if (llvm::isa<ElementsAttr>(value))
155  return builder.create<tosa::ConstOp>(loc, type,
156  llvm::cast<ElementsAttr>(value));
157  return nullptr;
158 }
159 
160 //===----------------------------------------------------------------------===//
161 // Parsers and printers
162 //===----------------------------------------------------------------------===//
163 
164 ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
165  Attribute &attr) {
166  if (succeeded(parser.parseOptionalEqual())) {
167  if (failed(parser.parseAttribute(attr))) {
168  return parser.emitError(parser.getCurrentLocation())
169  << "expected attribute";
170  }
171  if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
172  typeAttr = TypeAttr::get(typedAttr.getType());
173  }
174  return success();
175  }
176 
177  Type type;
178  if (failed(parser.parseColonType(type))) {
179  return parser.emitError(parser.getCurrentLocation()) << "expected type";
180  }
181  typeAttr = TypeAttr::get(type);
182 
183  return success();
184 }
185 
187  Attribute attr) {
188  bool needsSpace = false;
189  auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
190  if (!typedAttr || typedAttr.getType() != type.getValue()) {
191  p << ": ";
192  p.printAttribute(type);
193  needsSpace = true; // subsequent attr value needs a space separator
194  }
195  if (attr) {
196  if (needsSpace)
197  p << ' ';
198  p << "= ";
199  p.printAttribute(attr);
200  }
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // TOSA Operator Verifiers.
205 //===----------------------------------------------------------------------===//
206 
207 static bool hasZeroDimension(ShapedType shapedType) {
208  if (!shapedType.hasRank())
209  return false;
210 
211  auto rank = shapedType.getRank();
212 
213  for (int i = 0; i < rank; i++) {
214  if (shapedType.isDynamicDim(i))
215  continue;
216  if (shapedType.getDimSize(i) == 0)
217  return true;
218  }
219 
220  return false;
221 }
222 
223 template <typename T>
224 static LogicalResult verifyConvOp(T op) {
225  // All TOSA conv ops have an input() and weight().
226  auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
227  auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
228 
229  // Must be ranked tensor types
230  if (!inputType) {
231  op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
232  return failure();
233  }
234  if (!weightType) {
235  op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
236  return failure();
237  }
238 
239  if (hasZeroDimension(inputType))
240  return op.emitOpError() << "tensor has a dimension with size zero. Each "
241  "dimension of a tensor must have size >= 1";
242 
243  auto inputEType = inputType.getElementType();
244  auto weightEType = weightType.getElementType();
245 
246  bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
247  bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
248 
249  // Either both must be quantized or both unquantized.
250  if (inputIsQuant != weightIsQuant) {
251  op.emitOpError(
252  "expect both input and weight to be float or not together, got ")
253  << inputEType << " and " << weightEType;
254  return failure();
255  }
256 
257  // Quantized type must have constructed the quantizationattr, and unquantized
258  // types should not have a quantizationattr.
259  if ((inputIsQuant && !op.getQuantizationInfo()) ||
260  (!inputIsQuant && op.getQuantizationInfo())) {
261  op.emitOpError("quantizationattr is required for quantized type, and not "
262  "allowed for float type");
263  return failure();
264  }
265 
266  return success();
267 }
268 
269 LogicalResult tosa::ArgMaxOp::verify() {
270  // Ensure output is of 32-bit integer
271  const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
272  if (!resultETy.isIntOrIndex())
273  return emitOpError("result tensor is not of integer type");
274 
275  // Ensure axis is within the tensor rank
276  const auto inputType = llvm::cast<ShapedType>(getInput().getType());
277  const int64_t axis = getAxisAttr().getInt();
278  if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
279  return emitOpError("specified axis is outside the rank of the tensor");
280 
281  return success();
282 }
283 
284 LogicalResult tosa::AvgPool2dOp::verify() {
285  auto inputType = llvm::cast<ShapedType>(getInput().getType());
286  if (hasZeroDimension(inputType))
287  return emitOpError() << "tensor has a dimension with size zero. Each "
288  "dimension of a tensor must have size >= 1";
289 
290  auto inputETy = inputType.getElementType();
291  auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
292 
293  if (auto quantType =
294  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
295  inputETy = quantType.getStorageType();
296 
297  if (auto quantType =
298  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy))
299  resultETy = quantType.getStorageType();
300 
301  auto accType = getAccType();
302  if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
303  return emitOpError("accumulator type for integer tensor is not i32");
304 
305  if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
306  return emitOpError("accumulator type for f16 tensor is not f16/f32");
307 
308  if (inputETy.isBF16() && !accType.isF32())
309  return emitOpError("accumulator type for bf16 tensor is not f32");
310 
311  if (inputETy.isF32() && !accType.isF32())
312  return emitOpError("accumulator type for f32 tensor is not f32");
313 
314  if ((inputETy.isF32() && resultETy.isF32()) ||
315  (inputETy.isF16() && resultETy.isF16()) ||
316  (inputETy.isBF16() && resultETy.isBF16()) ||
317  (inputETy.isInteger(8) && resultETy.isInteger(8)) ||
318  (inputETy.isInteger(16) && resultETy.isInteger(16)))
319  return success();
320 
321  return emitOpError("input/output element types are incompatible.");
322 }
323 
324 LogicalResult tosa::ClampOp::verify() {
325  mlir::Type inputETy =
326  llvm::cast<ShapedType>(getInput().getType()).getElementType();
327  if (auto quantType =
328  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
329  inputETy = quantType.getStorageType();
330  }
331  mlir::Type maxFpType = getMaxFpAttr().getType();
332  mlir::Type minFpType = getMinFpAttr().getType();
333  mlir::Type outputETy =
334  llvm::cast<ShapedType>(getOutput().getType()).getElementType();
335  if (auto quantType =
336  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
337  outputETy = quantType.getStorageType();
338  }
339  unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
340 
341  if (inputETy != outputETy)
342  return emitOpError("input/output element types are incompatible.");
343 
344  // if input datatype is float, check that the two min/max_fp attributes share
345  // the same type and that their type is either the same of the input's
346  // datatype, or a float type whose bitwidth > input datatype bitwidth
347  if (!inputETy.isInteger(dataTypeBitWidth)) {
348  if (((maxFpType != minFpType) ||
349  (maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <=
350  inputETy.getIntOrFloatBitWidth())))
351  return emitOpError("min/max attributes types are incompatible with "
352  "input/output element types.");
353  }
354 
355  return success();
356 }
357 
358 //===----------------------------------------------------------------------===//
359 // TOSA Operator Quantization Builders.
360 //===----------------------------------------------------------------------===//
361 
362 /// This builder is called on all convolution operators except TransposeConv,
363 /// which has specialized output shape semantics. The builder also defines the
364 /// bitwidth of the output given the bit width of the input & weight content.
365 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
366  Type outputType, Value input, Value weight,
367  Value bias, DenseI64ArrayAttr pad,
368  DenseI64ArrayAttr stride,
369  DenseI64ArrayAttr dilation) {
370 
371  result.addOperands({input, weight, bias});
372  result.addAttribute("pad", pad);
373  result.addAttribute("stride", stride);
374  result.addAttribute("dilation", dilation);
375 
376  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
377  if (quantAttr) {
378  result.addAttribute("quantization_info", quantAttr);
379  result.addTypes(
380  buildConvOpResultTypeInfo(builder, outputType, input, weight));
381  } else {
382  result.addTypes(outputType);
383  }
384 }
385 
386 /// Handles tosa.transpose_conv2d which has outpad and output shape attributes.
388  OpBuilder &builder, OperationState &result, Type outputType, Value input,
389  Value weight, Value bias, DenseI64ArrayAttr outpad,
390  DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) {
391  result.addOperands({input, weight, bias});
392  result.addAttribute("out_pad", outpad);
393  result.addAttribute("stride", stride);
394  result.addAttribute("out_shape", outputShape);
395  auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
396 
397  if (quantAttr) {
398  result.addAttribute("quantization_info", quantAttr);
399  result.addTypes(
400  buildConvOpResultTypeInfo(builder, outputType, input, weight));
401  } else {
402  result.addTypes(outputType);
403  }
404 }
405 
406 /// The tosa.fully_connected op has its own builder as it does not have
407 /// strides/dilation/padding.
408 static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
409  Type outputType, Value input, Value weight,
410  Value bias) {
411 
412  result.addOperands({input, weight, bias});
413  auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
414  if (quantAttr) {
415  result.addAttribute("quantization_info", quantAttr);
416  result.addTypes(
417  buildConvOpResultTypeInfo(builder, outputType, input, weight));
418  } else {
419  result.addTypes(outputType);
420  }
421 }
422 
423 /// The tosa.matmul op is also intended to be generated where a fully_connected
424 /// op must be constructed where the weight is not a constant. In this case,
425 /// the fully_connected op must be expressed using matmul.
426 /// TODO: Add link to the leglization document explaining this.
428  OperationState &result, Type outputType,
429  Value a, Value b) {
430  result.addOperands({a, b});
431  auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
432 
433  if (quantAttr) {
434  result.addAttribute("quantization_info", quantAttr);
435 
436  auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
437  assert(inputType && "Input must be a shaped tensor type!");
438 
439  auto inputQType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
440  inputType.getElementType());
441  assert(inputQType && "Tensor must have quantized datatype!");
442 
443  unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
444 
445  auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
446  assert(outputShapedType && "Output must be a shaped type");
447 
448  IntegerType accElementType;
449  if (inputBits == 16)
450  accElementType = builder.getIntegerType(48);
451  else
452  accElementType = builder.getI32Type();
453  auto accType = outputShapedType.clone(accElementType);
454  result.addTypes(accType);
455  } else {
456  result.addTypes(outputType);
457  }
458 }
459 
460 /// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr
461 /// but avg_pool operator has its own builder as it has additional parameters
462 /// not part of the unary ops.
463 static void
465  Type outputType, Value input,
466  DenseArrayAttr kernel, DenseArrayAttr stride,
467  DenseArrayAttr pad, TypeAttr accType) {
468  result.addOperands(input);
469  result.addAttribute("kernel", kernel);
470  result.addAttribute("stride", stride);
471  result.addAttribute("pad", pad);
472  result.addAttribute("acc_type", accType);
473  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
474  if (quantAttr)
475  result.addAttribute("quantization_info", quantAttr);
476  result.types.push_back(outputType);
477 }
478 
479 /// This builder is called on single-parameter unary operators that have scale
480 /// relationship between their input and output, expressed by the
481 /// UnaryOpQuantizationAttr.
483  OperationState &result, Type outputType,
484  Value input) {
485  result.addOperands(input);
486  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
487  if (quantAttr)
488  result.addAttribute("quantization_info", quantAttr);
489  result.types.push_back(outputType);
490 }
491 
492 /// This builder is called on TOSA pad operator that needs to create its own
493 /// OptionalAttr quantization_attr parameter to scale the padding values
494 /// correctly. No pad_const is interpreted as zero-padding.
495 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
496  Type outputType, Value input,
497  Value paddings) {
498  result.addOperands({input, paddings});
499  auto quantAttr = buildPadOpQuantizationAttr(builder, input);
500  if (quantAttr)
501  result.addAttribute("quantization_info", quantAttr);
502  result.types.push_back(outputType);
503 }
504 
505 /// This builder is called on TOSA pad operator when an explicit pad_const
506 /// value is passed in. It also optionally constructs quantization_attr.
508  OperationState &result,
509  Type outputType, Value input,
510  Value paddings,
511  Value padConst) {
512  result.addOperands({input, paddings, padConst});
513  auto quantAttr = buildPadOpQuantizationAttr(builder, input);
514  if (quantAttr)
515  result.addAttribute("quantization_info", quantAttr);
516  result.types.push_back(outputType);
517 }
518 
519 //===----------------------------------------------------------------------===//
520 // TOSA Operator Return Type Inference.
521 //===----------------------------------------------------------------------===//
522 
523 static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
524  SmallVector<int64_t> &outShape) {
525  int64_t outRank = 0;
526  for (int i = 0, e = operands.size(); i != e; ++i) {
527  auto shape = operands.getShape(i);
528  if (!shape.hasRank()) {
529  // TODO(jennik): Update function to have better case handling for invalid
530  // operands and for ranked tensors.
531  return failure();
532  }
533  outRank = std::max<int64_t>(outRank, shape.getRank());
534  }
535 
536  outShape.resize(outRank, 1);
537 
538  for (int i = 0, e = operands.size(); i != e; ++i) {
539  auto shape = operands.getShape(i);
540  auto rankDiff = outShape.size() - shape.getRank();
541 
542  for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
543  auto dim1 = outShape[i + rankDiff];
544  auto dim2 = shape.getDimSize(i);
545  auto resolvedDim = dim1;
546 
547  if (dim1 == 1) {
548  resolvedDim = dim2;
549  } else if (dim2 == 1) {
550  resolvedDim = dim1;
551  } else if (dim1 != dim2) {
552  return failure();
553  }
554  outShape[i + rankDiff] = resolvedDim;
555  }
556  }
557 
558  return success();
559 }
560 
561 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
562  MLIRContext *context, ::std::optional<Location> location,
563  ArgMaxOp::Adaptor adaptor,
564  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
565  ShapeAdaptor inputShape(adaptor.getInput().getType());
566  IntegerAttr axis = adaptor.getProperties().axis;
567  int32_t axisVal = axis.getValue().getSExtValue();
568 
569  if (!inputShape.hasRank()) {
570  inferredReturnShapes.push_back(ShapedTypeComponents());
571  return success();
572  }
573 
574  SmallVector<int64_t> outShape;
575  outShape.reserve(inputShape.getRank() - 1);
576  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
577  if (i == axisVal)
578  continue;
579  outShape.push_back(inputShape.getDimSize(i));
580  }
581 
582  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
583  return success();
584 }
585 
586 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
587  MLIRContext *context, ::std::optional<Location> location,
588  RFFT2dOp::Adaptor adaptor,
589  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
590  ShapeAdaptor inputShape(adaptor.getInput().getType());
591 
592  if (!inputShape.hasRank())
593  return failure();
594 
595  llvm::SmallVector<int64_t> outputShape;
596  outputShape.resize(3, ShapedType::kDynamic);
597  outputShape[0] = inputShape.getDimSize(0);
598  outputShape[1] = inputShape.getDimSize(1);
599  int64_t inWidth = inputShape.getDimSize(2);
600 
601  // Note that we can support this calculation symbolically
602  // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1]
603  if (inWidth != ShapedType::kDynamic)
604  outputShape[2] = inWidth / 2 + 1;
605 
606  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
607  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
608 
609  return success();
610 }
611 
612 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
613  MLIRContext *context, ::std::optional<Location> location,
614  FFT2dOp::Adaptor adaptor,
615  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
616  inferredReturnShapes.push_back(
617  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
618  inferredReturnShapes.push_back(
619  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
620  return success();
621 }
622 
623 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
624  MLIRContext *context, ::std::optional<Location> location,
625  ConcatOp::Adaptor adaptor,
626  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
627  // Infer all dimension sizes by reducing based on inputs.
628  const Properties &prop = adaptor.getProperties();
629  int32_t axis = prop.axis.getValue().getSExtValue();
630  llvm::SmallVector<int64_t> outputShape;
631  bool hasRankedInput = false;
632  for (auto operand : adaptor.getOperands()) {
633  ShapeAdaptor operandShape(operand.getType());
634  if (!operandShape.hasRank())
635  continue;
636 
637  // Copy the Operand's rank.
638  if (!hasRankedInput)
639  outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
640 
641  // Copy shapes until the dim is non-dynamic.
642  for (int i = 0, s = operandShape.getRank(); i < s; i++) {
643  if (i == axis || operandShape.isDynamicDim(i))
644  continue;
645  if (outputShape[i] == ShapedType::kDynamic)
646  outputShape[i] = operandShape.getDimSize(i);
647  if (outputShape[i] != operandShape.getDimSize(i))
648  return emitOptionalError(location,
649  "Cannot concat tensors with different sizes"
650  " on the non-axis dimension ",
651  i);
652  }
653 
654  hasRankedInput = true;
655  }
656  Type inputType =
657  llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
658  if (!hasRankedInput) {
659  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
660  return success();
661  }
662 
663  // Determine the dimension size along the concatenation axis.
664  int64_t concatDimSize = 0;
665  for (auto operand : adaptor.getOperands()) {
666  ShapeAdaptor operandShape(operand.getType());
667 
668  // We need to know the length of the concatenation axis of all inputs to
669  // determine the dimension size of the output shape.
670  if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
671  concatDimSize = ShapedType::kDynamic;
672  break;
673  }
674 
675  concatDimSize += operandShape.getDimSize(axis);
676  }
677 
678  outputShape[axis] = concatDimSize;
679 
680  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
681  return success();
682 }
683 
684 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
685  MLIRContext *context, ::std::optional<Location> location,
686  ValueShapeRange operands, DictionaryAttr attributes,
687  OpaqueProperties properties, RegionRange regions,
688  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
689  auto elementType = IntegerType::get(context, /*width=*/1);
690 
692  if (resolveBroadcastShape(operands, outShape).failed()) {
693  inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
694  return success();
695  }
696 
697  inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
698  return success();
699 }
700 
701 bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
702  if (l.size() != r.size() || l.size() != 1)
703  return false;
704  return succeeded(verifyCompatibleShape(l[0], r[0]));
705 }
706 
707 LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
708  MLIRContext *context, ::std::optional<Location> location,
709  FullyConnectedOp::Adaptor adaptor,
710  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
711  ShapeAdaptor inputShape(adaptor.getInput().getType());
712  ShapeAdaptor weightShape(adaptor.getWeight().getType());
713  ShapeAdaptor biasShape(adaptor.getBias().getType());
714 
715  // All shapes are dynamic.
716  SmallVector<int64_t> outShape;
717  outShape.resize(2, ShapedType::kDynamic);
718 
719  if (inputShape.hasRank()) {
720  outShape[0] = inputShape.getDimSize(0);
721  }
722 
723  if (weightShape.hasRank()) {
724  outShape[1] = weightShape.getDimSize(0);
725  }
726 
727  if (biasShape.hasRank()) {
728  outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0)
729  : outShape[1];
730  }
731 
732  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
733  return success();
734 }
735 
736 LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); }
737 
738 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
739  MLIRContext *context, ::std::optional<Location> location,
740  MatMulOp::Adaptor adaptor,
741  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
742  ShapeAdaptor lhsShape(adaptor.getA().getType());
743  ShapeAdaptor rhsShape(adaptor.getB().getType());
744 
745  // All shapes are dynamic.
746  SmallVector<int64_t> outShape;
747  outShape.resize(3, ShapedType::kDynamic);
748 
749  if (lhsShape.hasRank()) {
750  outShape[0] = lhsShape.getDimSize(0);
751  outShape[1] = lhsShape.getDimSize(1);
752  }
753 
754  if (rhsShape.hasRank()) {
755  outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
756  : outShape[0];
757  outShape[2] = rhsShape.getDimSize(2);
758  }
759 
760  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
761  return success();
762 }
763 
764 LogicalResult tosa::PadOp::inferReturnTypeComponents(
765  MLIRContext *context, ::std::optional<Location> location,
766  PadOp::Adaptor adaptor,
767  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
768  ShapeAdaptor inputShape(adaptor.getInput1().getType());
769  ShapeAdaptor paddingShape(adaptor.getPadding().getType());
770  SmallVector<int64_t> outputShape;
771 
772  // If both inputs have unknown shape, we cannot determine the shape of the
773  // output.
774  if (!inputShape.hasRank() && !paddingShape.hasRank()) {
775  inferredReturnShapes.push_back(ShapedTypeComponents());
776  return success();
777  }
778 
779  // If the input rank is unknown we can info the output rank using the padding
780  // shape's first dim.
781  if (!inputShape.hasRank()) {
782  if (paddingShape.isDynamicDim(0)) {
783  inferredReturnShapes.push_back(ShapedTypeComponents());
784  return success();
785  }
786 
787  outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamic);
788  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
789  return success();
790  }
791 
792  DenseIntElementsAttr paddings;
793  // If the paddings value is not a constant, all dimensions must be dynamic.
794  if (!matchPattern(adaptor.getPadding(), m_Constant(&paddings))) {
795  outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
796  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
797  return success();
798  }
799 
800  SmallVector<int64_t> paddingValues;
801  for (auto val : paddings) {
802  paddingValues.push_back(val.getSExtValue());
803  }
804 
805  outputShape.reserve(inputShape.getRank());
806  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
807  if (inputShape.isDynamicDim(i)) {
808  outputShape.push_back(ShapedType::kDynamic);
809  continue;
810  }
811 
812  outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] +
813  paddingValues[i * 2 + 1]);
814  }
815 
816  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
817  return success();
818 }
819 
821  return to_vector(llvm::map_range(shape, [](int64_t dim) {
822  return dim == -1 ? ShapedType::kDynamic : dim;
823  }));
824 }
825 
826 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
827  MLIRContext *context, ::std::optional<Location> location,
828  SliceOp::Adaptor adaptor,
829  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
830  inferredReturnShapes.push_back(
831  ShapedTypeComponents(convertToMlirShape(adaptor.getSize())));
832  return success();
833 }
834 
835 LogicalResult tosa::SliceOp::verify() {
836  auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
837  if (!inputType)
838  return success();
839 
840  if (static_cast<size_t>(inputType.getRank()) != getStart().size())
841  return emitOpError(
842  "length of start attribute is not equal rank of input shape");
843 
844  if (static_cast<size_t>(inputType.getRank()) != getSize().size())
845  return emitOpError(
846  "length of size attribute is not equal rank of input shape");
847 
848  return success();
849 }
850 
851 LogicalResult tosa::TableOp::inferReturnTypeComponents(
852  MLIRContext *context, ::std::optional<Location> location,
853  TableOp::Adaptor adaptor,
854  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
855  ShapeAdaptor inputShape(adaptor.getInput().getType());
856 
857  if (!inputShape.hasRank()) {
858  inferredReturnShapes.push_back(ShapedTypeComponents());
859  return success();
860  }
861 
862  inferredReturnShapes.resize(1);
863  inputShape.getDims(inferredReturnShapes[0]);
864  return success();
865 }
866 
867 LogicalResult tosa::TileOp::inferReturnTypeComponents(
868  MLIRContext *context, ::std::optional<Location> location,
869  TileOp::Adaptor adaptor,
870  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
871  ArrayRef<int64_t> multiples = adaptor.getMultiples();
872  ShapeAdaptor inputShape(adaptor.getInput1().getType());
873  SmallVector<int64_t> outputShape;
874  if (!inputShape.hasRank()) {
875  outputShape.resize(multiples.size(), ShapedType::kDynamic);
876  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
877  return success();
878  } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
879  return failure();
880 
881  // Any non dynamic dimension can be multiplied to a known size.
882  outputShape.reserve(multiples.size());
883  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
884  int64_t dim = inputShape.getDimSize(i);
885  if (dim != ShapedType::kDynamic)
886  dim *= multiples[i];
887  outputShape.push_back(dim);
888  }
889 
890  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
891  return success();
892 }
893 
894 LogicalResult tosa::TileOp::verify() {
895  ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
896  ShapedType outputType = llvm::cast<ShapedType>(getType());
897  auto multiples = getMultiples();
898 
899  if (inputType.hasRank()) {
900  if (static_cast<size_t>(inputType.getRank()) != multiples.size())
901  return emitOpError("expect 'multiples' array to have length ")
902  << inputType.getRank() << " but got " << multiples.size() << ".";
903  if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
904  return emitOpError("expect same input and output tensor rank.");
905  } else if (outputType.hasRank() &&
906  static_cast<size_t>(outputType.getRank()) != multiples.size())
907  return emitOpError("expect 'multiples' array to have length ")
908  << outputType.getRank() << " but got " << multiples.size() << ".";
909 
910  return success();
911 }
912 
913 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
914  if (l.size() != r.size() || l.size() != 1)
915  return false;
916  return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
917 }
918 
919 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
920  MLIRContext *context, ::std::optional<Location> location,
921  ReshapeOp::Adaptor adaptor,
922  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
923  ShapeAdaptor inputShape(adaptor.getInput1().getType());
924  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
925  llvm::SmallVector<int64_t> newShapeValue =
926  convertToMlirShape(adaptor.getNewShape());
927 
928  // We cannot infer from the total number of elements so we must take the
929  // shape attribute as exact.
930  if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
931  inferredReturnShapes.push_back(
932  ShapedTypeComponents(newShapeValue, inputType));
933  return success();
934  }
935 
936  // Determine the number of elements covered by the slice of all static
937  // dimensions. This allows us to infer the length of the remaining dynamic
938  // dimension.
939  int64_t numElements = inputShape.getNumElements();
940  int64_t staticMul = 1;
941  for (auto val : newShapeValue) {
942  if (!ShapedType::isDynamic(val)) {
943  staticMul *= val;
944  }
945  }
946 
947  // Determine the length of the dynamic dimension.
948  for (auto &val : newShapeValue) {
949  if (ShapedType::isDynamic(val))
950  val = numElements / staticMul;
951  }
952 
953  inferredReturnShapes.push_back(
954  ShapedTypeComponents(newShapeValue, inputType));
955  return success();
956 }
957 
958 llvm::LogicalResult tosa::ReshapeOp::verify() {
959  TensorType inputType = getInput1().getType();
960  RankedTensorType outputType = getType();
961 
962  if (hasZeroDimension(inputType) || hasZeroDimension(outputType))
963  return emitOpError() << "tensor has a dimension with size zero. Each "
964  "dimension of a tensor must have size >= 1";
965 
966  if ((int64_t)getNewShape().size() != outputType.getRank())
967  return emitOpError() << "new shape does not match result rank";
968 
969  for (auto [newShapeDim, outputShapeDim] :
970  zip(getNewShape(), outputType.getShape()))
971  if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
972  newShapeDim != outputShapeDim)
973  return emitOpError() << "new shape is inconsistent with result shape";
974 
975  if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
976  int64_t inputElementsNum = inputType.getNumElements();
977  int64_t outputElementsNum = outputType.getNumElements();
978  if (inputElementsNum != outputElementsNum) {
979  return emitOpError() << "cannot reshape " << inputElementsNum
980  << " elements into " << outputElementsNum;
981  }
982  }
983 
984  int missingDims = llvm::count(getNewShape(), -1);
985  if (missingDims > 1)
986  return emitOpError() << "expected at most one target dimension to be -1";
987 
988  return mlir::success();
989 }
990 
991 LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int64_t> &perms) {
992  // Perms must be constants.
993  DenseIntElementsAttr permsAttr;
994  if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
995  return failure();
996 
997  // Transpose is not the identity transpose.
998  perms = llvm::to_vector(
999  llvm::map_range(permsAttr.getValues<APInt>(),
1000  [](const APInt &val) { return val.getSExtValue(); }));
1001 
1002  return success();
1003 }
1004 
1005 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
1006  MLIRContext *context, ::std::optional<Location> location,
1007  TransposeOp::Adaptor adaptor,
1008  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1009  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1010  ShapeAdaptor permsShape(adaptor.getPerms().getType());
1011 
1012  // We cannot infer anything from a rank-0 "permutation" tensor.
1013  if (permsShape.hasRank() && permsShape.getRank() == 0)
1014  return failure();
1015 
1016  // If input rank and permutation length is unknown, the output rank is
1017  // unknown.
1018  if (!inputShape.hasRank() || !permsShape.hasRank() ||
1019  permsShape.isDynamicDim(0)) {
1020  inferredReturnShapes.push_back(ShapedTypeComponents());
1021  return success();
1022  }
1023 
1024  // This would imply the number of permutations does not match the rank of the
1025  // input which is illegal.
1026  if (permsShape.getDimSize(0) != inputShape.getRank()) {
1027  return failure();
1028  }
1029 
1030  SmallVector<int64_t> outputShape;
1031  // Rank-0 means no permutations matter.
1032  if (inputShape.getRank() == 0) {
1033  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1034  return success();
1035  }
1036 
1037  // Check whether the input dimensions are all the same.
1038  bool allTheSame = true;
1039  for (int i = 1, s = inputShape.getRank(); i < s; i++) {
1040  if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
1041  allTheSame = false;
1042  break;
1043  }
1044  }
1045 
1046  // If all of the input dimensions are the same we don't care about the
1047  // permutation.
1048  if (allTheSame) {
1049  outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
1050  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1051  return success();
1052  }
1053 
1054  outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1055  // If the permuations are a constant we can directly determine the output
1056  // shape.
1057  DenseIntElementsAttr attr;
1058  if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
1059  attr.getType().getRank() == 1) {
1060  ShapeAdaptor permShape = attr;
1061  // Constant permutation must be the same length as the input rank.
1062  if (inputShape.getRank() != permShape.getRank())
1063  return emitOptionalError(location,
1064  "constant permutation must be the same length"
1065  " as the input rank");
1066 
1067  // Constant permutation values must be within the input rank.
1068  for (int i = 0, e = inputShape.getRank(); i < e; i++) {
1069  if (inputShape.getRank() <= permShape.getDimSize(i))
1070  return failure();
1071  }
1072 
1073  outputShape.reserve(inputShape.getRank());
1074  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1075  outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
1076  }
1077  }
1078 
1079  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1080  return success();
1081 }
1082 
1083 LogicalResult tosa::TransposeOp::verify() {
1084  TensorType inputType = getInput1().getType();
1085  TensorType permType = getPerms().getType();
1086  TensorType outputType = getOutput().getType();
1087 
1088  if (permType.hasRank() && permType.getRank() != 1)
1089  return emitOpError()
1090  << "expected permutation tensor to be rank 1 but got rank "
1091  << permType.getRank();
1092  if (inputType.hasRank() && permType.hasRank())
1093  if (!permType.isDynamicDim(0) &&
1094  permType.getDimSize(0) != inputType.getRank())
1095  return emitOpError() << "expected permutation tensor dim 0 to have size "
1096  << inputType.getRank()
1097  << " (input rank) but got size "
1098  << permType.getDimSize(0);
1099  if (inputType.hasRank() && outputType.hasRank() &&
1100  inputType.getRank() != outputType.getRank())
1101  return emitOpError()
1102  << "expected input tensor rank to equal result tensor rank";
1103  if (outputType.hasRank() && permType.hasRank())
1104  if (!permType.isDynamicDim(0) &&
1105  permType.getDimSize(0) != outputType.getRank())
1106  return emitOpError() << "expected permutation tensor dim 0 to have size "
1107  << outputType.getRank()
1108  << " (output rank) but got size "
1109  << permType.getDimSize(0);
1110 
1111  SmallVector<int64_t> constantPerms;
1112  if (succeeded(getConstantPerms(constantPerms))) {
1113  // Assert that the permutation tensor has a rank, which means that the rank
1114  // has been verified above.
1115  assert(permType.hasRank() &&
1116  "Unexpectedly found permutation tensor without rank");
1117  if (!isPermutationVector(constantPerms))
1118  return emitOpError() << "expected valid permutation tensor";
1119  }
1120  return success();
1121 }
1122 
1123 LogicalResult TransposeOp::reifyResultShapes(
1124  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1125 
1126  SmallVector<int64_t> transposePerms;
1127  if (getConstantPerms(transposePerms).failed())
1128  return failure();
1129 
1130  Value input = getInput1();
1131  auto inputType = cast<TensorType>(input.getType());
1132 
1133  SmallVector<OpFoldResult> returnedDims(inputType.getRank());
1134  for (auto dim : transposePerms) {
1135  int64_t dimInInput = transposePerms[dim];
1136  if (inputType.isDynamicDim(dimInInput))
1137  returnedDims[dim] =
1138  builder.create<tensor::DimOp>(getLoc(), input, dimInInput)
1139  .getResult();
1140  else
1141  returnedDims[dim] =
1142  builder.getIndexAttr(inputType.getDimSize(dimInInput));
1143  }
1144 
1145  reifiedReturnShapes.emplace_back(std::move(returnedDims));
1146  return success();
1147 }
1148 
1149 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
1150  MLIRContext *context, ::std::optional<Location> location,
1151  GatherOp::Adaptor adaptor,
1152  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1153  llvm::SmallVector<int64_t> outputShape;
1154  outputShape.resize(3, ShapedType::kDynamic);
1155 
1156  ShapeAdaptor valuesShape(adaptor.getValues().getType());
1157  if (valuesShape.hasRank()) {
1158  outputShape[0] = valuesShape.getDimSize(0);
1159  outputShape[2] = valuesShape.getDimSize(2);
1160  }
1161 
1162  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
1163  if (indicesShape.hasRank()) {
1164  if (outputShape[0] == ShapedType::kDynamic)
1165  outputShape[0] = indicesShape.getDimSize(0);
1166  if (outputShape[1] == ShapedType::kDynamic)
1167  outputShape[1] = indicesShape.getDimSize(1);
1168  }
1169 
1170  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1171  return success();
1172 }
1173 
1174 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
1175  MLIRContext *context, ::std::optional<Location> location,
1176  ResizeOp::Adaptor adaptor,
1177  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1178  llvm::SmallVector<int64_t, 4> outputShape;
1179  outputShape.resize(4, ShapedType::kDynamic);
1180 
1181  ShapeAdaptor inputShape(adaptor.getInput().getType());
1182  if (!inputShape.hasRank())
1183  return failure();
1184 
1185  outputShape[0] = inputShape.getDimSize(0);
1186  outputShape[3] = inputShape.getDimSize(3);
1187  int64_t inputHeight = inputShape.getDimSize(1);
1188  int64_t inputWidth = inputShape.getDimSize(2);
1189 
1190  if ((inputHeight == ShapedType::kDynamic) ||
1191  (inputWidth == ShapedType::kDynamic))
1192  return failure();
1193 
1194  llvm::ArrayRef<int64_t> scaleInt = adaptor.getScale();
1195  llvm::ArrayRef<int64_t> offsetInt = adaptor.getOffset();
1196  llvm::ArrayRef<int64_t> borderInt = adaptor.getBorder();
1197 
1198  // Compute the output shape based on attributes: scale, offset, and border.
1199  outputShape[1] =
1200  (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
1201  scaleInt[1]) +
1202  1;
1203 
1204  outputShape[2] =
1205  (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
1206  scaleInt[3]) +
1207  1;
1208 
1209  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1210  return success();
1211 }
1212 
1213 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
1214  MLIRContext *context, ::std::optional<Location> location,
1215  ScatterOp::Adaptor adaptor,
1216  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1217  llvm::SmallVector<int64_t> outputShape;
1218  outputShape.resize(3, ShapedType::kDynamic);
1219 
1220  ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
1221  if (valuesInShape.hasRank()) {
1222  outputShape[0] = valuesInShape.getDimSize(0);
1223  outputShape[1] = valuesInShape.getDimSize(1);
1224  outputShape[2] = valuesInShape.getDimSize(2);
1225  }
1226 
1227  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
1228  if (indicesShape.hasRank()) {
1229  if (outputShape[0] == ShapedType::kDynamic)
1230  outputShape[0] = indicesShape.getDimSize(0);
1231  }
1232 
1233  ShapeAdaptor inputShape(adaptor.getInput().getType());
1234  if (inputShape.hasRank()) {
1235  if (outputShape[0] == ShapedType::kDynamic)
1236  outputShape[0] = inputShape.getDimSize(0);
1237  if (outputShape[2] == ShapedType::kDynamic)
1238  outputShape[2] = inputShape.getDimSize(2);
1239  }
1240 
1241  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1242  return success();
1243 }
1244 
1245 static LogicalResult ReduceInferReturnTypes(
1246  ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
1247  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1248  int64_t axisVal = axis.getValue().getSExtValue();
1249  if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
1250  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1251  return success();
1252  }
1253 
1254  SmallVector<int64_t> outputShape;
1255  operandShape.getDims(outputShape);
1256  outputShape[axisVal] = 1;
1257  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1258  return success();
1259 }
1260 
1261 #define COMPATIBLE_RETURN_TYPES(OP) \
1262  bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
1263  if (l.size() != r.size() || l.size() != 1) \
1264  return false; \
1265  if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
1266  return false; \
1267  return succeeded(verifyCompatibleShape(l[0], r[0])); \
1268  }
1269 
1270 #define REDUCE_SHAPE_INFER(OP) \
1271  LogicalResult OP::inferReturnTypeComponents( \
1272  MLIRContext *context, ::std::optional<Location> location, \
1273  OP::Adaptor adaptor, \
1274  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
1275  Type inputType = \
1276  llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
1277  ShapeAdaptor inputShape(adaptor.getInput().getType()); \
1278  const Properties &prop = adaptor.getProperties(); \
1279  return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
1280  inferredReturnShapes); \
1281  } \
1282  COMPATIBLE_RETURN_TYPES(OP)
1283 
1284 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
1285 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
1286 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
1287 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
1288 REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
1289 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
1290 #undef REDUCE_SHAPE_INFER
1291 COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
1292 #undef COMPATIBLE_RETURN_TYPES
1293 
1294 template <typename T>
1295 static LogicalResult verifyReduceOp(T op) {
1296  // All TOSA reduce Ops have input, output and axis.
1297  TensorType inputType = op.getInput().getType();
1298  TensorType outputType = op.getOutput().getType();
1299  int32_t reduceAxis = op.getAxis();
1300 
1301  if (reduceAxis < 0) {
1302  op.emitOpError("reduce axis must not be negative");
1303  return failure();
1304  }
1305  if (inputType.hasRank()) {
1306  int64_t inputRank = inputType.getRank();
1307  // We allow for a special case where the input/output shape has rank 0 and
1308  // axis is also 0.
1309  if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
1310  op.emitOpError("expect input tensor rank (")
1311  << inputRank << ") to be larger than reduce axis (" << reduceAxis
1312  << ")";
1313  return failure();
1314  }
1315  }
1316  if (outputType.hasRank()) {
1317  int64_t outputRank = outputType.getRank();
1318  if (inputType.hasRank() && outputRank != inputType.getRank()) {
1319  op.emitOpError(
1320  "expect output tensor rank to be equal to input tensor rank");
1321  return failure();
1322  }
1323  if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
1324  op.emitOpError("expect output tensor rank (")
1325  << outputRank << ") to be larger than reduce axis (" << reduceAxis
1326  << ")";
1327  return failure();
1328  }
1329  // We can only verify the reduced dimension size to be 1 if this is not the
1330  // special case of output rank == 0.
1331  if (outputRank != 0) {
1332  auto outputShape = outputType.getShape();
1333  if (!outputType.isDynamicDim(reduceAxis) &&
1334  outputShape[reduceAxis] != 1) {
1335  op.emitOpError("expect reduced dimension size to be 1, got ")
1336  << outputShape[reduceAxis];
1337  return failure();
1338  }
1339  }
1340  }
1341  return success();
1342 }
1343 
1344 LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
1345 LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
1346 LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
1347 LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
1348 LogicalResult tosa::ReduceProdOp::verify() { return verifyReduceOp(*this); }
1349 LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
1350 
1351 static LogicalResult NAryInferReturnTypes(
1352  const ValueShapeRange &operands,
1353  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1354  llvm::SmallVector<int64_t> outShape;
1355  if (resolveBroadcastShape(operands, outShape).failed()) {
1356  inferredReturnShapes.push_back(ShapedTypeComponents());
1357  } else {
1358  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1359  }
1360  return success();
1361 }
1362 
1363 #define NARY_SHAPE_INFER(OP) \
1364  LogicalResult OP::inferReturnTypeComponents( \
1365  MLIRContext *context, ::std::optional<Location> location, \
1366  ValueShapeRange operands, DictionaryAttr attributes, \
1367  OpaqueProperties properties, RegionRange regions, \
1368  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
1369  return NAryInferReturnTypes(operands, inferredReturnShapes); \
1370  }
1371 
1372 NARY_SHAPE_INFER(tosa::AbsOp)
1373 NARY_SHAPE_INFER(tosa::AddOp)
1374 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
1375 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
1376 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
1377 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
1378 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
1379 NARY_SHAPE_INFER(tosa::CastOp)
1380 NARY_SHAPE_INFER(tosa::CeilOp)
1381 NARY_SHAPE_INFER(tosa::ClampOp)
1382 NARY_SHAPE_INFER(tosa::ClzOp)
1383 NARY_SHAPE_INFER(tosa::CosOp)
1384 NARY_SHAPE_INFER(tosa::ExpOp)
1385 NARY_SHAPE_INFER(tosa::FloorOp)
1386 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
1387 NARY_SHAPE_INFER(tosa::GreaterOp)
1388 NARY_SHAPE_INFER(tosa::IdentityOp)
1389 NARY_SHAPE_INFER(tosa::IntDivOp)
1390 NARY_SHAPE_INFER(tosa::LogOp)
1391 NARY_SHAPE_INFER(tosa::LogicalAndOp)
1392 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
1393 NARY_SHAPE_INFER(tosa::LogicalNotOp)
1394 NARY_SHAPE_INFER(tosa::LogicalOrOp)
1395 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
1396 NARY_SHAPE_INFER(tosa::LogicalXorOp)
1397 NARY_SHAPE_INFER(tosa::MaximumOp)
1398 NARY_SHAPE_INFER(tosa::MinimumOp)
1399 NARY_SHAPE_INFER(tosa::MulOp)
1400 NARY_SHAPE_INFER(tosa::NegateOp)
1401 NARY_SHAPE_INFER(tosa::PowOp)
1402 NARY_SHAPE_INFER(tosa::ReciprocalOp)
1403 NARY_SHAPE_INFER(tosa::RescaleOp)
1404 NARY_SHAPE_INFER(tosa::ReverseOp)
1405 NARY_SHAPE_INFER(tosa::RsqrtOp)
1406 NARY_SHAPE_INFER(tosa::SinOp)
1407 NARY_SHAPE_INFER(tosa::SelectOp)
1408 NARY_SHAPE_INFER(tosa::SubOp)
1409 NARY_SHAPE_INFER(tosa::TanhOp)
1410 NARY_SHAPE_INFER(tosa::ErfOp)
1411 NARY_SHAPE_INFER(tosa::SigmoidOp)
1412 #undef PRED_SHAPE_INFER
1413 
1414 static LogicalResult poolingInferReturnTypes(
1415  ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
1416  ArrayRef<int64_t> pad,
1417  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1418  llvm::SmallVector<int64_t> outputShape;
1419  outputShape.resize(4, ShapedType::kDynamic);
1420 
1421  // We only know the rank if the input type is unranked.
1422  if (!inputShape) {
1423  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1424  return success();
1425  }
1426 
1427  // Batch and number of channels are identical for pooling layer.
1428  outputShape[0] = inputShape.getDimSize(0);
1429  outputShape[3] = inputShape.getDimSize(3);
1430 
1431  int64_t height = inputShape.getDimSize(1);
1432  int64_t width = inputShape.getDimSize(2);
1433 
1434  if (!ShapedType::isDynamic(height)) {
1435  int64_t padded = height + pad[0] + pad[1] - kernel[0];
1436  outputShape[1] = padded / stride[0] + 1;
1437  }
1438 
1439  if (!ShapedType::isDynamic(width)) {
1440  int64_t padded = width + pad[2] + pad[3] - kernel[1];
1441  outputShape[2] = padded / stride[1] + 1;
1442  }
1443 
1444  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1445  return success();
1446 }
1447 
1448 LogicalResult Conv2DOp::inferReturnTypeComponents(
1449  MLIRContext *context, ::std::optional<Location> location,
1450  Conv2DOp::Adaptor adaptor,
1451  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1452  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
1453 
1454  int64_t inputWidth = ShapedType::kDynamic;
1455  int64_t inputHeight = ShapedType::kDynamic;
1456  int64_t weightWidth = ShapedType::kDynamic;
1457  int64_t weightHeight = ShapedType::kDynamic;
1458 
1459  // Input shape describes input width/height and batch.
1460 
1461  ShapeAdaptor inputShape(adaptor.getInput().getType());
1462  if (inputShape.hasRank()) {
1463  outputShape[0] = inputShape.getDimSize(0);
1464  inputHeight = inputShape.getDimSize(1);
1465  inputWidth = inputShape.getDimSize(2);
1466  }
1467 
1468  // Weight shapes describes the filter width/height and the output channels.
1469  ShapeAdaptor weightShape(adaptor.getWeight().getType());
1470  if (weightShape.hasRank()) {
1471  outputShape[3] = weightShape.getDimSize(0);
1472  weightHeight = weightShape.getDimSize(1);
1473  weightWidth = weightShape.getDimSize(2);
1474  }
1475 
1476  // Bias shape can describe the output channels.
1477  ShapeAdaptor biasShape(adaptor.getBias().getType());
1478  if (biasShape.hasRank()) {
1479  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1480  ? biasShape.getDimSize(0)
1481  : outputShape[3];
1482  }
1483 
1484  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
1485  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1486  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
1487 
1488  if (!ShapedType::isDynamic(inputHeight) &&
1489  !ShapedType::isDynamic(weightHeight)) {
1490  int64_t inputSize = inputHeight + padding[0] + padding[1];
1491  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1492  int64_t unstridedResult = inputSize - filterSize + 1;
1493  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1494  }
1495 
1496  if (!ShapedType::isDynamic(inputWidth) &&
1497  !ShapedType::isDynamic(weightWidth)) {
1498  int64_t inputSize = inputWidth + padding[2] + padding[3];
1499  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1500  int64_t unstridedResult = inputSize - filterSize + 1;
1501  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1502  }
1503 
1504  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1505  return success();
1506 }
1507 
1508 LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); }
1509 
1510 LogicalResult Conv3DOp::inferReturnTypeComponents(
1511  MLIRContext *context, ::std::optional<Location> location,
1512  Conv3DOp::Adaptor adaptor,
1513  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1514  llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
1515 
1516  int64_t inputWidth = ShapedType::kDynamic;
1517  int64_t inputHeight = ShapedType::kDynamic;
1518  int64_t inputDepth = ShapedType::kDynamic;
1519 
1520  int64_t weightWidth = ShapedType::kDynamic;
1521  int64_t weightHeight = ShapedType::kDynamic;
1522  int64_t weightDepth = ShapedType::kDynamic;
1523 
1524  // Input shape describes input width/height and batch.
1525  ShapeAdaptor inputShape(adaptor.getInput().getType());
1526  if (inputShape.hasRank()) {
1527  outputShape[0] = inputShape.getDimSize(0);
1528  inputDepth = inputShape.getDimSize(1);
1529  inputHeight = inputShape.getDimSize(2);
1530  inputWidth = inputShape.getDimSize(3);
1531  }
1532 
1533  // Weight shapes describes the filter width/height and the output channels.
1534  ShapeAdaptor weightShape(adaptor.getWeight().getType());
1535  if (weightShape.hasRank()) {
1536  outputShape[4] = weightShape.getDimSize(0);
1537  weightDepth = weightShape.getDimSize(1);
1538  weightHeight = weightShape.getDimSize(2);
1539  weightWidth = weightShape.getDimSize(3);
1540  }
1541 
1542  // Bias shape can describe the output channels.
1543  ShapeAdaptor biasShape(adaptor.getBias().getType());
1544  if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
1545  outputShape[4] = biasShape.getDimSize(0);
1546  }
1547 
1548  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
1549  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1550  llvm::ArrayRef<int64_t> pad = adaptor.getPad();
1551 
1552  if (!ShapedType::isDynamic(inputDepth) &&
1553  !ShapedType::isDynamic(weightDepth)) {
1554  int32_t inputSize = inputDepth + pad[0] + pad[1];
1555  int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
1556  int32_t unstridedResult = inputSize - filterSize + 1;
1557  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1558  }
1559 
1560  if (!ShapedType::isDynamic(inputHeight) &&
1561  !ShapedType::isDynamic(weightHeight)) {
1562  int32_t inputSize = inputHeight + pad[2] + pad[3];
1563  int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
1564  int32_t unstridedResult = inputSize - filterSize + 1;
1565  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1566  }
1567 
1568  if (!ShapedType::isDynamic(inputWidth) &&
1569  !ShapedType::isDynamic(weightWidth)) {
1570  int32_t inputSize = inputWidth + pad[4] + pad[5];
1571  int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
1572  int32_t unstridedResult = inputSize - filterSize + 1;
1573  outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
1574  }
1575 
1576  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1577  return success();
1578 }
1579 
1580 LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); }
1581 
1582 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
1583  MLIRContext *context, ::std::optional<Location> location,
1584  AvgPool2dOp::Adaptor adaptor,
1585  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1586  ShapeAdaptor inputShape(adaptor.getInput().getType());
1587  const Properties &prop = adaptor.getProperties();
1588  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
1589  inferredReturnShapes);
1590 }
1591 
1592 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
1593  MLIRContext *context, ::std::optional<Location> location,
1594  MaxPool2dOp::Adaptor adaptor,
1595  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1596  ShapeAdaptor inputShape(adaptor.getInput().getType());
1597  const Properties &prop = adaptor.getProperties();
1598  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
1599  inferredReturnShapes);
1600 }
1601 
1602 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
1603  MLIRContext *context, ::std::optional<Location> location,
1604  DepthwiseConv2DOp::Adaptor adaptor,
1605  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1606  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
1607 
1608  int64_t inputWidth = ShapedType::kDynamic;
1609  int64_t inputHeight = ShapedType::kDynamic;
1610  int64_t inputChannels = ShapedType::kDynamic;
1611 
1612  int64_t weightWidth = ShapedType::kDynamic;
1613  int64_t weightHeight = ShapedType::kDynamic;
1614  int64_t depthChannels = ShapedType::kDynamic;
1615 
1616  // Input shape describes input width/height and batch.
1617  ShapeAdaptor inputShape(adaptor.getInput().getType());
1618  if (inputShape.hasRank()) {
1619  outputShape[0] = inputShape.getDimSize(0);
1620  inputHeight = inputShape.getDimSize(1);
1621  inputWidth = inputShape.getDimSize(2);
1622  inputChannels = inputShape.getDimSize(3);
1623  }
1624 
1625  // Weight shapes describes the filter width/height and the output channels.
1626  ShapeAdaptor weightShape(adaptor.getWeight().getType());
1627  if (weightShape.hasRank()) {
1628  weightHeight = weightShape.getDimSize(0);
1629  weightWidth = weightShape.getDimSize(1);
1630  inputChannels = ShapedType::isDynamic(inputChannels)
1631  ? weightShape.getDimSize(2)
1632  : inputChannels;
1633  depthChannels = weightShape.getDimSize(3);
1634  }
1635 
1636  // If both inputChannels and depthChannels are available we can determine
1637  // the output channels.
1638  if (!ShapedType::isDynamic(inputChannels) &&
1639  !ShapedType::isDynamic(depthChannels)) {
1640  outputShape[3] = inputChannels * depthChannels;
1641  }
1642 
1643  // Bias shape can describe the output channels.
1644  ShapeAdaptor biasShape(adaptor.getBias().getType());
1645  if (biasShape.hasRank()) {
1646  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1647  ? biasShape.getDimSize(0)
1648  : outputShape[3];
1649  }
1650 
1651  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
1652  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
1653  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1654 
1655  if (!ShapedType::isDynamic(inputHeight) &&
1656  !ShapedType::isDynamic(weightHeight)) {
1657  int64_t inputSize = inputHeight + padding[0] + padding[1];
1658  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1659  int64_t unstridedResult = inputSize - filterSize + 1;
1660  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1661  }
1662 
1663  if (!ShapedType::isDynamic(inputWidth) &&
1664  !ShapedType::isDynamic(weightWidth)) {
1665  int64_t inputSize = inputWidth + padding[2] + padding[3];
1666  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1667  int64_t unstridedResult = inputSize - filterSize + 1;
1668  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1669  }
1670 
1671  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1672  return success();
1673 }
1674 
1675 LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); }
1676 
1677 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
1678  MLIRContext *context, ::std::optional<Location> location,
1679  TransposeConv2DOp::Adaptor adaptor,
1680  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1681  // outputShape is mutable.
1682  llvm::SmallVector<int64_t> outputShape =
1683  convertToMlirShape(adaptor.getOutShape());
1684 
1685  int64_t inputWidth = ShapedType::kDynamic;
1686  int64_t inputHeight = ShapedType::kDynamic;
1687  int64_t weightWidth = ShapedType::kDynamic;
1688  int64_t weightHeight = ShapedType::kDynamic;
1689 
1690  // Input shape describes input width/height and batch.
1691  ShapeAdaptor inputShape(adaptor.getInput().getType());
1692  if (inputShape.hasRank()) {
1693  outputShape[0] = ShapedType::isDynamic(outputShape[0])
1694  ? inputShape.getDimSize(0)
1695  : outputShape[0];
1696  inputHeight = inputShape.getDimSize(1);
1697  inputWidth = inputShape.getDimSize(2);
1698  }
1699 
1700  // Weight shapes describes the filter width/height and the output channels.
1701  ShapeAdaptor weightShape(adaptor.getFilter().getType());
1702  if (weightShape.hasRank()) {
1703  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1704  ? weightShape.getDimSize(0)
1705  : outputShape[3];
1706  weightHeight = weightShape.getDimSize(1);
1707  weightWidth = weightShape.getDimSize(2);
1708  }
1709 
1710  // Bias shape can describe the output channels.
1711  ShapeAdaptor biasShape(adaptor.getInput().getType());
1712  if (biasShape.hasRank()) {
1713  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1714  ? biasShape.getDimSize(0)
1715  : outputShape[3];
1716  }
1717 
1718  llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
1719  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1720 
1721  if (!ShapedType::isDynamic(inputHeight) &&
1722  !ShapedType::isDynamic(weightHeight)) {
1723  int64_t calculateSize =
1724  (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
1725  outputShape[1] =
1726  ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
1727  }
1728 
1729  if (!ShapedType::isDynamic(inputWidth) &&
1730  !ShapedType::isDynamic(weightWidth)) {
1731  int64_t calculateSize =
1732  (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
1733  outputShape[2] =
1734  ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
1735  }
1736 
1737  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1738  return success();
1739 }
1740 
1741 LogicalResult IfOp::inferReturnTypeComponents(
1742  MLIRContext *context, ::std::optional<Location> location,
1743  IfOp::Adaptor adaptor,
1744  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1746  for (Region *region : adaptor.getRegions()) {
1747  for (auto &block : *region)
1748  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1749  yieldOps.push_back(returnOp);
1750  }
1751 
1752  if (yieldOps.empty())
1753  return failure();
1754 
1755  // Get the initial type information for the yield op.
1756  llvm::SmallVector<ValueKnowledge> resultKnowledge;
1757  resultKnowledge.reserve(yieldOps.front().getNumOperands());
1758  for (auto operand : yieldOps.front().getOperands()) {
1759  resultKnowledge.push_back(
1760  ValueKnowledge::getKnowledgeFromType(operand.getType()));
1761  }
1762 
1763  for (auto yieldOp : yieldOps) {
1764  if (resultKnowledge.size() != yieldOp.getNumOperands())
1765  return failure();
1766 
1767  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
1768  int32_t index = it.index();
1769  auto meet = ValueKnowledge::meet(
1770  resultKnowledge[index],
1771  ValueKnowledge::getKnowledgeFromType(it.value().getType()));
1772  if (!meet)
1773  continue;
1774  resultKnowledge[index] = meet;
1775  }
1776  }
1777 
1778  for (const ValueKnowledge &result : resultKnowledge) {
1779  inferredReturnShapes.push_back(result.getShapedTypeComponents());
1780  }
1781 
1782  return success();
1783 }
1784 
1785 LogicalResult WhileOp::inferReturnTypeComponents(
1786  MLIRContext *context, ::std::optional<Location> location,
1787  WhileOp::Adaptor adaptor,
1788  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1790  for (auto &block : adaptor.getBody())
1791  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1792  yieldOps.push_back(returnOp);
1793 
1794  // TOSA's while must have a tosa.yield as its terminator. If not found this
1795  // tosa.while is invalid.
1796  if (yieldOps.empty())
1797  return failure();
1798 
1799  // Get the initial type information from the operand types.
1800  llvm::SmallVector<ValueKnowledge> resultKnowledge;
1801  resultKnowledge.reserve(yieldOps.front().getNumOperands());
1802  for (auto operand : yieldOps.front().getOperands()) {
1803  resultKnowledge.push_back(
1804  ValueKnowledge::getKnowledgeFromType(operand.getType()));
1805  }
1806 
1807  for (auto yieldOp : yieldOps) {
1808  if (resultKnowledge.size() != yieldOp.getNumOperands())
1809  return failure();
1810 
1811  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
1812  int32_t index = it.index();
1813  if (auto meet = ValueKnowledge::meet(
1814  resultKnowledge[index],
1815  ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
1816  resultKnowledge[index] = meet;
1817  }
1818  }
1819  }
1820 
1821  for (const ValueKnowledge &result : resultKnowledge) {
1822  inferredReturnShapes.push_back(result.getShapedTypeComponents());
1823  }
1824 
1825  return success();
1826 }
1827 
1828 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
1829  if (auto vt = llvm::dyn_cast<VectorType>(getType()))
1830  return llvm::to_vector<4>(vt.getShape());
1831  return std::nullopt;
1832 }
1833 
1834 // parse and print of IfOp refer to the implementation of SCF dialect.
1835 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
1836  // Create the regions for 'then'.
1837  result.regions.reserve(2);
1838  Region *thenRegion = result.addRegion();
1839  Region *elseRegion = result.addRegion();
1840 
1841  auto &builder = parser.getBuilder();
1843  // Create a i1 tensor type for the boolean condition.
1844  Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
1845  if (parser.parseOperand(cond) ||
1846  parser.resolveOperand(cond, i1Type, result.operands))
1847  return failure();
1848  // Parse optional results type list.
1849  if (parser.parseOptionalArrowTypeList(result.types))
1850  return failure();
1851  // Parse the 'then' region.
1852  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
1853  return failure();
1854 
1855  // If we find an 'else' keyword then parse the 'else' region.
1856  if (!parser.parseOptionalKeyword("else")) {
1857  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
1858  return failure();
1859  }
1860 
1861  // Parse the optional attribute list.
1862  if (parser.parseOptionalAttrDict(result.attributes))
1863  return failure();
1864  return success();
1865 }
1866 
1867 void IfOp::print(OpAsmPrinter &p) {
1868  bool printBlockTerminators = false;
1869 
1870  p << " " << getCond();
1871  if (!getResults().empty()) {
1872  p << " -> (" << getResultTypes() << ")";
1873  // Print yield explicitly if the op defines values.
1874  printBlockTerminators = true;
1875  }
1876  p << ' ';
1877  p.printRegion(getThenBranch(),
1878  /*printEntryBlockArgs=*/false,
1879  /*printBlockTerminators=*/printBlockTerminators);
1880 
1881  // Print the 'else' regions if it exists and has a block.
1882  auto &elseRegion = getElseBranch();
1883  if (!elseRegion.empty()) {
1884  p << " else ";
1885  p.printRegion(elseRegion,
1886  /*printEntryBlockArgs=*/false,
1887  /*printBlockTerminators=*/printBlockTerminators);
1888  }
1889 
1890  p.printOptionalAttrDict((*this)->getAttrs());
1891 }
1892 
1893 LogicalResult ReverseOp::verify() {
1894  TensorType inputType = getInput().getType();
1895  TensorType outputType = getOutput().getType();
1896  int32_t reverseAxis = getAxis();
1897 
1898  if (reverseAxis < 0)
1899  return emitOpError("expected non-negative reverse axis");
1900  if (inputType.hasRank()) {
1901  int64_t inputRank = inputType.getRank();
1902  // We allow for a special case where the input/output shape has rank 0 and
1903  // axis is also 0.
1904  if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
1905  return emitOpError("expect input tensor rank (")
1906  << inputRank << ") to be larger than reverse axis (" << reverseAxis
1907  << ")";
1908  }
1909  if (outputType.hasRank()) {
1910  int64_t outputRank = outputType.getRank();
1911  if (inputType.hasRank() && outputRank != inputType.getRank())
1912  return emitOpError(
1913  "expect output tensor rank to be equal to input tensor rank");
1914  if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
1915  return emitOpError("expect output tensor rank (")
1916  << outputRank << ") to be larger than reverse axis ("
1917  << reverseAxis << ")";
1918  }
1919  return success();
1920 }
1921 
1922 // parse and print of WhileOp refer to the implementation of SCF dialect.
1923 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
1926  Region *cond = result.addRegion();
1927  Region *body = result.addRegion();
1928 
1929  OptionalParseResult listResult =
1930  parser.parseOptionalAssignmentList(regionArgs, operands);
1931  if (listResult.has_value() && failed(listResult.value()))
1932  return failure();
1933 
1934  FunctionType functionType;
1935  SMLoc typeLoc = parser.getCurrentLocation();
1936  if (failed(parser.parseColonType(functionType)))
1937  return failure();
1938 
1939  result.addTypes(functionType.getResults());
1940 
1941  if (functionType.getNumInputs() != operands.size()) {
1942  return parser.emitError(typeLoc)
1943  << "expected as many input types as operands "
1944  << "(expected " << operands.size() << " got "
1945  << functionType.getNumInputs() << ")";
1946  }
1947 
1948  // Resolve input operands.
1949  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
1950  parser.getCurrentLocation(),
1951  result.operands)))
1952  return failure();
1953 
1954  // Propagate the types into the region arguments.
1955  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
1956  regionArgs[i].type = functionType.getInput(i);
1957 
1958  return failure(parser.parseRegion(*cond, regionArgs) ||
1959  parser.parseKeyword("do") || parser.parseRegion(*body) ||
1961 }
1962 
1964  Block::BlockArgListType blocksArgs,
1965  ValueRange initializers,
1966  StringRef prefix = "") {
1967  assert(blocksArgs.size() == initializers.size() &&
1968  "expected same length of arguments and initializers");
1969  if (initializers.empty())
1970  return;
1971 
1972  parser << prefix << '(';
1973  llvm::interleaveComma(
1974  llvm::zip(blocksArgs, initializers), parser,
1975  [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
1976  parser << ")";
1977 }
1978 
1979 void WhileOp::print(OpAsmPrinter &parser) {
1980  printInitializationList(parser, getCond().front().getArguments(), getInputs(),
1981  " ");
1982  parser << " : ";
1983  parser.printFunctionalType(getInputs().getTypes(), getResults().getTypes());
1984  parser << ' ';
1985  parser.printRegion(getCond(), /*printEntryBlockArgs=*/false);
1986  parser << " do ";
1987  parser.printRegion(getBody());
1988  parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
1989 }
1990 
1991 //===----------------------------------------------------------------------===//
1992 // TOSA Attribute Definitions.
1993 //===----------------------------------------------------------------------===//
1994 
1995 #define GET_ATTRDEF_CLASSES
1996 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
1997 
1998 //===----------------------------------------------------------------------===//
1999 // TOSA Operator Definitions.
2000 //===----------------------------------------------------------------------===//
2001 
2002 #define GET_OP_CLASSES
2003 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition: TosaOps.cpp:427
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:1245
#define REDUCE_SHAPE_INFER(OP)
Definition: TosaOps.cpp:1270
static bool hasZeroDimension(ShapedType shapedType)
Definition: TosaOps.cpp:207
static LogicalResult verifyConvOp(T op)
Definition: TosaOps.cpp:224
static void buildUnaryOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter unary operators that have scale relationship between their...
Definition: TosaOps.cpp:482
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:1414
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition: TosaOps.cpp:495
static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias)
The tosa.fully_connected op has its own builder as it does not have strides/dilation/padding.
Definition: TosaOps.cpp:408
static LogicalResult verifyReduceOp(T op)
Definition: TosaOps.cpp:1295
#define NARY_SHAPE_INFER(OP)
Definition: TosaOps.cpp:1363
static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings, Value padConst)
This builder is called on TOSA pad operator when an explicit pad_const value is passed in.
Definition: TosaOps.cpp:507
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:1351
#define COMPATIBLE_RETURN_TYPES(OP)
Definition: TosaOps.cpp:1261
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition: TosaOps.cpp:523
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition: TosaOps.cpp:387
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr but avg_pool operator has...
Definition: TosaOps.cpp:464
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition: TosaOps.cpp:820
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition: TosaOps.cpp:365
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Definition: TosaOps.cpp:1963
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:128
IntegerType getI32Type()
Definition: Builders.cpp:87
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:91
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:94
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:210
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:97
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:59
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:126
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Definition: QuantUtils.cpp:120
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
Definition: QuantUtils.cpp:239
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
Definition: QuantUtils.cpp:219
ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &attr)
Definition: TosaOps.cpp:164
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
Definition: QuantUtils.cpp:164
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, Attribute attr)
Definition: TosaOps.cpp:186
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Definition: QuantUtils.cpp:193
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:496
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
Definition: ShapeUtils.h:33
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45