MLIR  21.0.0git
TosaOps.cpp
Go to the documentation of this file.
1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://developer.mlplatform.org/w/tosa/
12 //
13 //===----------------------------------------------------------------------===//
14 
22 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 #include <numeric>
34 
35 using namespace mlir;
36 using namespace mlir::tosa;
37 
38 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
40 
41 //===----------------------------------------------------------------------===//
42 // Tosa dialect interface includes.
43 //===----------------------------------------------------------------------===//
44 
45 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
46 
47 namespace {
48 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
49 
50 //===----------------------------------------------------------------------===//
51 // Dialect Function Inliner Interface.
52 //===----------------------------------------------------------------------===//
53 struct TosaInlinerInterface : public DialectInlinerInterface {
55 
56  //===--------------------------------------------------------------------===//
57  // Analysis Hooks.
58  //===--------------------------------------------------------------------===//
59 
60  /// All operations can be inlined by default.
61  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
62  IRMapping &map) const final {
63  return true;
64  }
65 
66  /// All regions with If and While parent operators can be inlined.
67  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
68  IRMapping &map) const final {
69  return (isa<tosa::IfOp>(dest->getParentOp()) ||
70  isa<tosa::WhileOp>(dest->getParentOp()));
71  }
72 };
73 
74 /// This class implements the bytecode interface for the Tosa dialect.
75 struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
76  TosaDialectBytecodeInterface(Dialect *dialect)
77  : BytecodeDialectInterface(dialect) {}
78 
79  //===--------------------------------------------------------------------===//
80  // Attributes
81 
82  Attribute readAttribute(DialectBytecodeReader &reader) const override {
83  return ::readAttribute(getContext(), reader);
84  }
85 
86  LogicalResult writeAttribute(Attribute attr,
87  DialectBytecodeWriter &writer) const override {
88  return ::writeAttribute(attr, writer);
89  }
90 
91  //===--------------------------------------------------------------------===//
92  // Types
93 
94  Type readType(DialectBytecodeReader &reader) const override {
95  return ::readType(getContext(), reader);
96  }
97 
98  LogicalResult writeType(Type type,
99  DialectBytecodeWriter &writer) const override {
100  return ::writeType(type, writer);
101  }
102 
103  void writeVersion(DialectBytecodeWriter &writer) const final {
104  // TODO: Populate.
105  }
106 
107  std::unique_ptr<DialectVersion>
108  readVersion(DialectBytecodeReader &reader) const final {
109  // TODO: Populate
110  reader.emitError("Dialect does not support versioning");
111  return nullptr;
112  }
113 
114  LogicalResult upgradeFromVersion(Operation *topLevelOp,
115  const DialectVersion &version) const final {
116  return success();
117  }
118 };
119 
120 } // namespace
121 
122 //===----------------------------------------------------------------------===//
123 // TOSA control flow support.
124 //===----------------------------------------------------------------------===//
125 
126 /// Returns the while loop body.
127 SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
128 
129 //===----------------------------------------------------------------------===//
130 // Tosa dialect initialization.
131 //===----------------------------------------------------------------------===//
132 
133 void TosaDialect::initialize() {
134  addTypes<
135 #define GET_TYPEDEF_LIST
136 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
137  >();
138  addOperations<
139 #define GET_OP_LIST
140 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
141  >();
142  addAttributes<
143 #define GET_ATTRDEF_LIST
144 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
145  >();
146  addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
147  declarePromisedInterfaces<
148  mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
149  ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
150  LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
151  LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
152  BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
153  NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
154  GreaterEqualOp, MatMulOp>();
155 }
156 
158  Type type, Location loc) {
159  // Tosa dialect constants only support ElementsAttr unlike standard dialect
160  // constant which supports all attributes.
161  if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
162  return builder.create<tosa::ConstShapeOp>(
163  loc, type, llvm::cast<DenseIntElementsAttr>(value));
164  }
165  if (llvm::isa<ElementsAttr>(value))
166  return builder.create<tosa::ConstOp>(loc, type,
167  llvm::cast<ElementsAttr>(value));
168  return nullptr;
169 }
170 
171 //===----------------------------------------------------------------------===//
172 // Parsers and printers
173 //===----------------------------------------------------------------------===//
174 
175 ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
176  Attribute &attr) {
177  if (succeeded(parser.parseOptionalEqual())) {
178  if (failed(parser.parseAttribute(attr))) {
179  return parser.emitError(parser.getCurrentLocation())
180  << "expected attribute";
181  }
182  if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
183  typeAttr = TypeAttr::get(typedAttr.getType());
184  }
185  return success();
186  }
187 
188  Type type;
189  if (failed(parser.parseColonType(type))) {
190  return parser.emitError(parser.getCurrentLocation()) << "expected type";
191  }
192  typeAttr = TypeAttr::get(type);
193 
194  return success();
195 }
196 
198  Attribute attr) {
199  bool needsSpace = false;
200  auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
201  if (!typedAttr || typedAttr.getType() != type.getValue()) {
202  p << ": ";
203  p.printAttribute(type);
204  needsSpace = true; // subsequent attr value needs a space separator
205  }
206  if (attr) {
207  if (needsSpace)
208  p << ' ';
209  p << "= ";
210  p.printAttribute(attr);
211  }
212 }
213 
214 //===----------------------------------------------------------------------===//
215 // TOSA Operator Verifiers.
216 //===----------------------------------------------------------------------===//
217 
218 template <typename T>
219 static LogicalResult verifyConvOp(T op) {
220  // All TOSA conv ops have an input and weight arguments which must be ranked
221  // tensors.
222  auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
223  if (!inputType) {
224  op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
225  return failure();
226  }
227 
228  auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
229  if (!weightType) {
230  op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
231  return failure();
232  }
233 
234  auto inputEType = inputType.getElementType();
235  auto weightEType = weightType.getElementType();
236  auto biasEType =
237  llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
238  auto resultEType =
239  llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
240  bool biasIsFloat = llvm::isa<FloatType>(biasEType);
241  bool resultIsFloat = llvm::isa<FloatType>(resultEType);
242 
243  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
244  inputEType = quantType.getStorageType();
245 
246  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
247  biasEType = quantType.getStorageType();
248 
249  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
250  resultEType = quantType.getStorageType();
251 
252  if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
253  // for now, only enforce bias element type == result element type for
254  // float types.
255  op.emitOpError(
256  "expect both bias and result to have same element type, got ")
257  << biasEType << " and " << resultEType;
258  return failure();
259  }
260 
261  if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
262  isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
263  if (inputEType != weightEType) {
264  op.emitOpError(
265  "expect both input and weight to have same element type, got ")
266  << inputEType << " and " << weightEType;
267  return failure();
268  }
269  }
270 
271  bool inputIsFloat = llvm::isa<FloatType>(inputEType);
272  bool weightIsFloat = llvm::isa<FloatType>(weightEType);
273 
274  // Either both must be float or both non-float.
275  if (inputIsFloat != weightIsFloat) {
276  op.emitOpError(
277  "expect both input and weight to be float or not together, got ")
278  << inputEType << " and " << weightEType;
279  return failure();
280  }
281 
282  // We require an explicit input zero point and weight zero point for i8
283  // convolution.
284  if (!op.getInputZp() && !op.getWeightZp())
285  return inputEType.isInteger(8) ? failure() : success();
286 
287  ElementsAttr inputZpAttr;
288  ElementsAttr weightZpAttr;
289  if (!matchPattern(op.getInputZp(), m_Constant(&inputZpAttr)) ||
290  !matchPattern(op.getWeightZp(), m_Constant(&weightZpAttr))) {
291  op.emitOpError(
292  "bail out if the actual value of zero points cannot be determined");
293  return failure();
294  }
295 
296  // Get and verify explicit zero points.
297  int64_t inputZpVal;
298  int64_t weightZpVal;
299 
300  if (tosa::getZeroPoint(inputZpAttr, inputZpVal).failed() ||
301  tosa::verifyZeroPoint<T>(getElementTypeOrSelf(inputZpAttr), inputZpVal)
302  .failed()) {
303  op.emitOpError("input zero point must be zero for non-int8 integer types");
304  return failure();
305  }
306 
307  if (tosa::getZeroPoint(weightZpAttr, weightZpVal).failed() ||
308  tosa::verifyZeroPoint<T>(getElementTypeOrSelf(weightZpAttr), weightZpVal)
309  .failed()) {
310  op.emitOpError("weight zero point must be zero for non-int8 integer types");
311  return failure();
312  }
313 
314  return success();
315 }
316 
317 LogicalResult tosa::ConstOp::verify() {
318 
319  auto attrType = llvm::dyn_cast<TensorType>(getValueAttr().getType());
320  auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
321 
322  if (!attrType || !outputType) {
323  emitOpError("expected tensors for attr/result type");
324  return failure();
325  }
326 
327  if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
328  outputType.getElementType())) {
329  if (result.getStorageType() == attrType.getElementType())
330  return success();
331  }
332 
333  if (attrType.getElementType() != outputType.getElementType()) {
334  emitOpError("expected same attr/result element types");
335  return failure();
336  }
337 
338  return success();
339 }
340 
341 template <typename T>
342 static LogicalResult verifyConvOpModes(T op) {
343  auto inputEType =
344  llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
345 
346  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
347  inputEType = quantType.getStorageType();
348 
349  auto accType = op.getAccType();
350  if (inputEType.isInteger(8) && !accType.isInteger(32))
351  return op.emitOpError("accumulator type for i8 tensor is not i32");
352 
353  if (inputEType.isInteger(16) && !accType.isInteger(48))
354  return op.emitOpError("accumulator type for i16 tensor is not i48");
355 
356  if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
357  return op.emitOpError("accumulator type for f8 tensor is not f16");
358 
359  if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
360  return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
361 
362  if (inputEType.isBF16() && !accType.isF32())
363  return op.emitOpError("accumulator type for bf16 tensor is not f32");
364 
365  if (inputEType.isF32() && !accType.isF32())
366  return op.emitOpError("accumulator type for f32 tensor is not f32");
367 
368  auto resultEType =
369  llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
370 
371  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
372  resultEType = quantType.getStorageType();
373 
374  // check allowed input/result element types combinations
375  if ((inputEType.isInteger(8) && resultEType.isInteger(32)) ||
376  (inputEType.isInteger(16) && resultEType.isInteger(48)) ||
377  (isa<Float8E5M2Type>(inputEType) && resultEType.isF16()) ||
378  (isa<Float8E4M3FNType>(inputEType) && resultEType.isF16()) ||
379  (inputEType.isF16() && resultEType.isF16()) ||
380  (inputEType.isBF16() && resultEType.isBF16()) ||
381  (inputEType.isF32() && resultEType.isF32()))
382  return success();
383 
384  return op.emitOpError("input/output element types are incompatible.");
385 }
386 
387 // verify that inType and outType have same element types
388 template <typename T>
389 static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
390  auto inputType = llvm::dyn_cast<TensorType>(inType);
391  auto outputType = llvm::dyn_cast<TensorType>(outType);
392  if (!inputType) {
393  op.emitOpError("expect shaped tensor for input, got ") << inType;
394  return failure();
395  }
396  if (!outputType) {
397  op.emitOpError("expect shaped tensor for output, got ") << outType;
398  return failure();
399  }
400  auto inputElementType = inputType.getElementType();
401  auto outputElementType = outputType.getElementType();
402  auto inputQuantType =
403  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
404  auto outputQuantType =
405  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
406  if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
407  (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
408  inputElementType != outputElementType) {
409  // only check if both element types are int/index/float/UniformQuantized
410  // eg, not sure how to check quant::QuantizedType
411  // this happens in test_conv2d_q_grouped_convolution in
412  // tfl-to-tosa-pipeline.mlir
413  op.emitOpError("expect input and output to have same element type, got ")
414  << inputElementType << " and " << outputElementType;
415  return failure();
416  }
417  return success();
418 }
419 
420 LogicalResult tosa::ArgMaxOp::verify() {
421  // Ensure output is of 32-bit integer
422  const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
423  if (!resultETy.isIntOrIndex())
424  return emitOpError("result tensor is not of integer type");
425 
426  // Ensure axis is within the tensor rank
427  const auto inputType = llvm::cast<ShapedType>(getInput().getType());
428  const int64_t axis = getAxisAttr().getInt();
429  if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
430  return emitOpError("specified axis is outside the rank of the tensor");
431 
432  return success();
433 }
434 
435 LogicalResult tosa::AvgPool2dOp::verify() {
436  auto inputType = llvm::cast<ShapedType>(getInput().getType());
437 
438  auto inputETy = inputType.getElementType();
439  auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
440 
441  if (auto quantType =
442  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
443  inputETy = quantType.getStorageType();
444 
445  if (auto quantType =
446  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy))
447  resultETy = quantType.getStorageType();
448 
449  auto accType = getAccType();
450  if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
451  return emitOpError("accumulator type for integer tensor is not i32");
452 
453  if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
454  return emitOpError("accumulator type for f16 tensor is not f16/f32");
455 
456  if (inputETy.isBF16() && !accType.isF32())
457  return emitOpError("accumulator type for bf16 tensor is not f32");
458 
459  if (inputETy.isF32() && !accType.isF32())
460  return emitOpError("accumulator type for f32 tensor is not f32");
461 
462  if ((inputETy.isF32() && resultETy.isF32()) ||
463  (inputETy.isF16() && resultETy.isF16()) ||
464  (inputETy.isBF16() && resultETy.isBF16()) ||
465  (inputETy.isInteger(8) && resultETy.isInteger(8)) ||
466  (inputETy.isInteger(16) && resultETy.isInteger(16)))
467  return success();
468 
469  return emitOpError("input/output element types are incompatible.");
470 }
471 
472 LogicalResult tosa::ClampOp::verify() {
473  mlir::Type inputETy =
474  llvm::cast<ShapedType>(getInput().getType()).getElementType();
475  if (auto quantType =
476  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
477  inputETy = quantType.getStorageType();
478  }
479  mlir::Type outputETy =
480  llvm::cast<ShapedType>(getOutput().getType()).getElementType();
481  if (auto quantType =
482  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
483  outputETy = quantType.getStorageType();
484  }
485  if (inputETy != outputETy)
486  return emitOpError("input/output element types are incompatible.");
487 
488  auto maxValAttr = getMaxValAttr();
489  auto minValAttr = getMinValAttr();
490 
491  unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
492 
493  if (inputETy.isInteger(dataTypeBitWidth)) {
494  // if input datatype is integer, check that the min_val/max_val attributes
495  // are integer attributes, and that their type is the same as the input's
496  // datatype
497  auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
498  auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
499  if (!intMaxValAttr || !intMinValAttr ||
500  (intMaxValAttr.getType() != intMinValAttr.getType()) ||
501  (intMaxValAttr.getType() != inputETy))
502  return emitOpError("min/max attributes types are incompatible with "
503  "input/output element types.");
504  } else {
505  // otherwise, input datatype is float, check that the min_val/max_val
506  // attributes share the same type and that their type is the same as the
507  // input's datatype
508  auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
509  auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
510  if (!floatMaxValAttr || !floatMinValAttr ||
511  (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
512  (floatMaxValAttr.getType() != inputETy))
513  return emitOpError("min/max attributes types are incompatible with "
514  "input/output element types.");
515  }
516 
517  return success();
518 }
519 
520 //===----------------------------------------------------------------------===//
521 // TOSA Operator Quantization Builders.
522 //===----------------------------------------------------------------------===//
523 
524 /// This builder is called on all convolution operators except TransposeConv,
525 /// which has specialized output shape semantics. The builder also defines the
526 /// bitwidth of the output given the bit width of the input & weight content.
527 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
528  Type outputType, Value input, Value weight,
529  Value bias, DenseI64ArrayAttr pad,
530  DenseI64ArrayAttr stride,
531  DenseI64ArrayAttr dilation,
532  TypeAttr accType) {
533  auto zps = createZPsAsConst(builder, input, weight);
534  result.addOperands({input, weight, bias, zps.first, zps.second});
535  result.addAttribute("pad", pad);
536  result.addAttribute("stride", stride);
537  result.addAttribute("dilation", dilation);
538  result.addAttribute("acc_type", accType);
539  Type finalOutputType = outputType;
540  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
541  if (quantAttr) {
542  finalOutputType =
543  buildConvOpResultTypeInfo(builder, outputType, input, weight);
544  }
545  result.addTypes(finalOutputType);
546 }
547 
548 /// Handles tosa.transpose_conv2d which has outpad and output shape
549 /// attributes.
551  OpBuilder &builder, OperationState &result, Type outputType, Value input,
552  Value weight, Value bias, DenseI64ArrayAttr outpad,
553  DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
554  auto zps = createZPsAsConst(builder, input, weight);
555  result.addOperands({input, weight, bias, zps.first, zps.second});
556  result.addAttribute("out_pad", outpad);
557  result.addAttribute("stride", stride);
558  result.addAttribute("out_shape", outputShape);
559  result.addAttribute("acc_type", accType);
560  Type finalOutputType = outputType;
561  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
562  if (quantAttr) {
563  finalOutputType =
564  buildConvOpResultTypeInfo(builder, outputType, input, weight);
565  }
566  result.addTypes(finalOutputType);
567 }
568 
569 /// The tosa.matmul op is also intended to be generated where a fully_connected
570 /// op must be constructed where the weight is not a constant. In this case,
571 /// the fully_connected op must be expressed using matmul.
572 /// TODO: Add link to the leglization document explaining this.
574  OperationState &result, Type outputType,
575  Value a, Value b) {
576  result.addOperands({a, b});
577  auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
578 
579  if (quantAttr) {
580  result.addAttribute("a_zp", builder.getI32IntegerAttr(
581  static_cast<int32_t>(quantAttr.getAZp())));
582  result.addAttribute("b_zp", builder.getI32IntegerAttr(
583  static_cast<int32_t>(quantAttr.getBZp())));
584 
585  auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
586  assert(inputType && "Input must be a shaped tensor type!");
587 
588  auto inputQType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
589  inputType.getElementType());
590  assert(inputQType && "Tensor must have quantized datatype!");
591 
592  unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
593 
594  auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
595  assert(outputShapedType && "Output must be a shaped type");
596 
597  IntegerType accElementType;
598  if (inputBits == 16)
599  accElementType = builder.getIntegerType(48);
600  else
601  accElementType = builder.getI32Type();
602  auto accType = outputShapedType.clone(accElementType);
603  result.addTypes(accType);
604  } else {
605  result.addTypes(outputType);
606  }
607 }
608 
609 /// Both the tosa.avg_pool2d and unary ops use the same
610 /// UnaruOpQuantizationAttr but avg_pool operator has its own builder as it
611 /// has additional parameters not part of the unary ops.
612 static void
614  Type outputType, Value input,
615  DenseArrayAttr kernel, DenseArrayAttr stride,
616  DenseArrayAttr pad, TypeAttr accType) {
617  result.addOperands(input);
618  result.addAttribute("kernel", kernel);
619  result.addAttribute("stride", stride);
620  result.addAttribute("pad", pad);
621  result.addAttribute("acc_type", accType);
622  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
623  if (quantAttr) {
624  result.addAttribute("input_zp",
625  builder.getI32IntegerAttr(
626  static_cast<int32_t>(quantAttr.getInputZp())));
627  result.addAttribute("output_zp",
628  builder.getI32IntegerAttr(
629  static_cast<int32_t>(quantAttr.getOutputZp())));
630  }
631  result.types.push_back(outputType);
632 }
633 
634 /// This builder is called on single-parameter unary operators that have scale
635 /// relationship between their input and output, expressed by the
636 /// UnaryOpQuantizationAttr.
638  OperationState &result, Type outputType,
639  Value input) {
640  result.addOperands(input);
641  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
642  if (quantAttr) {
643  // note: negateOp has attributes input1_zp and output_zp
644  result.addAttribute("input1_zp",
645  builder.getI32IntegerAttr(
646  static_cast<int32_t>(quantAttr.getInputZp())));
647  result.addAttribute("output_zp",
648  builder.getI32IntegerAttr(
649  static_cast<int32_t>(quantAttr.getOutputZp())));
650  }
651  result.types.push_back(outputType);
652 }
653 
654 /// This builder is called on TOSA pad operator that needs to create its own
655 /// OptionalAttr quantization_attr parameter to scale the padding values
656 /// correctly. No pad_const is interpreted as zero-padding.
657 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
658  Type outputType, Value input,
659  Value paddings) {
660  result.addOperands({input, paddings});
661  auto quantAttr = buildPadOpQuantizationAttr(builder, input);
662  if (quantAttr) {
663  result.addAttribute("input_zp",
664  builder.getI32IntegerAttr(
665  static_cast<int32_t>(quantAttr.getInputZp())));
666  }
667  result.types.push_back(outputType);
668 }
669 
670 /// This builder is called on TOSA pad operator when an explicit pad_const
671 /// value is passed in. It also optionally constructs quantization_attr.
673  OperationState &result,
674  Type outputType, Value input,
675  Value paddings,
676  Value padConst) {
677  result.addOperands({input, paddings, padConst});
678  auto quantAttr = buildPadOpQuantizationAttr(builder, input);
679  if (quantAttr) {
680  result.addAttribute("input_zp",
681  builder.getI32IntegerAttr(
682  static_cast<int32_t>(quantAttr.getInputZp())));
683  }
684  result.types.push_back(outputType);
685 }
686 
687 //===----------------------------------------------------------------------===//
688 // TOSA Operator Return Type Inference.
689 //===----------------------------------------------------------------------===//
690 
691 static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
692  SmallVector<int64_t> &outShape) {
693  int64_t outRank = 0;
694  for (int i = 0, e = operands.size(); i != e; ++i) {
695  auto shape = operands.getShape(i);
696  if (!shape.hasRank()) {
697  // TODO(jennik): Update function to have better case handling for
698  // invalid operands and for ranked tensors.
699  return failure();
700  }
701  outRank = std::max<int64_t>(outRank, shape.getRank());
702  }
703 
704  outShape.resize(outRank, 1);
705 
706  for (int i = 0, e = operands.size(); i != e; ++i) {
707  auto shape = operands.getShape(i);
708  auto rankDiff = outShape.size() - shape.getRank();
709 
710  for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
711  auto dim1 = outShape[i + rankDiff];
712  auto dim2 = shape.getDimSize(i);
713  auto resolvedDim = dim1;
714 
715  if (dim1 == 1) {
716  resolvedDim = dim2;
717  } else if (dim2 == 1) {
718  resolvedDim = dim1;
719  } else if (dim1 != dim2) {
720  return failure();
721  }
722  outShape[i + rankDiff] = resolvedDim;
723  }
724  }
725 
726  return success();
727 }
728 
729 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
730  MLIRContext *context, ::std::optional<Location> location,
731  ArgMaxOp::Adaptor adaptor,
732  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
733  ShapeAdaptor inputShape(adaptor.getInput().getType());
734  IntegerAttr axis = adaptor.getProperties().axis;
735  int32_t axisVal = axis.getValue().getSExtValue();
736 
737  if (!inputShape.hasRank()) {
738  inferredReturnShapes.push_back(ShapedTypeComponents());
739  return success();
740  }
741 
742  SmallVector<int64_t> outShape;
743  outShape.reserve(inputShape.getRank() - 1);
744  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
745  if (i == axisVal)
746  continue;
747  outShape.push_back(inputShape.getDimSize(i));
748  }
749 
750  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
751  return success();
752 }
753 
754 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
755  MLIRContext *context, ::std::optional<Location> location,
756  RFFT2dOp::Adaptor adaptor,
757  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
758  ShapeAdaptor inputShape(adaptor.getInput().getType());
759 
760  if (!inputShape.hasRank())
761  return failure();
762 
763  llvm::SmallVector<int64_t> outputShape;
764  outputShape.resize(3, ShapedType::kDynamic);
765  outputShape[0] = inputShape.getDimSize(0);
766  outputShape[1] = inputShape.getDimSize(1);
767  int64_t inWidth = inputShape.getDimSize(2);
768 
769  // Note that we can support this calculation symbolically
770  // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1]
771  if (inWidth != ShapedType::kDynamic)
772  outputShape[2] = inWidth / 2 + 1;
773 
774  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
775  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
776 
777  return success();
778 }
779 
780 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
781  MLIRContext *context, ::std::optional<Location> location,
782  FFT2dOp::Adaptor adaptor,
783  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
784  inferredReturnShapes.push_back(
785  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
786  inferredReturnShapes.push_back(
787  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
788  return success();
789 }
790 
791 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
792  MLIRContext *context, ::std::optional<Location> location,
793  ConcatOp::Adaptor adaptor,
794  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
795  // Infer all dimension sizes by reducing based on inputs.
796  const Properties &prop = adaptor.getProperties();
797  int32_t axis = prop.axis.getValue().getSExtValue();
798  llvm::SmallVector<int64_t> outputShape;
799  bool hasRankedInput = false;
800  for (auto operand : adaptor.getOperands()) {
801  ShapeAdaptor operandShape(operand.getType());
802  if (!operandShape.hasRank())
803  continue;
804 
805  // Copy the Operand's rank.
806  if (!hasRankedInput)
807  outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
808 
809  // Copy shapes until the dim is non-dynamic.
810  for (int i = 0, s = operandShape.getRank(); i < s; i++) {
811  if (i == axis || operandShape.isDynamicDim(i))
812  continue;
813  if (outputShape[i] == ShapedType::kDynamic)
814  outputShape[i] = operandShape.getDimSize(i);
815  if (outputShape[i] != operandShape.getDimSize(i))
816  return emitOptionalError(location,
817  "Cannot concat tensors with different sizes"
818  " on the non-axis dimension ",
819  i);
820  }
821 
822  hasRankedInput = true;
823  }
824  Type inputType =
825  llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
826  if (!hasRankedInput) {
827  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
828  return success();
829  }
830 
831  // Determine the dimension size along the concatenation axis.
832  int64_t concatDimSize = 0;
833  for (auto operand : adaptor.getOperands()) {
834  ShapeAdaptor operandShape(operand.getType());
835 
836  // We need to know the length of the concatenation axis of all inputs to
837  // determine the dimension size of the output shape.
838  if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
839  concatDimSize = ShapedType::kDynamic;
840  break;
841  }
842 
843  concatDimSize += operandShape.getDimSize(axis);
844  }
845 
846  outputShape[axis] = concatDimSize;
847 
848  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
849  return success();
850 }
851 
852 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
853  MLIRContext *context, ::std::optional<Location> location,
854  ValueShapeRange operands, DictionaryAttr attributes,
855  OpaqueProperties properties, RegionRange regions,
856  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
857  auto elementType = IntegerType::get(context, /*width=*/1);
858 
860  if (resolveBroadcastShape(operands, outShape).failed()) {
861  inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
862  return success();
863  }
864 
865  inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
866  return success();
867 }
868 
869 bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
870  if (l.size() != r.size() || l.size() != 1)
871  return false;
872  return succeeded(verifyCompatibleShape(l[0], r[0]));
873 }
874 
875 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
876  MLIRContext *context, ::std::optional<Location> location,
877  MatMulOp::Adaptor adaptor,
878  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
879  ShapeAdaptor lhsShape(adaptor.getA().getType());
880  ShapeAdaptor rhsShape(adaptor.getB().getType());
881 
882  // All shapes are dynamic.
883  SmallVector<int64_t> outShape;
884  outShape.resize(3, ShapedType::kDynamic);
885 
886  if (lhsShape.hasRank()) {
887  outShape[0] = lhsShape.getDimSize(0);
888  outShape[1] = lhsShape.getDimSize(1);
889  }
890 
891  if (rhsShape.hasRank()) {
892  outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
893  : outShape[0];
894  outShape[2] = rhsShape.getDimSize(2);
895  }
896 
897  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
898  return success();
899 }
900 
901 LogicalResult tosa::PadOp::inferReturnTypeComponents(
902  MLIRContext *context, ::std::optional<Location> location,
903  PadOp::Adaptor adaptor,
904  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
905  ShapeAdaptor inputShape(adaptor.getInput1().getType());
906  auto paddingRank =
907  cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
908  SmallVector<int64_t> outputShape;
909 
910  // If the input rank is unknown, we can infer the output rank using the
911  // padding shape's rank divided by 2.
912  if (!inputShape.hasRank()) {
913  outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
914  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
915  return success();
916  }
917 
918  SmallVector<int64_t> paddingValues;
919  // If the paddings value is not a constant, all dimensions must be dynamic.
920  if (!tosa::getConstShapeValue(adaptor.getPadding().getDefiningOp(),
921  paddingValues)) {
922  outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
923  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
924  return success();
925  }
926 
927  outputShape.reserve(inputShape.getRank());
928  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
929  if (inputShape.isDynamicDim(i)) {
930  outputShape.push_back(ShapedType::kDynamic);
931  continue;
932  }
933  auto padFront = paddingValues[i * 2];
934  auto padBack = paddingValues[i * 2 + 1];
935  if (padFront < 0 || padBack < 0) {
936  // if either padding for dim i is -1, output dim is unknown
937  outputShape.push_back(ShapedType::kDynamic);
938  continue;
939  }
940 
941  outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
942  }
943 
944  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
945  return success();
946 }
947 
948 LogicalResult tosa::PadOp::verify() {
949  RankedTensorType inputType = getInput1().getType();
950  RankedTensorType outputType = getOutput().getType();
951  auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
952 
953  if (inputType.getRank() != outputType.getRank())
954  return emitOpError() << "expect same input and output tensor rank.";
955 
956  if (paddingRank != inputType.getRank() * 2)
957  return emitOpError() << "expected padding tensor dim 0 to have size "
958  << inputType.getRank() * 2
959  << " (2*rank(shape1)) but got size " << paddingRank;
960 
961  return success();
962 }
963 
965  return to_vector(llvm::map_range(shape, [](int64_t dim) {
966  return dim == -1 ? ShapedType::kDynamic : dim;
967  }));
968 }
969 
970 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
971  MLIRContext *context, ::std::optional<Location> location,
972  SliceOp::Adaptor adaptor,
973  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
974 
975  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
976  SmallVector<int64_t> start;
978 
979  if (!tosa::getConstShapeValue(adaptor.getStart().getDefiningOp(), start) ||
980  !tosa::getConstShapeValue(adaptor.getSize().getDefiningOp(), size)) {
981  auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
982  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
983  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
984  return success();
985  }
986 
987  // if size[i] is -1, all remaining elements in dimension i are included
988  // in the slice, similar to TF.
989  ShapeAdaptor inputShape(adaptor.getInput1().getType());
990  // initialize outputShape to all unknown
991  SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
992  if (inputShape.hasRank()) {
993  for (size_t i = 0; i < size.size(); i++) {
994  if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
995  (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
996  start[i] < inputShape.getDimSize(i))) {
997  // size[i] is not 0 and not < -1, and start[i] is in valid range
998  if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
999  // input shape has unknown dim[i] - only valid if size[i] > 0
1000  if (size[i] > 0) {
1001  outputShape[i] = size[i];
1002  }
1003  } else {
1004  // input shape has known dim[i]
1005  if (size[i] == -1) {
1006  outputShape[i] = inputShape.getDimSize(i) - start[i];
1007  } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
1008  // start[i] + size[i] is within bound of input shape's dim[i]
1009  outputShape[i] = size[i];
1010  }
1011  }
1012  }
1013  }
1014  } else {
1015  outputShape = convertToMlirShape(size);
1016  }
1017  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1018  return success();
1019 }
1020 
1021 LogicalResult tosa::SliceOp::verify() {
1022  auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
1023  if (!inputType)
1024  return success();
1025 
1026  auto startShapeRank =
1027  llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
1028  if (inputType.getRank() != startShapeRank)
1029  return emitOpError(
1030  "length of start attribute is not equal rank of input shape");
1031 
1032  auto sizeShapeRank =
1033  llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
1034  if (inputType.getRank() != sizeShapeRank)
1035  return emitOpError(
1036  "length of size attribute is not equal rank of input shape");
1037 
1038  return success();
1039 }
1040 
1041 LogicalResult tosa::MulOp::inferReturnTypeComponents(
1042  MLIRContext *context, ::std::optional<Location> location,
1043  ValueShapeRange operands, DictionaryAttr attributes,
1044  OpaqueProperties properties, RegionRange regions,
1045  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1046  // mul op's output shape only depend on input1 and input2, not on shift
1047  ValueShapeRange twoInputs = operands.drop_back();
1048  llvm::SmallVector<int64_t> outShape;
1049  if (resolveBroadcastShape(twoInputs, outShape).failed()) {
1050  inferredReturnShapes.push_back(ShapedTypeComponents());
1051  } else {
1052  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1053  }
1054  return success();
1055 }
1056 
1057 LogicalResult tosa::MulOp::verify() {
1058  auto resElemType = getElementTypeOrSelf(getOutput());
1059 
1060  // Verify if the element type among operands and result match tosa
1061  // specification.
1062  if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
1063  IntegerType lhsIntType =
1064  cast<IntegerType>(getElementTypeOrSelf(getInput1()));
1065  IntegerType rhsIntType =
1066  cast<IntegerType>(getElementTypeOrSelf(getInput2()));
1067  if (lhsIntType != rhsIntType)
1068  return emitOpError("requires the same element type for all operands");
1069 
1070  // Though the spec requires the element type of result to be i32, a more
1071  // relaxed way is provided at dialect level for easier cooperating with
1072  // other dialects.
1073  if (lhsIntType.getWidth() > resIntType.getWidth())
1074  return emitOpError("invalid data type size for operands or result");
1075 
1076  } else {
1077  // For other supported type, the spec requires requires the same element
1078  // type for all operands (excludes `shift` operand) and results.
1079  for (int i = 0; i < 2; ++i) {
1080  if (getElementTypeOrSelf(getOperand(i)) != resElemType)
1081  return emitOpError(
1082  "requires the same element type for all operands and results");
1083  }
1084 
1085  // verify shift has value 0 for non-integer types
1086  ElementsAttr shift_elem;
1087  if (matchPattern(getShift(), m_Constant(&shift_elem))) {
1088  int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1089  if (shift != 0) {
1090  return emitOpError() << "require shift to be 0 for float type";
1091  }
1092  }
1093  }
1094 
1095  // Verify the op has same ranks for all main operands (excludes extra operands
1096  // such as shift of mul op, so this is the only difference with the built-in
1097  // `SameOperandsAndResultRank` trait) and results types, if known.
1098 
1099  // delegate function that returns true if type is a shaped type with known
1100  // rank
1101  auto hasRank = [](const Type type) {
1102  if (auto shaped_type = dyn_cast<ShapedType>(type))
1103  return shaped_type.hasRank();
1104 
1105  return false;
1106  };
1107 
1108  auto rankedOperandTypes =
1109  llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));
1110 
1111  auto rankedResultTypes =
1112  llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
1113 
1114  // If all operands and results are unranked, then no further verification.
1115  if (rankedOperandTypes.empty() && rankedResultTypes.empty())
1116  return success();
1117 
1118  // delegate function that returns rank of shaped type with known rank
1119  auto getRank = [](const Type type) {
1120  return cast<ShapedType>(type).getRank();
1121  };
1122 
1123  auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
1124  : getRank(*rankedResultTypes.begin());
1125 
1126  for (size_t i = 0; i < 2; ++i) {
1127  if (rank != getRank(rankedOperandTypes[i])) {
1128  return emitOpError("operands don't have matching ranks");
1129  }
1130  }
1131 
1132  for (const auto type : rankedResultTypes) {
1133  if (rank != getRank(type)) {
1134  return emitOpError("result type has different rank than operands");
1135  }
1136  }
1137 
1138  // check for broadcast compatible shapes in first two operands (ignoring
1139  // shift)
1140 
1141  // delegate function that returns shape of shaped type
1142  auto getShape = [](const Type type) {
1143  return mlir::cast<ShapedType>(type).getShape();
1144  };
1145  SmallVector<int64_t> resultShape;
1146  if (!mlir::OpTrait::util::getBroadcastedShape(getShape(rankedOperandTypes[0]),
1147  getShape(rankedOperandTypes[1]),
1148  resultShape)) {
1149  return emitOpError("operands don't have broadcast-compatible shapes");
1150  }
1151 
1152  return success();
1153 }
1154 
1155 LogicalResult tosa::TableOp::inferReturnTypeComponents(
1156  MLIRContext *context, ::std::optional<Location> location,
1157  TableOp::Adaptor adaptor,
1158  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1159  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1160 
1161  if (!inputShape.hasRank()) {
1162  inferredReturnShapes.push_back(ShapedTypeComponents());
1163  return success();
1164  }
1165 
1166  inferredReturnShapes.resize(1);
1167  inputShape.getDims(inferredReturnShapes[0]);
1168  return success();
1169 }
1170 
1171 LogicalResult tosa::TableOp::verify() {
1172  TensorType inputType = getInput1().getType();
1173  TensorType outputType = getOutput().getType();
1174 
1175  if (inputType.hasRank() && outputType.hasRank() &&
1176  inputType.getRank() != outputType.getRank())
1177  return emitOpError()
1178  << "expected input tensor rank to equal result tensor rank";
1179 
1180  auto inputDims = inputType.getShape();
1181  auto outputDims = outputType.getShape();
1182  for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
1183  int64_t dim = it.index();
1184  auto [inputDim, outputDim] = it.value();
1185  if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) {
1186  return emitOpError() << "dim(result, " << dim << ") = " << outputDim
1187  << " doesn't match dim(input, " << dim
1188  << ") = " << inputDim;
1189  }
1190  }
1191  return success();
1192 }
1193 
1194 LogicalResult
1195 tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
1196  // Multiples must be constants.
1197  DenseIntElementsAttr multiplesAttr;
1198  if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
1199  return failure();
1200  multiples = llvm::to_vector(
1201  llvm::map_range(multiplesAttr.getValues<APInt>(),
1202  [](const APInt &val) { return val.getSExtValue(); }));
1203  return success();
1204 }
1205 
1206 LogicalResult tosa::TileOp::inferReturnTypeComponents(
1207  MLIRContext *context, ::std::optional<Location> location,
1208  TileOp::Adaptor adaptor,
1209  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1210  DenseIntElementsAttr multiplesAttr;
1211  if (!matchPattern(adaptor.getMultiples(), m_Constant(&multiplesAttr)))
1212  return failure();
1213 
1214  SmallVector<int64_t> multiples = llvm::to_vector(
1215  llvm::map_range(multiplesAttr.getValues<APInt>(),
1216  [](const APInt &val) { return val.getSExtValue(); }));
1217 
1218  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1219  SmallVector<int64_t> outputShape;
1220  if (!inputShape.hasRank()) {
1221  outputShape.resize(multiples.size(), ShapedType::kDynamic);
1222  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1223  return success();
1224  } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
1225  return failure();
1226 
1227  // Any non dynamic dimension can be multiplied to a known size.
1228  outputShape.reserve(multiples.size());
1229  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1230  int64_t dim = inputShape.getDimSize(i);
1231  if (dim != ShapedType::kDynamic)
1232  dim *= multiples[i];
1233  outputShape.push_back(dim);
1234  }
1235 
1236  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1237  return success();
1238 }
1239 
1240 LogicalResult tosa::TileOp::verify() {
1241  ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
1242  ShapedType outputType = llvm::cast<ShapedType>(getType());
1243 
1244  shapeType multiplesType =
1245  llvm::cast<tosa::shapeType>(getMultiples().getType());
1246 
1247  auto multiplesRank = multiplesType.getRank();
1248 
1249  if (inputType.hasRank()) {
1250  if (inputType.getRank() != multiplesRank)
1251  return emitOpError("expect 'multiples' to have rank ")
1252  << inputType.getRank() << " but got " << multiplesRank << ".";
1253  if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
1254  return emitOpError("expect same input and output tensor rank.");
1255  } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
1256  return emitOpError("expect 'multiples' array to have length ")
1257  << outputType.getRank() << " but got " << multiplesRank << ".";
1258 
1259  SmallVector<int64_t> multiples;
1260  if (getConstantMultiples(multiples).succeeded() &&
1261  llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
1262  return emitOpError(
1263  "expect element of 'multiples' to be positive integer or -1.");
1264 
1265  return success();
1266 }
1267 
1268 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1269  if (l.size() != r.size() || l.size() != 1)
1270  return false;
1271  return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
1272 }
1273 
1274 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
1275  MLIRContext *context, ::std::optional<Location> location,
1276  ReshapeOp::Adaptor adaptor,
1277  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1278  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1279  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
1280  llvm::SmallVector<int64_t> newShapeValue;
1281  if (!tosa::getConstShapeValue(adaptor.getShape().getDefiningOp(),
1282  newShapeValue)) {
1283  auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
1284  SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
1285  inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
1286  return success();
1287  } else {
1288  newShapeValue = convertToMlirShape(newShapeValue);
1289  }
1290 
1291  // We cannot infer from the total number of elements so we must take the
1292  // shape attribute as exact.
1293  if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
1294  inferredReturnShapes.push_back(
1295  ShapedTypeComponents(newShapeValue, inputType));
1296  return success();
1297  }
1298 
1299  // Determine the number of elements covered by the slice of all static
1300  // dimensions. This allows us to infer the length of the remaining dynamic
1301  // dimension.
1302  int64_t numElements = inputShape.getNumElements();
1303  int64_t staticMul = 1;
1304  for (auto val : newShapeValue) {
1305  if (!ShapedType::isDynamic(val)) {
1306  staticMul *= val;
1307  }
1308  }
1309 
1310  // Determine the length of the dynamic dimension.
1311  for (auto &val : newShapeValue) {
1312  if (ShapedType::isDynamic(val))
1313  val = numElements / staticMul;
1314  }
1315 
1316  inferredReturnShapes.push_back(
1317  ShapedTypeComponents(newShapeValue, inputType));
1318  return success();
1319 }
1320 
1321 llvm::LogicalResult tosa::ReshapeOp::verify() {
1322  TensorType inputType = getInput1().getType();
1323  RankedTensorType outputType = getType();
1324 
1325  SmallVector<int64_t> shapeValues;
1326  if (!tosa::getConstShapeValue(getShape().getDefiningOp(), shapeValues)) {
1327  // skip following checks if shape is not constant
1328  return mlir::success();
1329  }
1330 
1331  if ((int64_t)shapeValues.size() != outputType.getRank())
1332  return emitOpError() << "new shape does not match result rank";
1333 
1334  for (auto [newShapeDim, outputShapeDim] :
1335  zip(shapeValues, outputType.getShape())) {
1336  if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
1337  outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
1338  return emitOpError() << "new shape is inconsistent with result shape";
1339 
1340  if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
1341  return emitOpError() << "new shape has invalid tensor dimension size "
1342  << newShapeDim;
1343  }
1344 
1345  if (inputType.hasStaticShape()) {
1346  int64_t inputElementsNum = inputType.getNumElements();
1347  if (outputType.hasStaticShape()) {
1348  int64_t outputElementsNum = outputType.getNumElements();
1349  if (inputElementsNum != outputElementsNum) {
1350  return emitOpError() << "cannot reshape " << inputElementsNum
1351  << " elements into " << outputElementsNum;
1352  }
1353  }
1354 
1355  int64_t newShapeElementsNum = std::accumulate(
1356  shapeValues.begin(), shapeValues.end(), 1LL,
1357  [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
1358  bool isStaticNewShape =
1359  llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
1360  if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
1361  (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
1362  return emitOpError() << "cannot reshape " << inputElementsNum
1363  << " elements into " << newShapeElementsNum;
1364  }
1365  }
1366 
1367  int missingDims = llvm::count(shapeValues, -1);
1368  if (missingDims > 1)
1369  return emitOpError() << "expected at most one target dimension to be -1";
1370 
1371  return mlir::success();
1372 }
1373 
1374 LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int32_t> &perms) {
1375  // Perms must be constants.
1376  DenseIntElementsAttr permsAttr;
1377  if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
1378  return failure();
1379 
1380  perms.clear();
1381  for (auto v : permsAttr.getValues<APInt>())
1382  perms.push_back(v.getSExtValue());
1383 
1384  return success();
1385 }
1386 
1387 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
1388  MLIRContext *context, ::std::optional<Location> location,
1389  TransposeOp::Adaptor adaptor,
1390  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1391  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1392  ShapeAdaptor permsShape(adaptor.getPerms().getType());
1393 
1394  // We cannot infer anything from a rank-0 "permutation" tensor.
1395  if (permsShape.hasRank() && permsShape.getRank() == 0)
1396  return failure();
1397 
1398  // If input rank and permutation length is unknown, the output rank is
1399  // unknown.
1400  if (!inputShape.hasRank() || !permsShape.hasRank() ||
1401  permsShape.isDynamicDim(0)) {
1402  inferredReturnShapes.push_back(ShapedTypeComponents());
1403  return success();
1404  }
1405 
1406  // This would imply the number of permutations does not match the rank of
1407  // the input which is illegal.
1408  if (permsShape.getDimSize(0) != inputShape.getRank()) {
1409  return failure();
1410  }
1411 
1412  SmallVector<int64_t> outputShape;
1413  // Rank-0 means no permutations matter.
1414  if (inputShape.getRank() == 0) {
1415  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1416  return success();
1417  }
1418 
1419  // Check whether the input dimensions are all the same.
1420  bool allTheSame = true;
1421  for (int i = 1, s = inputShape.getRank(); i < s; i++) {
1422  if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
1423  allTheSame = false;
1424  break;
1425  }
1426  }
1427 
1428  // If all of the input dimensions are the same we don't care about the
1429  // permutation.
1430  if (allTheSame) {
1431  outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
1432  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1433  return success();
1434  }
1435 
1436  outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1437  // If the permuations are a constant we can directly determine the output
1438  // shape.
1439  DenseIntElementsAttr attr;
1440  if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
1441  attr.getType().getRank() == 1) {
1442  ShapeAdaptor permShape = attr;
1443  // Constant permutation must be the same length as the input rank.
1444  if (inputShape.getRank() != permShape.getRank())
1445  return emitOptionalError(location,
1446  "constant permutation must be the same length"
1447  " as the input rank");
1448 
1449  // Constant permutation values must be within the input rank.
1450  for (int i = 0, e = inputShape.getRank(); i < e; i++) {
1451  if (inputShape.getRank() <= permShape.getDimSize(i))
1452  return failure();
1453  }
1454 
1455  outputShape.reserve(inputShape.getRank());
1456  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1457  outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
1458  }
1459  }
1460 
1461  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1462  return success();
1463 }
1464 
1465 LogicalResult tosa::TransposeOp::verify() {
1466  TensorType inputType = getInput1().getType();
1467  TensorType permType = getPerms().getType();
1468  TensorType outputType = getOutput().getType();
1469 
1470  if (permType.hasRank() && permType.getRank() != 1)
1471  return emitOpError()
1472  << "expected permutation tensor to be rank 1 but got rank "
1473  << permType.getRank();
1474  if (inputType.hasRank() && permType.hasRank())
1475  if (!permType.isDynamicDim(0) &&
1476  permType.getDimSize(0) != inputType.getRank())
1477  return emitOpError() << "expected permutation tensor dim 0 to have size "
1478  << inputType.getRank()
1479  << " (input rank) but got size "
1480  << permType.getDimSize(0);
1481  if (inputType.hasRank() && outputType.hasRank() &&
1482  inputType.getRank() != outputType.getRank())
1483  return emitOpError()
1484  << "expected input tensor rank to equal result tensor rank";
1485  if (outputType.hasRank() && permType.hasRank())
1486  if (!permType.isDynamicDim(0) &&
1487  permType.getDimSize(0) != outputType.getRank())
1488  return emitOpError() << "expected permutation tensor dim 0 to have size "
1489  << outputType.getRank()
1490  << " (output rank) but got size "
1491  << permType.getDimSize(0);
1492 
1493  SmallVector<int32_t> constantPerms;
1494  if (succeeded(getConstantPerms(constantPerms))) {
1495  // Assert that the permutation tensor has a rank, which means that the
1496  // rank has been verified above.
1497  assert(permType.hasRank() &&
1498  "Unexpectedly found permutation tensor without rank");
1499  if (!llvm::all_of(constantPerms,
1500  [&constantPerms](int32_t s) {
1501  return s >= 0 &&
1502  static_cast<size_t>(s) < constantPerms.size();
1503  }) ||
1504  !isPermutationVector(llvm::to_vector(llvm::map_range(
1505  constantPerms, [](int32_t v) -> int64_t { return v; }))))
1506  return emitOpError() << "expected valid permutation tensor";
1507 
1508  // Verify that the types of the input and output tensors are properly
1509  // permuted.
1510  if (inputType.hasRank() && outputType.hasRank()) {
1511  assert(constantPerms.size() == static_cast<size_t>(inputType.getRank()) &&
1512  inputType.getRank() == outputType.getRank());
1513 
1514  for (auto i = 0; i < outputType.getRank(); i++) {
1515  if (inputType.isDynamicDim(constantPerms[i]) ||
1516  outputType.isDynamicDim(i))
1517  continue;
1518 
1519  if (inputType.getDimSize(constantPerms[i]) != outputType.getDimSize(i))
1520  return emitOpError()
1521  << "expected output tensor dim " << i << " to match "
1522  << "input dim " << constantPerms[i] << " with value of "
1523  << inputType.getDimSize(constantPerms[i]);
1524  }
1525  }
1526  }
1527  return success();
1528 }
1529 
1530 LogicalResult TransposeOp::reifyResultShapes(
1531  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1532 
1533  SmallVector<int32_t> transposePerms;
1534  if (getConstantPerms(transposePerms).failed())
1535  return failure();
1536 
1537  Value input = getInput1();
1538  auto inputType = cast<TensorType>(input.getType());
1539 
1540  SmallVector<OpFoldResult> returnedDims(inputType.getRank());
1541  for (auto dim : transposePerms) {
1542  int32_t dimInInput = transposePerms[dim];
1543  if (inputType.isDynamicDim(dimInInput))
1544  returnedDims[dim] =
1545  builder.create<tensor::DimOp>(getLoc(), input, dimInInput)
1546  .getResult();
1547  else
1548  returnedDims[dim] =
1549  builder.getIndexAttr(inputType.getDimSize(dimInInput));
1550  }
1551 
1552  reifiedReturnShapes.emplace_back(std::move(returnedDims));
1553  return success();
1554 }
1555 
1556 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
1557  MLIRContext *context, ::std::optional<Location> location,
1558  GatherOp::Adaptor adaptor,
1559  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1560  llvm::SmallVector<int64_t> outputShape;
1561  outputShape.resize(3, ShapedType::kDynamic);
1562 
1563  ShapeAdaptor valuesShape(adaptor.getValues().getType());
1564  if (valuesShape.hasRank()) {
1565  outputShape[0] = valuesShape.getDimSize(0);
1566  outputShape[2] = valuesShape.getDimSize(2);
1567  }
1568 
1569  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
1570  if (indicesShape.hasRank()) {
1571  if (outputShape[0] == ShapedType::kDynamic)
1572  outputShape[0] = indicesShape.getDimSize(0);
1573  if (outputShape[1] == ShapedType::kDynamic)
1574  outputShape[1] = indicesShape.getDimSize(1);
1575  }
1576 
1577  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1578  return success();
1579 }
1580 
1581 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
1582  MLIRContext *context, ::std::optional<Location> location,
1583  ResizeOp::Adaptor adaptor,
1584  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1585  llvm::SmallVector<int64_t, 4> outputShape;
1586  outputShape.resize(4, ShapedType::kDynamic);
1587 
1588  ShapeAdaptor inputShape(adaptor.getInput().getType());
1589  if (!inputShape.hasRank())
1590  return failure();
1591 
1592  outputShape[0] = inputShape.getDimSize(0);
1593  outputShape[3] = inputShape.getDimSize(3);
1594  int64_t inputHeight = inputShape.getDimSize(1);
1595  int64_t inputWidth = inputShape.getDimSize(2);
1596 
1597  if ((inputHeight == ShapedType::kDynamic) ||
1598  (inputWidth == ShapedType::kDynamic))
1599  return failure();
1600 
1601  llvm::ArrayRef<int64_t> scaleInt = adaptor.getScale();
1602  llvm::ArrayRef<int64_t> offsetInt = adaptor.getOffset();
1603  llvm::ArrayRef<int64_t> borderInt = adaptor.getBorder();
1604 
1605  // Compute the output shape based on attributes: scale, offset, and border.
1606  outputShape[1] =
1607  (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
1608  scaleInt[1]) +
1609  1;
1610 
1611  outputShape[2] =
1612  (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
1613  scaleInt[3]) +
1614  1;
1615 
1616  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1617  return success();
1618 }
1619 
1620 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
1621  MLIRContext *context, ::std::optional<Location> location,
1622  ScatterOp::Adaptor adaptor,
1623  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1624  llvm::SmallVector<int64_t> outputShape;
1625  outputShape.resize(3, ShapedType::kDynamic);
1626 
1627  ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
1628  if (valuesInShape.hasRank()) {
1629  outputShape[0] = valuesInShape.getDimSize(0);
1630  outputShape[1] = valuesInShape.getDimSize(1);
1631  outputShape[2] = valuesInShape.getDimSize(2);
1632  }
1633 
1634  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
1635  if (indicesShape.hasRank()) {
1636  if (outputShape[0] == ShapedType::kDynamic)
1637  outputShape[0] = indicesShape.getDimSize(0);
1638  }
1639 
1640  ShapeAdaptor inputShape(adaptor.getInput().getType());
1641  if (inputShape.hasRank()) {
1642  if (outputShape[0] == ShapedType::kDynamic)
1643  outputShape[0] = inputShape.getDimSize(0);
1644  if (outputShape[2] == ShapedType::kDynamic)
1645  outputShape[2] = inputShape.getDimSize(2);
1646  }
1647 
1648  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1649  return success();
1650 }
1651 
1652 static LogicalResult ReduceInferReturnTypes(
1653  ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
1654  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1655  int64_t axisVal = axis.getValue().getSExtValue();
1656  if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
1657  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1658  return success();
1659  }
1660 
1661  SmallVector<int64_t> outputShape;
1662  operandShape.getDims(outputShape);
1663  outputShape[axisVal] = 1;
1664  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1665  return success();
1666 }
1667 
1668 #define COMPATIBLE_RETURN_TYPES(OP) \
1669  bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
1670  if (l.size() != r.size() || l.size() != 1) \
1671  return false; \
1672  if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
1673  return false; \
1674  return succeeded(verifyCompatibleShape(l[0], r[0])); \
1675  }
1676 
1677 #define REDUCE_SHAPE_INFER(OP) \
1678  LogicalResult OP::inferReturnTypeComponents( \
1679  MLIRContext *context, ::std::optional<Location> location, \
1680  OP::Adaptor adaptor, \
1681  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
1682  Type inputType = \
1683  llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
1684  ShapeAdaptor inputShape(adaptor.getInput().getType()); \
1685  const Properties &prop = adaptor.getProperties(); \
1686  return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
1687  inferredReturnShapes); \
1688  } \
1689  COMPATIBLE_RETURN_TYPES(OP)
1690 
1691 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
1692 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
1693 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
1694 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
1695 REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
1696 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
1697 #undef REDUCE_SHAPE_INFER
1698 COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
1699 #undef COMPATIBLE_RETURN_TYPES
1700 
1701 template <typename T>
1702 static LogicalResult verifyReduceOp(T op) {
1703  // All TOSA reduce Ops have input, output and axis.
1704  TensorType inputType = op.getInput().getType();
1705  TensorType outputType = op.getOutput().getType();
1706  int32_t reduceAxis = op.getAxis();
1707 
1708  if (reduceAxis < 0) {
1709  op.emitOpError("reduce axis must not be negative");
1710  return failure();
1711  }
1712  if (inputType.hasRank()) {
1713  int64_t inputRank = inputType.getRank();
1714  // We allow for a special case where the input/output shape has rank 0 and
1715  // axis is also 0.
1716  if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
1717  op.emitOpError("expect input tensor rank (")
1718  << inputRank << ") to be larger than reduce axis (" << reduceAxis
1719  << ")";
1720  return failure();
1721  }
1722  }
1723  if (outputType.hasRank()) {
1724  int64_t outputRank = outputType.getRank();
1725  if (inputType.hasRank() && outputRank != inputType.getRank()) {
1726  op.emitOpError(
1727  "expect output tensor rank to be equal to input tensor rank");
1728  return failure();
1729  }
1730  if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
1731  op.emitOpError("expect output tensor rank (")
1732  << outputRank << ") to be larger than reduce axis (" << reduceAxis
1733  << ")";
1734  return failure();
1735  }
1736  // We can only verify the reduced dimension size to be 1 if this is not
1737  // the special case of output rank == 0.
1738  if (outputRank != 0) {
1739  auto outputShape = outputType.getShape();
1740  if (!outputType.isDynamicDim(reduceAxis) &&
1741  outputShape[reduceAxis] != 1) {
1742  op.emitOpError("expect reduced dimension size to be 1, got ")
1743  << outputShape[reduceAxis];
1744  return failure();
1745  }
1746  }
1747  }
1748  return success();
1749 }
1750 
1751 LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
1752 LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
1753 LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
1754 LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
1755 LogicalResult tosa::ReduceProdOp::verify() { return verifyReduceOp(*this); }
1756 LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
1757 
1758 static LogicalResult NAryInferReturnTypes(
1759  const ValueShapeRange &operands,
1760  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1761  llvm::SmallVector<int64_t> outShape;
1762  if (resolveBroadcastShape(operands, outShape).failed()) {
1763  inferredReturnShapes.push_back(ShapedTypeComponents());
1764  } else {
1765  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1766  }
1767  return success();
1768 }
1769 
1770 #define NARY_SHAPE_INFER(OP) \
1771  LogicalResult OP::inferReturnTypeComponents( \
1772  MLIRContext *context, ::std::optional<Location> location, \
1773  ValueShapeRange operands, DictionaryAttr attributes, \
1774  OpaqueProperties properties, RegionRange regions, \
1775  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
1776  return NAryInferReturnTypes(operands, inferredReturnShapes); \
1777  }
1778 
1779 NARY_SHAPE_INFER(tosa::AbsOp)
1780 NARY_SHAPE_INFER(tosa::AddOp)
1781 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
1782 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
1783 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
1784 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
1785 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
1786 NARY_SHAPE_INFER(tosa::CastOp)
1787 NARY_SHAPE_INFER(tosa::CeilOp)
1788 NARY_SHAPE_INFER(tosa::ClampOp)
1789 NARY_SHAPE_INFER(tosa::ClzOp)
1790 NARY_SHAPE_INFER(tosa::CosOp)
1791 NARY_SHAPE_INFER(tosa::ExpOp)
1792 NARY_SHAPE_INFER(tosa::FloorOp)
1793 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
1794 NARY_SHAPE_INFER(tosa::GreaterOp)
1795 NARY_SHAPE_INFER(tosa::IdentityOp)
1796 NARY_SHAPE_INFER(tosa::IntDivOp)
1797 NARY_SHAPE_INFER(tosa::LogOp)
1798 NARY_SHAPE_INFER(tosa::LogicalAndOp)
1799 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
1800 NARY_SHAPE_INFER(tosa::LogicalNotOp)
1801 NARY_SHAPE_INFER(tosa::LogicalOrOp)
1802 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
1803 NARY_SHAPE_INFER(tosa::LogicalXorOp)
1804 NARY_SHAPE_INFER(tosa::MaximumOp)
1805 NARY_SHAPE_INFER(tosa::MinimumOp)
1806 NARY_SHAPE_INFER(tosa::NegateOp)
1807 NARY_SHAPE_INFER(tosa::PowOp)
1808 NARY_SHAPE_INFER(tosa::ReciprocalOp)
1809 NARY_SHAPE_INFER(tosa::RescaleOp)
1810 NARY_SHAPE_INFER(tosa::ReverseOp)
1811 NARY_SHAPE_INFER(tosa::RsqrtOp)
1812 NARY_SHAPE_INFER(tosa::SinOp)
1813 NARY_SHAPE_INFER(tosa::SelectOp)
1814 NARY_SHAPE_INFER(tosa::SubOp)
1815 NARY_SHAPE_INFER(tosa::TanhOp)
1816 NARY_SHAPE_INFER(tosa::ErfOp)
1817 NARY_SHAPE_INFER(tosa::SigmoidOp)
1818 #undef PRED_SHAPE_INFER
1819 
1820 static LogicalResult poolingInferReturnTypes(
1821  ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
1822  ArrayRef<int64_t> pad,
1823  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1824  llvm::SmallVector<int64_t> outputShape;
1825  outputShape.resize(4, ShapedType::kDynamic);
1826 
1827  // We only know the rank if the input type is unranked.
1828  if (!inputShape) {
1829  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1830  return success();
1831  }
1832 
1833  // Batch and number of channels are identical for pooling layer.
1834  outputShape[0] = inputShape.getDimSize(0);
1835  outputShape[3] = inputShape.getDimSize(3);
1836 
1837  int64_t height = inputShape.getDimSize(1);
1838  int64_t width = inputShape.getDimSize(2);
1839 
1840  if (!ShapedType::isDynamic(height)) {
1841  int64_t padded = height + pad[0] + pad[1] - kernel[0];
1842  outputShape[1] = padded / stride[0] + 1;
1843  }
1844 
1845  if (!ShapedType::isDynamic(width)) {
1846  int64_t padded = width + pad[2] + pad[3] - kernel[1];
1847  outputShape[2] = padded / stride[1] + 1;
1848  }
1849 
1850  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1851  return success();
1852 }
1853 
1854 LogicalResult Conv2DOp::inferReturnTypeComponents(
1855  MLIRContext *context, ::std::optional<Location> location,
1856  Conv2DOp::Adaptor adaptor,
1857  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1858  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
1859 
1860  int64_t inputWidth = ShapedType::kDynamic;
1861  int64_t inputHeight = ShapedType::kDynamic;
1862  int64_t weightWidth = ShapedType::kDynamic;
1863  int64_t weightHeight = ShapedType::kDynamic;
1864 
1865  // Input shape describes input width/height and batch.
1866 
1867  ShapeAdaptor inputShape(adaptor.getInput().getType());
1868  if (inputShape.hasRank()) {
1869  outputShape[0] = inputShape.getDimSize(0);
1870  inputHeight = inputShape.getDimSize(1);
1871  inputWidth = inputShape.getDimSize(2);
1872  }
1873 
1874  // Weight shapes describes the filter width/height and the output channels.
1875  ShapeAdaptor weightShape(adaptor.getWeight().getType());
1876  if (weightShape.hasRank()) {
1877  outputShape[3] = weightShape.getDimSize(0);
1878  weightHeight = weightShape.getDimSize(1);
1879  weightWidth = weightShape.getDimSize(2);
1880  }
1881 
1882  // Bias shape can describe the output channels.
1883  ShapeAdaptor biasShape(adaptor.getBias().getType());
1884  if (biasShape.hasRank()) {
1885  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1886  ? biasShape.getDimSize(0)
1887  : outputShape[3];
1888  }
1889 
1890  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
1891  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1892  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
1893 
1894  if (!ShapedType::isDynamic(inputHeight) &&
1895  !ShapedType::isDynamic(weightHeight)) {
1896  int64_t inputSize = inputHeight + padding[0] + padding[1];
1897  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1898  int64_t unstridedResult = inputSize - filterSize + 1;
1899  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1900  }
1901 
1902  if (!ShapedType::isDynamic(inputWidth) &&
1903  !ShapedType::isDynamic(weightWidth)) {
1904  int64_t inputSize = inputWidth + padding[2] + padding[3];
1905  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1906  int64_t unstridedResult = inputSize - filterSize + 1;
1907  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1908  }
1909 
1910  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1911  return success();
1912 }
1913 
1914 LogicalResult Conv2DOp::verify() {
1915  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
1916  return failure();
1917  return success();
1918 }
1919 
1920 LogicalResult Conv3DOp::inferReturnTypeComponents(
1921  MLIRContext *context, ::std::optional<Location> location,
1922  Conv3DOp::Adaptor adaptor,
1923  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1924  llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
1925 
1926  int64_t inputWidth = ShapedType::kDynamic;
1927  int64_t inputHeight = ShapedType::kDynamic;
1928  int64_t inputDepth = ShapedType::kDynamic;
1929 
1930  int64_t weightWidth = ShapedType::kDynamic;
1931  int64_t weightHeight = ShapedType::kDynamic;
1932  int64_t weightDepth = ShapedType::kDynamic;
1933 
1934  // Input shape describes input width/height and batch.
1935  ShapeAdaptor inputShape(adaptor.getInput().getType());
1936  if (inputShape.hasRank()) {
1937  outputShape[0] = inputShape.getDimSize(0);
1938  inputDepth = inputShape.getDimSize(1);
1939  inputHeight = inputShape.getDimSize(2);
1940  inputWidth = inputShape.getDimSize(3);
1941  }
1942 
1943  // Weight shapes describes the filter width/height and the output channels.
1944  ShapeAdaptor weightShape(adaptor.getWeight().getType());
1945  if (weightShape.hasRank()) {
1946  outputShape[4] = weightShape.getDimSize(0);
1947  weightDepth = weightShape.getDimSize(1);
1948  weightHeight = weightShape.getDimSize(2);
1949  weightWidth = weightShape.getDimSize(3);
1950  }
1951 
1952  // Bias shape can describe the output channels.
1953  ShapeAdaptor biasShape(adaptor.getBias().getType());
1954  if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
1955  outputShape[4] = biasShape.getDimSize(0);
1956  }
1957 
1958  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
1959  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1960  llvm::ArrayRef<int64_t> pad = adaptor.getPad();
1961 
1962  if (!ShapedType::isDynamic(inputDepth) &&
1963  !ShapedType::isDynamic(weightDepth)) {
1964  int32_t inputSize = inputDepth + pad[0] + pad[1];
1965  int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
1966  int32_t unstridedResult = inputSize - filterSize + 1;
1967  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1968  }
1969 
1970  if (!ShapedType::isDynamic(inputHeight) &&
1971  !ShapedType::isDynamic(weightHeight)) {
1972  int32_t inputSize = inputHeight + pad[2] + pad[3];
1973  int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
1974  int32_t unstridedResult = inputSize - filterSize + 1;
1975  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1976  }
1977 
1978  if (!ShapedType::isDynamic(inputWidth) &&
1979  !ShapedType::isDynamic(weightWidth)) {
1980  int32_t inputSize = inputWidth + pad[4] + pad[5];
1981  int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
1982  int32_t unstridedResult = inputSize - filterSize + 1;
1983  outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
1984  }
1985 
1986  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1987  return success();
1988 }
1989 
1990 LogicalResult Conv3DOp::verify() {
1991  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
1992  return failure();
1993  return success();
1994 }
1995 
1996 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
1997  MLIRContext *context, ::std::optional<Location> location,
1998  AvgPool2dOp::Adaptor adaptor,
1999  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2000  ShapeAdaptor inputShape(adaptor.getInput().getType());
2001  const Properties &prop = adaptor.getProperties();
2002  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
2003  inferredReturnShapes);
2004 }
2005 
2006 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
2007  MLIRContext *context, ::std::optional<Location> location,
2008  MaxPool2dOp::Adaptor adaptor,
2009  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2010  ShapeAdaptor inputShape(adaptor.getInput().getType());
2011  const Properties &prop = adaptor.getProperties();
2012  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
2013  inferredReturnShapes);
2014 }
2015 
2016 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
2017  MLIRContext *context, ::std::optional<Location> location,
2018  DepthwiseConv2DOp::Adaptor adaptor,
2019  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2020  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
2021 
2022  int64_t inputWidth = ShapedType::kDynamic;
2023  int64_t inputHeight = ShapedType::kDynamic;
2024  int64_t inputChannels = ShapedType::kDynamic;
2025 
2026  int64_t weightWidth = ShapedType::kDynamic;
2027  int64_t weightHeight = ShapedType::kDynamic;
2028  int64_t depthChannels = ShapedType::kDynamic;
2029 
2030  // Input shape describes input width/height and batch.
2031  ShapeAdaptor inputShape(adaptor.getInput().getType());
2032  if (inputShape.hasRank()) {
2033  outputShape[0] = inputShape.getDimSize(0);
2034  inputHeight = inputShape.getDimSize(1);
2035  inputWidth = inputShape.getDimSize(2);
2036  inputChannels = inputShape.getDimSize(3);
2037  }
2038 
2039  // Weight shapes describes the filter width/height and the output channels.
2040  ShapeAdaptor weightShape(adaptor.getWeight().getType());
2041  if (weightShape.hasRank()) {
2042  weightHeight = weightShape.getDimSize(0);
2043  weightWidth = weightShape.getDimSize(1);
2044  inputChannels = ShapedType::isDynamic(inputChannels)
2045  ? weightShape.getDimSize(2)
2046  : inputChannels;
2047  depthChannels = weightShape.getDimSize(3);
2048  }
2049 
2050  // If both inputChannels and depthChannels are available we can determine
2051  // the output channels.
2052  if (!ShapedType::isDynamic(inputChannels) &&
2053  !ShapedType::isDynamic(depthChannels)) {
2054  outputShape[3] = inputChannels * depthChannels;
2055  }
2056 
2057  // Bias shape can describe the output channels.
2058  ShapeAdaptor biasShape(adaptor.getBias().getType());
2059  if (biasShape.hasRank()) {
2060  outputShape[3] = ShapedType::isDynamic(outputShape[3])
2061  ? biasShape.getDimSize(0)
2062  : outputShape[3];
2063  }
2064 
2065  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
2066  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
2067  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
2068 
2069  if (!ShapedType::isDynamic(inputHeight) &&
2070  !ShapedType::isDynamic(weightHeight)) {
2071  int64_t inputSize = inputHeight + padding[0] + padding[1];
2072  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
2073  int64_t unstridedResult = inputSize - filterSize + 1;
2074  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
2075  }
2076 
2077  if (!ShapedType::isDynamic(inputWidth) &&
2078  !ShapedType::isDynamic(weightWidth)) {
2079  int64_t inputSize = inputWidth + padding[2] + padding[3];
2080  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
2081  int64_t unstridedResult = inputSize - filterSize + 1;
2082  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
2083  }
2084 
2085  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2086  return success();
2087 }
2088 
2089 LogicalResult DepthwiseConv2DOp::verify() {
2090  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
2091  return failure();
2092  return success();
2093 }
2094 
2095 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
2096  MLIRContext *context, ::std::optional<Location> location,
2097  TransposeConv2DOp::Adaptor adaptor,
2098  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2099  // outputShape is mutable.
2100  llvm::SmallVector<int64_t> outputShape =
2101  convertToMlirShape(adaptor.getOutShape());
2102 
2103  int64_t inputWidth = ShapedType::kDynamic;
2104  int64_t inputHeight = ShapedType::kDynamic;
2105  int64_t weightWidth = ShapedType::kDynamic;
2106  int64_t weightHeight = ShapedType::kDynamic;
2107 
2108  // Input shape describes input width/height and batch.
2109  ShapeAdaptor inputShape(adaptor.getInput().getType());
2110  if (inputShape.hasRank()) {
2111  outputShape[0] = ShapedType::isDynamic(outputShape[0])
2112  ? inputShape.getDimSize(0)
2113  : outputShape[0];
2114  inputHeight = inputShape.getDimSize(1);
2115  inputWidth = inputShape.getDimSize(2);
2116  }
2117 
2118  // Weight shapes describes the filter width/height and the output channels.
2119  ShapeAdaptor weightShape(adaptor.getWeight().getType());
2120  if (weightShape.hasRank()) {
2121  outputShape[3] = ShapedType::isDynamic(outputShape[3])
2122  ? weightShape.getDimSize(0)
2123  : outputShape[3];
2124  weightHeight = weightShape.getDimSize(1);
2125  weightWidth = weightShape.getDimSize(2);
2126  }
2127 
2128  // Bias shape can describe the output channels.
2129  ShapeAdaptor biasShape(adaptor.getInput().getType());
2130  if (biasShape.hasRank()) {
2131  outputShape[3] = ShapedType::isDynamic(outputShape[3])
2132  ? biasShape.getDimSize(0)
2133  : outputShape[3];
2134  }
2135 
2136  llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
2137  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
2138 
2139  if (!ShapedType::isDynamic(inputHeight) &&
2140  !ShapedType::isDynamic(weightHeight)) {
2141  int64_t calculateSize =
2142  (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
2143  outputShape[1] =
2144  ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
2145  }
2146 
2147  if (!ShapedType::isDynamic(inputWidth) &&
2148  !ShapedType::isDynamic(weightWidth)) {
2149  int64_t calculateSize =
2150  (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
2151  outputShape[2] =
2152  ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
2153  }
2154 
2155  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2156  return success();
2157 }
2158 
2159 LogicalResult TransposeConv2DOp::verify() {
2160  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
2161  return failure();
2162  return success();
2163 }
2164 
2165 LogicalResult IfOp::inferReturnTypeComponents(
2166  MLIRContext *context, ::std::optional<Location> location,
2167  IfOp::Adaptor adaptor,
2168  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2170  for (Region *region : adaptor.getRegions()) {
2171  for (auto &block : *region)
2172  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
2173  yieldOps.push_back(returnOp);
2174  }
2175 
2176  if (yieldOps.empty())
2177  return failure();
2178 
2179  // Get the initial type information for the yield op.
2180  llvm::SmallVector<ValueKnowledge> resultKnowledge;
2181  resultKnowledge.reserve(yieldOps.front().getNumOperands());
2182  for (auto operand : yieldOps.front().getOperands()) {
2183  resultKnowledge.push_back(
2184  ValueKnowledge::getKnowledgeFromType(operand.getType()));
2185  }
2186 
2187  for (auto yieldOp : yieldOps) {
2188  if (resultKnowledge.size() != yieldOp.getNumOperands())
2189  return failure();
2190 
2191  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
2192  int32_t index = it.index();
2193  auto meet = ValueKnowledge::meet(
2194  resultKnowledge[index],
2195  ValueKnowledge::getKnowledgeFromType(it.value().getType()));
2196  if (!meet)
2197  continue;
2198  resultKnowledge[index] = meet;
2199  }
2200  }
2201 
2202  for (const ValueKnowledge &result : resultKnowledge) {
2203  inferredReturnShapes.push_back(result.getShapedTypeComponents());
2204  }
2205 
2206  return success();
2207 }
2208 
2209 LogicalResult WhileOp::inferReturnTypeComponents(
2210  MLIRContext *context, ::std::optional<Location> location,
2211  WhileOp::Adaptor adaptor,
2212  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2214  for (auto &block : adaptor.getBody())
2215  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
2216  yieldOps.push_back(returnOp);
2217 
2218  // TOSA's while must have a tosa.yield as its terminator. If not found this
2219  // tosa.while is invalid.
2220  if (yieldOps.empty())
2221  return failure();
2222 
2223  // Get the initial type information from the operand types.
2224  llvm::SmallVector<ValueKnowledge> resultKnowledge;
2225  resultKnowledge.reserve(yieldOps.front().getNumOperands());
2226  for (auto operand : yieldOps.front().getOperands()) {
2227  resultKnowledge.push_back(
2228  ValueKnowledge::getKnowledgeFromType(operand.getType()));
2229  }
2230 
2231  for (auto yieldOp : yieldOps) {
2232  if (resultKnowledge.size() != yieldOp.getNumOperands())
2233  return failure();
2234 
2235  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
2236  int32_t index = it.index();
2237  if (auto meet = ValueKnowledge::meet(
2238  resultKnowledge[index],
2239  ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
2240  resultKnowledge[index] = meet;
2241  }
2242  }
2243  }
2244 
2245  for (const ValueKnowledge &result : resultKnowledge) {
2246  inferredReturnShapes.push_back(result.getShapedTypeComponents());
2247  }
2248 
2249  return success();
2250 }
2251 
2252 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
2253  if (auto vt = llvm::dyn_cast<VectorType>(getType()))
2254  return llvm::to_vector<4>(vt.getShape());
2255  return std::nullopt;
2256 }
2257 
2258 // parse and print of IfOp refer to the implementation of SCF dialect.
2259 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
2260  // Create the regions for 'then'.
2261  result.regions.reserve(2);
2262  Region *thenRegion = result.addRegion();
2263  Region *elseRegion = result.addRegion();
2264 
2265  auto &builder = parser.getBuilder();
2267  // Create a i1 tensor type for the boolean condition.
2268  Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
2269  if (parser.parseOperand(cond) ||
2270  parser.resolveOperand(cond, i1Type, result.operands))
2271  return failure();
2272  // Parse optional results type list.
2273  if (parser.parseOptionalArrowTypeList(result.types))
2274  return failure();
2275  // Parse the 'then' region.
2276  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
2277  return failure();
2278 
2279  // If we find an 'else' keyword then parse the 'else' region.
2280  if (!parser.parseOptionalKeyword("else")) {
2281  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
2282  return failure();
2283  }
2284 
2285  // Parse the optional attribute list.
2286  if (parser.parseOptionalAttrDict(result.attributes))
2287  return failure();
2288  return success();
2289 }
2290 
2291 void IfOp::print(OpAsmPrinter &p) {
2292  bool printBlockTerminators = false;
2293 
2294  p << " " << getCond();
2295  if (!getResults().empty()) {
2296  p << " -> (" << getResultTypes() << ")";
2297  // Print yield explicitly if the op defines values.
2298  printBlockTerminators = true;
2299  }
2300  p << ' ';
2301  p.printRegion(getThenBranch(),
2302  /*printEntryBlockArgs=*/false,
2303  /*printBlockTerminators=*/printBlockTerminators);
2304 
2305  // Print the 'else' regions if it exists and has a block.
2306  auto &elseRegion = getElseBranch();
2307  if (!elseRegion.empty()) {
2308  p << " else ";
2309  p.printRegion(elseRegion,
2310  /*printEntryBlockArgs=*/false,
2311  /*printBlockTerminators=*/printBlockTerminators);
2312  }
2313 
2314  p.printOptionalAttrDict((*this)->getAttrs());
2315 }
2316 
2317 LogicalResult ReverseOp::verify() {
2318  TensorType inputType = getInput1().getType();
2319  TensorType outputType = getOutput().getType();
2320  int32_t reverseAxis = getAxis();
2321 
2322  if (reverseAxis < 0)
2323  return emitOpError("expected non-negative reverse axis");
2324  if (inputType.hasRank()) {
2325  int64_t inputRank = inputType.getRank();
2326  // We allow for a special case where the input/output shape has rank 0 and
2327  // axis is also 0.
2328  if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
2329  return emitOpError("expect input tensor rank (")
2330  << inputRank << ") to be larger than reverse axis (" << reverseAxis
2331  << ")";
2332  }
2333  if (outputType.hasRank()) {
2334  int64_t outputRank = outputType.getRank();
2335  if (inputType.hasRank() && outputRank != inputType.getRank())
2336  return emitOpError(
2337  "expect output tensor rank to be equal to input tensor rank");
2338  if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
2339  return emitOpError("expect output tensor rank (")
2340  << outputRank << ") to be larger than reverse axis ("
2341  << reverseAxis << ")";
2342  }
2343  return success();
2344 }
2345 
2346 // parse and print of WhileOp refer to the implementation of SCF dialect.
2347 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
2350  Region *cond = result.addRegion();
2351  Region *body = result.addRegion();
2352 
2353  OptionalParseResult listResult =
2354  parser.parseOptionalAssignmentList(regionArgs, operands);
2355  if (listResult.has_value() && failed(listResult.value()))
2356  return failure();
2357 
2358  FunctionType functionType;
2359  SMLoc typeLoc = parser.getCurrentLocation();
2360  if (failed(parser.parseColonType(functionType)))
2361  return failure();
2362 
2363  result.addTypes(functionType.getResults());
2364 
2365  if (functionType.getNumInputs() != operands.size()) {
2366  return parser.emitError(typeLoc)
2367  << "expected as many input types as operands "
2368  << "(expected " << operands.size() << " got "
2369  << functionType.getNumInputs() << ")";
2370  }
2371 
2372  // Resolve input operands.
2373  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
2374  parser.getCurrentLocation(),
2375  result.operands)))
2376  return failure();
2377 
2378  // Propagate the types into the region arguments.
2379  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
2380  regionArgs[i].type = functionType.getInput(i);
2381 
2382  return failure(parser.parseRegion(*cond, regionArgs) ||
2383  parser.parseKeyword("do") || parser.parseRegion(*body) ||
2385 }
2386 
2388  Block::BlockArgListType blocksArgs,
2389  ValueRange initializers,
2390  StringRef prefix = "") {
2391  assert(blocksArgs.size() == initializers.size() &&
2392  "expected same length of arguments and initializers");
2393  if (initializers.empty())
2394  return;
2395 
2396  parser << prefix << '(';
2397  llvm::interleaveComma(
2398  llvm::zip(blocksArgs, initializers), parser,
2399  [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
2400  parser << ")";
2401 }
2402 
2403 void WhileOp::print(OpAsmPrinter &parser) {
2404  printInitializationList(parser, getCond().front().getArguments(), getInputs(),
2405  " ");
2406  parser << " : ";
2407  parser.printFunctionalType(getInputs().getTypes(), getResults().getTypes());
2408  parser << ' ';
2409  parser.printRegion(getCond(), /*printEntryBlockArgs=*/false);
2410  parser << " do ";
2411  parser.printRegion(getBody());
2412  parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
2413 }
2414 
2415 LogicalResult mlir::tosa::getZeroPoint(ElementsAttr zpAttr, int64_t &zp) {
2416  Type zpElemType = zpAttr.getElementType();
2417  if (auto quantType =
2418  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(zpElemType)) {
2419  zp = quantType.getZeroPoint();
2420  return success();
2421  }
2422  if (llvm::isa<FloatType>(zpElemType)) {
2423  // non-zero zero point is not allowed for float types.
2424  if (!zpAttr.getValues<APFloat>()[0].isZero())
2425  return failure();
2426  zp = 0;
2427  return success();
2428  }
2429  if (llvm::isa<IntegerType>(zpElemType)) {
2430  zp = zpAttr.getValues<APInt>()[0].getSExtValue();
2431  return success();
2432  }
2433  // zero point is not allowed for unsupported types.
2434  return failure();
2435 }
2436 
2437 // Create a rank-1 const tensor for zero point of the source tensor.
2438 std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
2439  Location loc,
2440  Type srcElemType,
2441  int64_t zp) {
2442  srcElemType = getElementTypeOrSelf(srcElemType);
2443  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcElemType))
2444  srcElemType = quantType.getStorageType();
2445  auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
2446  if (llvm::isa<FloatType>(srcElemType)) {
2447  auto zpAttr = DenseElementsAttr::get(
2448  zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
2449  return builder.create<tosa::ConstOp>(loc, zpType, zpAttr);
2450  }
2451  if (llvm::isa<IntegerType>(srcElemType)) {
2452  auto zpAttr =
2453  DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp));
2454  return builder.create<tosa::ConstOp>(loc, zpType, zpAttr);
2455  }
2456  llvm::errs() << "zero point is not allowed for unsupported data types\n";
2457  return std::nullopt;
2458 }
2459 
2460 //===----------------------------------------------------------------------===//
2461 // TOSA Shape and Shape Operators Helper functions.
2462 //===----------------------------------------------------------------------===//
2463 
2465  return mlir::isa<tosa::shapeType>(t);
2466 }
2467 
2468 LogicalResult
2470  int rank) {
2471  if (rank < 0)
2472  return emitError() << "invalid rank (must be >= 0): " << rank;
2473  return success();
2474 }
2475 
2477  for (auto v : op->getOperands()) {
2478  if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
2479  Operation *definingOp = v.getDefiningOp();
2480  if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
2481  return op->emitOpError("shape operand is not compile time resolvable");
2482  }
2483  }
2484  }
2485  return success();
2486 }
2487 
2489  for (auto type : op->getOperandTypes()) {
2490  if (!mlir::isa<mlir::tosa::shapeType>(type)) {
2491  return op->emitOpError("must have operands with tosa shape type");
2492  }
2493  }
2494  for (auto type : op->getResultTypes()) {
2495  if (!mlir::isa<mlir::tosa::shapeType>(type)) {
2496  return op->emitOpError("must have result with tosa shape type");
2497  }
2498  }
2499  return success();
2500 }
2501 
2502 LogicalResult
2504  if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) ||
2505  failed(verifyTosaShapeOperator(op)))
2506  return failure();
2507 
2508  // delegate function that returns rank of shape type
2509  auto getRank = [](const Type type) {
2510  return mlir::cast<mlir::tosa::shapeType>(type).getRank();
2511  };
2512  auto operandTypes = op->getOperandTypes();
2513  auto resultTypes = op->getResultTypes();
2514 
2515  auto rank = getRank(*op->getOperandTypes().begin());
2516  for (auto type : operandTypes) {
2517  if (getRank(type) != rank) {
2518  return op->emitOpError("operands don't have matching ranks");
2519  }
2520  }
2521  for (auto type : resultTypes) {
2522  if (getRank(type) != rank) {
2523  return op->emitOpError("result shape has different rank than operands");
2524  }
2525  }
2526  return success();
2527 }
2528 
2529 //===----------------------------------------------------------------------===//
2530 // TOSA Shape Operators verify functions.
2531 //===----------------------------------------------------------------------===//
2532 
2533 LogicalResult tosa::ConstShapeOp::verify() {
2534  // check one dimensional rank
2535  auto valuesRank = getValue().getType().getRank();
2536  if (valuesRank != 1)
2537  return emitOpError("expect elements in attribute value with rank 1");
2538  // check that number of elements in value attr equal to rank of result shape
2539  auto count = getValue().getNumElements();
2540  auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
2541  if (!(count == rank || (count == 1 && rank == 0))) {
2542  return emitOpError("expect number of elements in attribute value (")
2543  << count << ") to be equal to the rank (" << rank
2544  << ") for the result shape type";
2545  }
2546  return success();
2547 }
2548 
2549 //===----------------------------------------------------------------------===//
2550 // TOSA Attribute Definitions.
2551 //===----------------------------------------------------------------------===//
2552 
2553 #define GET_ATTRDEF_CLASSES
2554 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
2555 
2556 //===----------------------------------------------------------------------===//
2557 // TOSA Type Definitions.
2558 //===----------------------------------------------------------------------===//
2559 #define GET_TYPEDEF_CLASSES
2560 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
2561 
2562 //===----------------------------------------------------------------------===//
2563 // TOSA Operator Definitions.
2564 //===----------------------------------------------------------------------===//
2565 
2566 #define GET_OP_CLASSES
2567 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition: TosaOps.cpp:573
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType)
Definition: TosaOps.cpp:389
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:1652
#define REDUCE_SHAPE_INFER(OP)
Definition: TosaOps.cpp:1677
static LogicalResult verifyConvOp(T op)
Definition: TosaOps.cpp:219
static void buildUnaryOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter unary operators that have scale relationship between their...
Definition: TosaOps.cpp:637
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:1820
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition: TosaOps.cpp:550
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition: TosaOps.cpp:657
static LogicalResult verifyReduceOp(T op)
Definition: TosaOps.cpp:1702
#define NARY_SHAPE_INFER(OP)
Definition: TosaOps.cpp:1770
static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings, Value padConst)
This builder is called on TOSA pad operator when an explicit pad_const value is passed in.
Definition: TosaOps.cpp:672
static LogicalResult verifyConvOpModes(T op)
Definition: TosaOps.cpp:342
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:1758
#define COMPATIBLE_RETURN_TYPES(OP)
Definition: TosaOps.cpp:1668
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition: TosaOps.cpp:691
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition: TosaOps.cpp:527
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr but avg_pool operator has...
Definition: TosaOps.cpp:613
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition: TosaOps.cpp:964
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Definition: TosaOps.cpp:2387
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:95
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class indicates that op operates on tosa shape types.
Definition: TosaOps.h:119
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:47
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:114
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
Definition: Operation.cpp:915
LogicalResult verifyTosaShapeOperator(Operation *op)
Definition: TosaOps.cpp:2488
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
Definition: TosaOps.cpp:2503
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
Definition: TosaOps.cpp:2476
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition: Traits.cpp:60
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Definition: QuantUtils.cpp:191
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
Definition: QuantUtils.cpp:282
LogicalResult getZeroPoint(ElementsAttr zpAttr, int64_t &zp)
Definition: TosaOps.cpp:2415
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
Definition: QuantUtils.cpp:262
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
Definition: QuantUtils.cpp:155
ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &attr)
Definition: TosaOps.cpp:175
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
Definition: QuantUtils.cpp:207
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
Definition: TosaOps.cpp:2438
bool isa_tosa_shape_type(mlir::Type t)
Definition: TosaOps.cpp:2464
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, Attribute attr)
Definition: TosaOps.cpp:197
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Definition: QuantUtils.cpp:236
bool getConstShapeValue(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:497
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
Definition: ShapeUtils.h:33
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45