MLIR  19.0.0git
TosaOps.cpp
Go to the documentation of this file.
1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://developer.mlplatform.org/w/tosa/
12 //
13 //===----------------------------------------------------------------------===//
14 
22 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 using namespace mlir;
34 using namespace mlir::tosa;
35 
36 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
37 
38 //===----------------------------------------------------------------------===//
39 // Tosa dialect interface includes.
40 //===----------------------------------------------------------------------===//
41 
42 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
43 
44 namespace {
45 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
46 
47 //===----------------------------------------------------------------------===//
48 // Dialect Function Inliner Interface.
49 //===----------------------------------------------------------------------===//
50 struct TosaInlinerInterface : public DialectInlinerInterface {
52 
53  //===--------------------------------------------------------------------===//
54  // Analysis Hooks.
55  //===--------------------------------------------------------------------===//
56 
57  /// All operations can be inlined by default.
58  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
59  IRMapping &map) const final {
60  return true;
61  }
62 
63  /// All regions with If and While parent operators can be inlined.
64  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
65  IRMapping &map) const final {
66  return (isa<tosa::IfOp>(dest->getParentOp()) ||
67  isa<tosa::WhileOp>(dest->getParentOp()));
68  }
69 };
70 
71 /// This class implements the bytecode interface for the Tosa dialect.
72 struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
73  TosaDialectBytecodeInterface(Dialect *dialect)
74  : BytecodeDialectInterface(dialect) {}
75 
76  //===--------------------------------------------------------------------===//
77  // Attributes
78 
79  Attribute readAttribute(DialectBytecodeReader &reader) const override {
80  return ::readAttribute(getContext(), reader);
81  }
82 
83  LogicalResult writeAttribute(Attribute attr,
84  DialectBytecodeWriter &writer) const override {
85  return ::writeAttribute(attr, writer);
86  }
87 
88  //===--------------------------------------------------------------------===//
89  // Types
90 
91  Type readType(DialectBytecodeReader &reader) const override {
92  return ::readType(getContext(), reader);
93  }
94 
95  LogicalResult writeType(Type type,
96  DialectBytecodeWriter &writer) const override {
97  return ::writeType(type, writer);
98  }
99 
100  void writeVersion(DialectBytecodeWriter &writer) const final {
101  // TODO: Populate.
102  }
103 
104  std::unique_ptr<DialectVersion>
105  readVersion(DialectBytecodeReader &reader) const final {
106  // TODO: Populate
107  reader.emitError("Dialect does not support versioning");
108  return nullptr;
109  }
110 
111  LogicalResult upgradeFromVersion(Operation *topLevelOp,
112  const DialectVersion &version) const final {
113  return success();
114  }
115 };
116 
117 } // namespace
118 
119 //===----------------------------------------------------------------------===//
120 // TOSA control flow support.
121 //===----------------------------------------------------------------------===//
122 
123 /// Returns the while loop body.
124 SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
125 
126 //===----------------------------------------------------------------------===//
127 // Tosa dialect initialization.
128 //===----------------------------------------------------------------------===//
129 
130 void TosaDialect::initialize() {
131  addOperations<
132 #define GET_OP_LIST
133 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
134  >();
135  addAttributes<
136 #define GET_ATTRDEF_LIST
137 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
138  >();
139  addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
140  declarePromisedInterfaces<
141  mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
142  ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, DivOp,
143  LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
144  LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
145  BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
146  NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
147  GreaterEqualOp, MatMulOp>();
148 }
149 
151  Type type, Location loc) {
152  // Tosa dialect constants only support ElementsAttr unlike standard dialect
153  // constant which supports all attributes.
154  if (llvm::isa<ElementsAttr>(value))
155  return builder.create<tosa::ConstOp>(loc, type,
156  llvm::cast<ElementsAttr>(value));
157  return nullptr;
158 }
159 
160 //===----------------------------------------------------------------------===//
161 // Parsers and printers
162 //===----------------------------------------------------------------------===//
163 
165  Attribute &attr) {
166  if (succeeded(parser.parseOptionalEqual())) {
167  if (failed(parser.parseAttribute(attr))) {
168  return parser.emitError(parser.getCurrentLocation())
169  << "expected attribute";
170  }
171  if (auto typedAttr = attr.dyn_cast<TypedAttr>()) {
172  typeAttr = TypeAttr::get(typedAttr.getType());
173  }
174  return success();
175  }
176 
177  Type type;
178  if (failed(parser.parseColonType(type))) {
179  return parser.emitError(parser.getCurrentLocation()) << "expected type";
180  }
181  typeAttr = TypeAttr::get(type);
182 
183  return success();
184 }
185 
187  Attribute attr) {
188  bool needsSpace = false;
189  auto typedAttr = attr.dyn_cast_or_null<TypedAttr>();
190  if (!typedAttr || typedAttr.getType() != type.getValue()) {
191  p << ": ";
192  p.printAttribute(type);
193  needsSpace = true; // subsequent attr value needs a space separator
194  }
195  if (attr) {
196  if (needsSpace)
197  p << ' ';
198  p << "= ";
199  p.printAttribute(attr);
200  }
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // TOSA Operator Verifiers.
205 //===----------------------------------------------------------------------===//
206 
207 static bool hasZeroDimension(ShapedType shapedType) {
208  if (!shapedType.hasRank())
209  return false;
210 
211  auto rank = shapedType.getRank();
212 
213  for (int i = 0; i < rank; i++) {
214  if (shapedType.isDynamicDim(i))
215  continue;
216  if (shapedType.getDimSize(i) == 0)
217  return true;
218  }
219 
220  return false;
221 }
222 
223 template <typename T> static LogicalResult verifyConvOp(T op) {
224  // All TOSA conv ops have an input() and weight().
225  auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
226  auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
227 
228  // Must be ranked tensor types
229  if (!inputType) {
230  op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
231  return failure();
232  }
233  if (!weightType) {
234  op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
235  return failure();
236  }
237 
238  if (hasZeroDimension(inputType))
239  return op.emitOpError() << "tensor has a dimension with size zero. Each "
240  "dimension of a tensor must have size >= 1";
241 
242  auto inputEType = inputType.getElementType();
243  auto weightEType = weightType.getElementType();
244 
245  bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
246  bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
247 
248  // Either both must be quantized or both unquantized.
249  if (inputIsQuant != weightIsQuant) {
250  op.emitOpError(
251  "expect both input and weight to be float or not together, got ")
252  << inputEType << " and " << weightEType;
253  return failure();
254  }
255 
256  // Quantized type must have constructed the quantizationattr, and unquantized
257  // types should not have a quantizationattr.
258  if ((inputIsQuant && !op.getQuantizationInfo()) ||
259  (!inputIsQuant && op.getQuantizationInfo())) {
260  op.emitOpError("quantizationattr is required for quantized type, and not "
261  "allowed for float type");
262  return failure();
263  }
264 
265  return success();
266 }
267 
269  // Ensure output is of 32-bit integer
270  const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
271  if (!resultETy.isIntOrIndex())
272  return emitOpError("result tensor is not of integer type");
273 
274  // Ensure axis is within the tensor rank
275  const auto inputType = llvm::cast<ShapedType>(getInput().getType());
276  const int64_t axis = getAxisAttr().getInt();
277  if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
278  return emitOpError("specified axis is outside the rank of the tensor");
279 
280  return success();
281 }
282 
284  auto inputType = llvm::cast<ShapedType>(getInput().getType());
285  if (hasZeroDimension(inputType))
286  return emitOpError() << "tensor has a dimension with size zero. Each "
287  "dimension of a tensor must have size >= 1";
288 
289  auto inputETy = inputType.getElementType();
290  auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
291 
292  if (auto quantType =
293  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
294  inputETy = quantType.getStorageType();
295 
296  if (auto quantType =
297  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy))
298  resultETy = quantType.getStorageType();
299 
300  auto accType = getAccType();
301  if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
302  return emitOpError("accumulator type for integer tensor is not i32");
303 
304  if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
305  return emitOpError("accumulator type for f16 tensor is not f16/f32");
306 
307  if (inputETy.isBF16() && !accType.isF32())
308  return emitOpError("accumulator type for bf16 tensor is not f32");
309 
310  if (inputETy.isF32() && !accType.isF32())
311  return emitOpError("accumulator type for f32 tensor is not f32");
312 
313  if ((inputETy.isF32() && resultETy.isF32()) ||
314  (inputETy.isF16() && resultETy.isF16()) ||
315  (inputETy.isBF16() && resultETy.isBF16()) ||
316  (inputETy.isInteger(8) && resultETy.isInteger(8)) ||
317  (inputETy.isInteger(16) && resultETy.isInteger(16)))
318  return success();
319 
320  return emitOpError("input/output element types are incompatible.");
321 }
322 
324  mlir::Type inputETy =
325  llvm::cast<ShapedType>(getInput().getType()).getElementType();
326  if (auto quantType =
327  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
328  inputETy = quantType.getStorageType();
329  }
330  mlir::Type maxFpType = getMaxFpAttr().getType();
331  mlir::Type minFpType = getMinFpAttr().getType();
332  mlir::Type outputETy =
333  llvm::cast<ShapedType>(getOutput().getType()).getElementType();
334  if (auto quantType =
335  llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
336  outputETy = quantType.getStorageType();
337  }
338  unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
339 
340  if (inputETy != outputETy)
341  return emitOpError("input/output element types are incompatible.");
342 
343  // if input datatype is float, check that the two min/max_fp attributes share
344  // the same type and that their type is either the same of the input's
345  // datatype, or a float type whose bitwidth > input datatype bitwidth
346  if (!inputETy.isInteger(dataTypeBitWidth)) {
347  if (((maxFpType != minFpType) ||
348  (maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <=
349  inputETy.getIntOrFloatBitWidth())))
350  return emitOpError("min/max attributes types are incompatible with "
351  "input/output element types.");
352  }
353 
354  return success();
355 }
356 
357 //===----------------------------------------------------------------------===//
358 // TOSA Operator Quantization Builders.
359 //===----------------------------------------------------------------------===//
360 
361 /// This builder is called on all convolution operators except TransposeConv,
362 /// which has specialized output shape semantics. The builder also defines the
363 /// bitwidth of the output given the bit width of the input & weight content.
364 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
365  Type outputType, Value input, Value weight,
366  Value bias, DenseI64ArrayAttr pad,
367  DenseI64ArrayAttr stride,
368  DenseI64ArrayAttr dilation) {
369 
370  result.addOperands({input, weight, bias});
371  result.addAttribute("pad", pad);
372  result.addAttribute("stride", stride);
373  result.addAttribute("dilation", dilation);
374 
375  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
376  if (quantAttr) {
377  result.addAttribute("quantization_info", quantAttr);
378  result.addTypes(
379  buildConvOpResultTypeInfo(builder, outputType, input, weight));
380  } else {
381  result.addTypes(outputType);
382  }
383 }
384 
385 /// Handles tosa.transpose_conv2d which has outpad and output shape attributes.
387  OpBuilder &builder, OperationState &result, Type outputType, Value input,
388  Value weight, Value bias, DenseI64ArrayAttr outpad,
389  DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) {
390  result.addOperands({input, weight, bias});
391  result.addAttribute("out_pad", outpad);
392  result.addAttribute("stride", stride);
393  result.addAttribute("out_shape", outputShape);
394  auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
395 
396  if (quantAttr) {
397  result.addAttribute("quantization_info", quantAttr);
398  result.addTypes(
399  buildConvOpResultTypeInfo(builder, outputType, input, weight));
400  } else {
401  result.addTypes(outputType);
402  }
403 }
404 
405 /// The tosa.fully_connected op has its own builder as it does not have
406 /// strides/dilation/padding.
407 static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
408  Type outputType, Value input, Value weight,
409  Value bias) {
410 
411  result.addOperands({input, weight, bias});
412  auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
413  if (quantAttr) {
414  result.addAttribute("quantization_info", quantAttr);
415  result.addTypes(
416  buildConvOpResultTypeInfo(builder, outputType, input, weight));
417  } else {
418  result.addTypes(outputType);
419  }
420 }
421 
422 /// The tosa.matmul op is also intended to be generated where a fully_connected
423 /// op must be constructed where the weight is not a constant. In this case,
424 /// the fully_connected op must be expressed using matmul.
425 /// TODO: Add link to the leglization document explaining this.
427  OperationState &result, Type outputType,
428  Value a, Value b) {
429  result.addOperands({a, b});
430  auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
431 
432  if (quantAttr) {
433  result.addAttribute("quantization_info", quantAttr);
434 
435  auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
436  assert(inputType && "Input must be a shaped tensor type!");
437 
438  auto inputQType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
439  inputType.getElementType());
440  assert(inputQType && "Tensor must have quantized datatype!");
441 
442  unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
443 
444  auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
445  assert(outputShapedType && "Output must be a shaped type");
446 
447  IntegerType accElementType;
448  if (inputBits == 16)
449  accElementType = builder.getIntegerType(48);
450  else
451  accElementType = builder.getI32Type();
452  auto accType = outputShapedType.clone(accElementType);
453  result.addTypes(accType);
454  } else {
455  result.addTypes(outputType);
456  }
457 }
458 
459 /// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr
460 /// but avg_pool operator has its own builder as it has additional parameters
461 /// not part of the unary ops.
462 static void
464  Type outputType, Value input,
465  DenseArrayAttr kernel, DenseArrayAttr stride,
466  DenseArrayAttr pad, TypeAttr accType) {
467  result.addOperands(input);
468  result.addAttribute("kernel", kernel);
469  result.addAttribute("stride", stride);
470  result.addAttribute("pad", pad);
471  result.addAttribute("acc_type", accType);
472  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
473  if (quantAttr)
474  result.addAttribute("quantization_info", quantAttr);
475  result.types.push_back(outputType);
476 }
477 
478 /// This builder is called on single-parameter unary operators that have scale
479 /// relationship between their input and output, expressed by the
480 /// UnaryOpQuantizationAttr.
482  OperationState &result, Type outputType,
483  Value input) {
484  result.addOperands(input);
485  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
486  if (quantAttr)
487  result.addAttribute("quantization_info", quantAttr);
488  result.types.push_back(outputType);
489 }
490 
491 /// This builder is called on TOSA pad operator that needs to create its own
492 /// OptionalAttr quantization_attr parameter to scale the padding values
493 /// correctly. No pad_const is interpreted as zero-padding.
494 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
495  Type outputType, Value input,
496  Value paddings) {
497  result.addOperands({input, paddings});
498  auto quantAttr = buildPadOpQuantizationAttr(builder, input);
499  if (quantAttr)
500  result.addAttribute("quantization_info", quantAttr);
501  result.types.push_back(outputType);
502 }
503 
504 /// This builder is called on TOSA pad operator when an explicit pad_const
505 /// value is passed in. It also optionally constructs quantization_attr.
507  OperationState &result,
508  Type outputType, Value input,
509  Value paddings,
510  Value padConst) {
511  result.addOperands({input, paddings, padConst});
512  auto quantAttr = buildPadOpQuantizationAttr(builder, input);
513  if (quantAttr)
514  result.addAttribute("quantization_info", quantAttr);
515  result.types.push_back(outputType);
516 }
517 
518 //===----------------------------------------------------------------------===//
519 // TOSA Operator Return Type Inference.
520 //===----------------------------------------------------------------------===//
521 
523  SmallVector<int64_t> &outShape) {
524  int64_t outRank = 0;
525  for (int i = 0, e = operands.size(); i != e; ++i) {
526  auto shape = operands.getShape(i);
527  if (!shape.hasRank()) {
528  // TODO(jennik): Update function to have better case handling for invalid
529  // operands and for ranked tensors.
530  return failure();
531  }
532  outRank = std::max<int64_t>(outRank, shape.getRank());
533  }
534 
535  outShape.resize(outRank, 1);
536 
537  for (int i = 0, e = operands.size(); i != e; ++i) {
538  auto shape = operands.getShape(i);
539  auto rankDiff = outShape.size() - shape.getRank();
540 
541  for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
542  auto dim1 = outShape[i + rankDiff];
543  auto dim2 = shape.getDimSize(i);
544  auto resolvedDim = dim1;
545 
546  if (dim1 == 1) {
547  resolvedDim = dim2;
548  } else if (dim2 == 1) {
549  resolvedDim = dim1;
550  } else if (dim1 != dim2) {
551  return failure();
552  }
553  outShape[i + rankDiff] = resolvedDim;
554  }
555  }
556 
557  return success();
558 }
559 
560 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
561  MLIRContext *context, ::std::optional<Location> location,
562  ArgMaxOp::Adaptor adaptor,
563  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
564  ShapeAdaptor inputShape(adaptor.getInput().getType());
565  IntegerAttr axis = adaptor.getProperties().axis;
566  int32_t axisVal = axis.getValue().getSExtValue();
567 
568  if (!inputShape.hasRank()) {
569  inferredReturnShapes.push_back(ShapedTypeComponents());
570  return success();
571  }
572 
573  SmallVector<int64_t> outShape;
574  outShape.reserve(inputShape.getRank() - 1);
575  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
576  if (i == axisVal)
577  continue;
578  outShape.push_back(inputShape.getDimSize(i));
579  }
580 
581  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
582  return success();
583 }
584 
585 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
586  MLIRContext *context, ::std::optional<Location> location,
587  RFFT2dOp::Adaptor adaptor,
588  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
589  ShapeAdaptor inputShape(adaptor.getInput().getType());
590 
591  if (!inputShape.hasRank())
592  return failure();
593 
594  llvm::SmallVector<int64_t> outputShape;
595  outputShape.resize(3, ShapedType::kDynamic);
596  outputShape[0] = inputShape.getDimSize(0);
597  outputShape[1] = inputShape.getDimSize(1);
598  int64_t inWidth = inputShape.getDimSize(2);
599 
600  // Note that we can support this calculation symbolically
601  // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1]
602  if (inWidth != ShapedType::kDynamic)
603  outputShape[2] = inWidth / 2 + 1;
604 
605  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
606  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
607 
608  return success();
609 }
610 
611 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
612  MLIRContext *context, ::std::optional<Location> location,
613  FFT2dOp::Adaptor adaptor,
614  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
615  inferredReturnShapes.push_back(
616  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
617  inferredReturnShapes.push_back(
618  ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
619  return success();
620 }
621 
622 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
623  MLIRContext *context, ::std::optional<Location> location,
624  ConcatOp::Adaptor adaptor,
625  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
626  // Infer all dimension sizes by reducing based on inputs.
627  const Properties &prop = adaptor.getProperties();
628  int32_t axis = prop.axis.getValue().getSExtValue();
629  llvm::SmallVector<int64_t> outputShape;
630  bool hasRankedInput = false;
631  for (auto operand : adaptor.getOperands()) {
632  ShapeAdaptor operandShape(operand.getType());
633  if (!operandShape.hasRank())
634  continue;
635 
636  // Copy the Operand's rank.
637  if (!hasRankedInput)
638  outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
639 
640  // Copy shapes until the dim is non-dynamic.
641  for (int i = 0, s = operandShape.getRank(); i < s; i++) {
642  if (i == axis || operandShape.isDynamicDim(i))
643  continue;
644  if (outputShape[i] == ShapedType::kDynamic)
645  outputShape[i] = operandShape.getDimSize(i);
646  if (outputShape[i] != operandShape.getDimSize(i))
647  return emitOptionalError(location,
648  "Cannot concat tensors with different sizes"
649  " on the non-axis dimension ",
650  i);
651  }
652 
653  hasRankedInput = true;
654  }
655  Type inputType =
656  llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
657  if (!hasRankedInput) {
658  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
659  return success();
660  }
661 
662  // Determine the dimension size along the concatenation axis.
663  int64_t concatDimSize = 0;
664  for (auto operand : adaptor.getOperands()) {
665  ShapeAdaptor operandShape(operand.getType());
666 
667  // We need to know the length of the concatenation axis of all inputs to
668  // determine the dimension size of the output shape.
669  if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
670  concatDimSize = ShapedType::kDynamic;
671  break;
672  }
673 
674  concatDimSize += operandShape.getDimSize(axis);
675  }
676 
677  outputShape[axis] = concatDimSize;
678 
679  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
680  return success();
681 }
682 
683 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
684  MLIRContext *context, ::std::optional<Location> location,
685  ValueShapeRange operands, DictionaryAttr attributes,
686  OpaqueProperties properties, RegionRange regions,
687  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
688  auto elementType = IntegerType::get(context, /*width=*/1);
689 
691  if (resolveBroadcastShape(operands, outShape).failed()) {
692  inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
693  return success();
694  }
695 
696  inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
697  return success();
698 }
699 
700 bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
701  if (l.size() != r.size() || l.size() != 1)
702  return false;
703  return succeeded(verifyCompatibleShape(l[0], r[0]));
704 }
705 
706 LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
707  MLIRContext *context, ::std::optional<Location> location,
708  FullyConnectedOp::Adaptor adaptor,
709  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
710  ShapeAdaptor inputShape(adaptor.getInput().getType());
711  ShapeAdaptor weightShape(adaptor.getWeight().getType());
712  ShapeAdaptor biasShape(adaptor.getBias().getType());
713 
714  // All shapes are dynamic.
715  SmallVector<int64_t> outShape;
716  outShape.resize(2, ShapedType::kDynamic);
717 
718  if (inputShape.hasRank()) {
719  outShape[0] = inputShape.getDimSize(0);
720  }
721 
722  if (weightShape.hasRank()) {
723  outShape[1] = weightShape.getDimSize(0);
724  }
725 
726  if (biasShape.hasRank()) {
727  outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0)
728  : outShape[1];
729  }
730 
731  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
732  return success();
733 }
734 
736 
737 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
738  MLIRContext *context, ::std::optional<Location> location,
739  MatMulOp::Adaptor adaptor,
740  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
741  ShapeAdaptor lhsShape(adaptor.getA().getType());
742  ShapeAdaptor rhsShape(adaptor.getB().getType());
743 
744  // All shapes are dynamic.
745  SmallVector<int64_t> outShape;
746  outShape.resize(3, ShapedType::kDynamic);
747 
748  if (lhsShape.hasRank()) {
749  outShape[0] = lhsShape.getDimSize(0);
750  outShape[1] = lhsShape.getDimSize(1);
751  }
752 
753  if (rhsShape.hasRank()) {
754  outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
755  : outShape[0];
756  outShape[2] = rhsShape.getDimSize(2);
757  }
758 
759  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
760  return success();
761 }
762 
763 LogicalResult tosa::PadOp::inferReturnTypeComponents(
764  MLIRContext *context, ::std::optional<Location> location,
765  PadOp::Adaptor adaptor,
766  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
767  ShapeAdaptor inputShape(adaptor.getInput1().getType());
768  ShapeAdaptor paddingShape(adaptor.getPadding().getType());
769  SmallVector<int64_t> outputShape;
770 
771  // If both inputs have unknown shape, we cannot determine the shape of the
772  // output.
773  if (!inputShape.hasRank() && !paddingShape.hasRank()) {
774  inferredReturnShapes.push_back(ShapedTypeComponents());
775  return success();
776  }
777 
778  // If the input rank is unknown we can info the output rank using the padding
779  // shape's first dim.
780  if (!inputShape.hasRank()) {
781  if (paddingShape.isDynamicDim(0)) {
782  inferredReturnShapes.push_back(ShapedTypeComponents());
783  return success();
784  }
785 
786  outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamic);
787  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
788  return success();
789  }
790 
791  DenseIntElementsAttr paddings;
792  // If the paddings value is not a constant, all dimensions must be dynamic.
793  if (!matchPattern(adaptor.getPadding(), m_Constant(&paddings))) {
794  outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
795  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
796  return success();
797  }
798 
799  SmallVector<int64_t> paddingValues;
800  for (auto val : paddings) {
801  paddingValues.push_back(val.getSExtValue());
802  }
803 
804  outputShape.reserve(inputShape.getRank());
805  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
806  if (inputShape.isDynamicDim(i)) {
807  outputShape.push_back(ShapedType::kDynamic);
808  continue;
809  }
810 
811  outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] +
812  paddingValues[i * 2 + 1]);
813  }
814 
815  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
816  return success();
817 }
818 
820  return to_vector(llvm::map_range(shape, [](int64_t dim) {
821  return dim == -1 ? ShapedType::kDynamic : dim;
822  }));
823 }
824 
825 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
826  MLIRContext *context, ::std::optional<Location> location,
827  SliceOp::Adaptor adaptor,
828  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
829  inferredReturnShapes.push_back(
830  ShapedTypeComponents(convertToMlirShape(adaptor.getSize())));
831  return success();
832 }
833 
835  auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
836  if (!inputType)
837  return success();
838 
839  if (static_cast<size_t>(inputType.getRank()) != getStart().size())
840  return emitOpError(
841  "length of start attribute is not equal rank of input shape");
842 
843  if (static_cast<size_t>(inputType.getRank()) != getSize().size())
844  return emitOpError(
845  "length of size attribute is not equal rank of input shape");
846 
847  return success();
848 }
849 
850 LogicalResult tosa::TableOp::inferReturnTypeComponents(
851  MLIRContext *context, ::std::optional<Location> location,
852  TableOp::Adaptor adaptor,
853  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
854  ShapeAdaptor inputShape(adaptor.getInput().getType());
855 
856  if (!inputShape.hasRank()) {
857  inferredReturnShapes.push_back(ShapedTypeComponents());
858  return success();
859  }
860 
861  inferredReturnShapes.resize(1);
862  inputShape.getDims(inferredReturnShapes[0]);
863  return success();
864 }
865 
866 LogicalResult tosa::TileOp::inferReturnTypeComponents(
867  MLIRContext *context, ::std::optional<Location> location,
868  TileOp::Adaptor adaptor,
869  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
870  ArrayRef<int64_t> multiples = adaptor.getMultiples();
871  ShapeAdaptor inputShape(adaptor.getInput1().getType());
872  SmallVector<int64_t> outputShape;
873  if (!inputShape.hasRank()) {
874  outputShape.resize(multiples.size(), ShapedType::kDynamic);
875  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
876  return success();
877  } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
878  return failure();
879 
880  // Any non dynamic dimension can be multiplied to a known size.
881  outputShape.reserve(multiples.size());
882  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
883  int64_t dim = inputShape.getDimSize(i);
884  if (dim != ShapedType::kDynamic)
885  dim *= multiples[i];
886  outputShape.push_back(dim);
887  }
888 
889  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
890  return success();
891 }
892 
894  ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
895  ShapedType outputType = llvm::cast<ShapedType>(getType());
896  auto multiples = getMultiples();
897 
898  if (inputType.hasRank()) {
899  if (static_cast<size_t>(inputType.getRank()) != multiples.size())
900  return emitOpError("expect 'multiples' array to have length ")
901  << inputType.getRank() << " but got " << multiples.size() << ".";
902  if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
903  return emitOpError("expect same input and output tensor rank.");
904  } else if (outputType.hasRank() &&
905  static_cast<size_t>(outputType.getRank()) != multiples.size())
906  return emitOpError("expect 'multiples' array to have length ")
907  << outputType.getRank() << " but got " << multiples.size() << ".";
908 
909  return success();
910 }
911 
912 bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
913  if (l.size() != r.size() || l.size() != 1)
914  return false;
915  return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
916 }
917 
918 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
919  MLIRContext *context, ::std::optional<Location> location,
920  ReshapeOp::Adaptor adaptor,
921  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
922  ShapeAdaptor inputShape(adaptor.getInput1().getType());
923  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
924  llvm::SmallVector<int64_t> newShapeValue =
925  convertToMlirShape(adaptor.getNewShape());
926 
927  // We cannot infer from the total number of elements so we must take the
928  // shape attribute as exact.
929  if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
930  inferredReturnShapes.push_back(
931  ShapedTypeComponents(newShapeValue, inputType));
932  return success();
933  }
934 
935  // Determine the number of elements covered by the slice of all static
936  // dimensions. This allows us to infer the length of the remaining dynamic
937  // dimension.
938  int64_t numElements = inputShape.getNumElements();
939  int64_t staticMul = 1;
940  for (auto val : newShapeValue) {
941  if (!ShapedType::isDynamic(val)) {
942  staticMul *= val;
943  }
944  }
945 
946  // Determine the length of the dynamic dimension.
947  for (auto &val : newShapeValue) {
948  if (ShapedType::isDynamic(val))
949  val = numElements / staticMul;
950  }
951 
952  inferredReturnShapes.push_back(
953  ShapedTypeComponents(newShapeValue, inputType));
954  return success();
955 }
956 
958  TensorType inputType = getInput1().getType();
959  RankedTensorType outputType = getType();
960 
961  if (hasZeroDimension(inputType) || hasZeroDimension(outputType))
962  return emitOpError() << "tensor has a dimension with size zero. Each "
963  "dimension of a tensor must have size >= 1";
964 
965  if ((int64_t) getNewShape().size() != outputType.getRank())
966  return emitOpError() << "new shape does not match result rank";
967 
968  for (auto [newShapeDim, outputShapeDim] :
969  zip(getNewShape(), outputType.getShape()))
970  if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
971  newShapeDim != outputShapeDim)
972  return emitOpError() << "new shape is inconsistent with result shape";
973 
974  if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
975  int64_t inputElementsNum = inputType.getNumElements();
976  int64_t outputElementsNum = outputType.getNumElements();
977  if (inputElementsNum != outputElementsNum) {
978  return emitOpError() << "cannot reshape " << inputElementsNum
979  << " elements into " << outputElementsNum;
980  }
981  }
982 
983  int missingDims = llvm::count(getNewShape(), -1);
984  if (missingDims > 1)
985  return emitOpError() << "expected at most one target dimension to be -1";
986 
987  return mlir::success();
988 }
989 
990 LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int64_t> &perms) {
991  // Perms must be constants.
992  DenseIntElementsAttr permsAttr;
993  if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
994  return failure();
995 
996  // Transpose is not the identity transpose.
997  perms = llvm::to_vector(
998  llvm::map_range(permsAttr.getValues<APInt>(),
999  [](const APInt &val) { return val.getSExtValue(); }));
1000 
1001  return success();
1002 }
1003 
1004 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
1005  MLIRContext *context, ::std::optional<Location> location,
1006  TransposeOp::Adaptor adaptor,
1007  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1008  ShapeAdaptor inputShape(adaptor.getInput1().getType());
1009  ShapeAdaptor permsShape(adaptor.getPerms().getType());
1010 
1011  // We cannot infer anything from a rank-0 "permutation" tensor.
1012  if (permsShape.hasRank() && permsShape.getRank() == 0)
1013  return failure();
1014 
1015  // If input rank and permutation length is unknown, the output rank is
1016  // unknown.
1017  if (!inputShape.hasRank() || !permsShape.hasRank() ||
1018  permsShape.isDynamicDim(0)) {
1019  inferredReturnShapes.push_back(ShapedTypeComponents());
1020  return success();
1021  }
1022 
1023  // This would imply the number of permutations does not match the rank of the
1024  // input which is illegal.
1025  if (permsShape.getDimSize(0) != inputShape.getRank()) {
1026  return failure();
1027  }
1028 
1029  SmallVector<int64_t> outputShape;
1030  // Rank-0 means no permutations matter.
1031  if (inputShape.getRank() == 0) {
1032  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1033  return success();
1034  }
1035 
1036  // Check whether the input dimensions are all the same.
1037  bool allTheSame = true;
1038  for (int i = 1, s = inputShape.getRank(); i < s; i++) {
1039  if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
1040  allTheSame = false;
1041  break;
1042  }
1043  }
1044 
1045  // If all of the input dimensions are the same we don't care about the
1046  // permutation.
1047  if (allTheSame) {
1048  outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
1049  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1050  return success();
1051  }
1052 
1053  outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1054  // If the permuations are a constant we can directly determine the output
1055  // shape.
1056  DenseIntElementsAttr attr;
1057  if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
1058  attr.getType().getRank() == 1) {
1059  ShapeAdaptor permShape = attr;
1060  // Constant permutation must be the same length as the input rank.
1061  if (inputShape.getRank() != permShape.getRank())
1062  return emitOptionalError(location,
1063  "constant permutation must be the same length"
1064  " as the input rank");
1065 
1066  // Constant permutation values must be within the input rank.
1067  for (int i = 0, e = inputShape.getRank(); i < e; i++) {
1068  if (inputShape.getRank() <= permShape.getDimSize(i))
1069  return failure();
1070  }
1071 
1072  outputShape.reserve(inputShape.getRank());
1073  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1074  outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
1075  }
1076  }
1077 
1078  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1079  return success();
1080 }
1081 
1083  TensorType inputType = getInput1().getType();
1084  TensorType permType = getPerms().getType();
1085  TensorType outputType = getOutput().getType();
1086 
1087  if (permType.hasRank() && permType.getRank() != 1)
1088  return emitOpError()
1089  << "expected permutation tensor to be rank 1 but got rank "
1090  << permType.getRank();
1091  if (inputType.hasRank() && permType.hasRank())
1092  if (!permType.isDynamicDim(0) &&
1093  permType.getDimSize(0) != inputType.getRank())
1094  return emitOpError() << "expected permutation tensor dim 0 to have size "
1095  << inputType.getRank()
1096  << " (input rank) but got size "
1097  << permType.getDimSize(0);
1098  if (inputType.hasRank() && outputType.hasRank() &&
1099  inputType.getRank() != outputType.getRank())
1100  return emitOpError()
1101  << "expected input tensor rank to equal result tensor rank";
1102  if (outputType.hasRank() && permType.hasRank())
1103  if (!permType.isDynamicDim(0) &&
1104  permType.getDimSize(0) != outputType.getRank())
1105  return emitOpError() << "expected permutation tensor dim 0 to have size "
1106  << outputType.getRank()
1107  << " (output rank) but got size "
1108  << permType.getDimSize(0);
1109 
1110  SmallVector<int64_t> constantPerms;
1111  if (succeeded(getConstantPerms(constantPerms))) {
1112  // Assert that the permutation tensor has a rank, which means that the rank
1113  // has been verified above.
1114  assert(permType.hasRank() &&
1115  "Unexpectedly found permutation tensor without rank");
1116  if (!isPermutationVector(constantPerms))
1117  return emitOpError() << "expected valid permutation tensor";
1118  }
1119  return success();
1120 }
1121 
1122 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
1123  MLIRContext *context, ::std::optional<Location> location,
1124  GatherOp::Adaptor adaptor,
1125  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1126  llvm::SmallVector<int64_t> outputShape;
1127  outputShape.resize(3, ShapedType::kDynamic);
1128 
1129  ShapeAdaptor valuesShape(adaptor.getValues().getType());
1130  if (valuesShape.hasRank()) {
1131  outputShape[0] = valuesShape.getDimSize(0);
1132  outputShape[2] = valuesShape.getDimSize(2);
1133  }
1134 
1135  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
1136  if (indicesShape.hasRank()) {
1137  if (outputShape[0] == ShapedType::kDynamic)
1138  outputShape[0] = indicesShape.getDimSize(0);
1139  if (outputShape[1] == ShapedType::kDynamic)
1140  outputShape[1] = indicesShape.getDimSize(1);
1141  }
1142 
1143  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1144  return success();
1145 }
1146 
1147 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
1148  MLIRContext *context, ::std::optional<Location> location,
1149  ResizeOp::Adaptor adaptor,
1150  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1151  llvm::SmallVector<int64_t, 4> outputShape;
1152  outputShape.resize(4, ShapedType::kDynamic);
1153 
1154  ShapeAdaptor inputShape(adaptor.getInput().getType());
1155  if (!inputShape.hasRank())
1156  return failure();
1157 
1158  outputShape[0] = inputShape.getDimSize(0);
1159  outputShape[3] = inputShape.getDimSize(3);
1160  int64_t inputHeight = inputShape.getDimSize(1);
1161  int64_t inputWidth = inputShape.getDimSize(2);
1162 
1163  if ((inputHeight == ShapedType::kDynamic) ||
1164  (inputWidth == ShapedType::kDynamic))
1165  return failure();
1166 
1167  llvm::ArrayRef<int64_t> scaleInt = adaptor.getScale();
1168  llvm::ArrayRef<int64_t> offsetInt = adaptor.getOffset();
1169  llvm::ArrayRef<int64_t> borderInt = adaptor.getBorder();
1170 
1171  // Compute the output shape based on attributes: scale, offset, and border.
1172  outputShape[1] =
1173  (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
1174  scaleInt[1]) +
1175  1;
1176 
1177  outputShape[2] =
1178  (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
1179  scaleInt[3]) +
1180  1;
1181 
1182  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1183  return success();
1184 }
1185 
1186 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
1187  MLIRContext *context, ::std::optional<Location> location,
1188  ScatterOp::Adaptor adaptor,
1189  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1190  llvm::SmallVector<int64_t> outputShape;
1191  outputShape.resize(3, ShapedType::kDynamic);
1192 
1193  ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
1194  if (valuesInShape.hasRank()) {
1195  outputShape[0] = valuesInShape.getDimSize(0);
1196  outputShape[1] = valuesInShape.getDimSize(1);
1197  outputShape[2] = valuesInShape.getDimSize(2);
1198  }
1199 
1200  ShapeAdaptor indicesShape(adaptor.getIndices().getType());
1201  if (indicesShape.hasRank()) {
1202  if (outputShape[0] == ShapedType::kDynamic)
1203  outputShape[0] = indicesShape.getDimSize(0);
1204  }
1205 
1206  ShapeAdaptor inputShape(adaptor.getInput().getType());
1207  if (inputShape.hasRank()) {
1208  if (outputShape[0] == ShapedType::kDynamic)
1209  outputShape[0] = inputShape.getDimSize(0);
1210  if (outputShape[2] == ShapedType::kDynamic)
1211  outputShape[2] = inputShape.getDimSize(2);
1212  }
1213 
1214  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1215  return success();
1216 }
1217 
1219  ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
1220  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1221  int64_t axisVal = axis.getValue().getSExtValue();
1222  if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
1223  inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1224  return success();
1225  }
1226 
1227  SmallVector<int64_t> outputShape;
1228  operandShape.getDims(outputShape);
1229  outputShape[axisVal] = 1;
1230  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1231  return success();
1232 }
1233 
1234 #define COMPATIBLE_RETURN_TYPES(OP) \
1235  bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
1236  if (l.size() != r.size() || l.size() != 1) \
1237  return false; \
1238  if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
1239  return false; \
1240  return succeeded(verifyCompatibleShape(l[0], r[0])); \
1241  }
1242 
1243 #define REDUCE_SHAPE_INFER(OP) \
1244  LogicalResult OP::inferReturnTypeComponents( \
1245  MLIRContext *context, ::std::optional<Location> location, \
1246  OP::Adaptor adaptor, \
1247  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
1248  Type inputType = \
1249  llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
1250  ShapeAdaptor inputShape(adaptor.getInput().getType()); \
1251  const Properties &prop = adaptor.getProperties(); \
1252  return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
1253  inferredReturnShapes); \
1254  } \
1255  COMPATIBLE_RETURN_TYPES(OP)
1256 
1257 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
1258 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
1259 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
1260 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
1261 REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
1262 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
1263 #undef REDUCE_SHAPE_INFER
1264 COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
1265 #undef COMPATIBLE_RETURN_TYPES
1266 
1267 template <typename T>
1269  // All TOSA reduce Ops have input, output and axis.
1270  TensorType inputType = op.getInput().getType();
1271  TensorType outputType = op.getOutput().getType();
1272  int32_t reduceAxis = op.getAxis();
1273 
1274  if (reduceAxis < 0) {
1275  op.emitOpError("reduce axis must not be negative");
1276  return failure();
1277  }
1278  if (inputType.hasRank()) {
1279  int64_t inputRank = inputType.getRank();
1280  // We allow for a special case where the input/output shape has rank 0 and
1281  // axis is also 0.
1282  if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
1283  op.emitOpError("expect input tensor rank (")
1284  << inputRank << ") to be larger than reduce axis (" << reduceAxis
1285  << ")";
1286  return failure();
1287  }
1288  }
1289  if (outputType.hasRank()) {
1290  int64_t outputRank = outputType.getRank();
1291  if (inputType.hasRank() && outputRank != inputType.getRank()) {
1292  op.emitOpError(
1293  "expect output tensor rank to be equal to input tensor rank");
1294  return failure();
1295  }
1296  if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
1297  op.emitOpError("expect output tensor rank (")
1298  << outputRank << ") to be larger than reduce axis (" << reduceAxis
1299  << ")";
1300  return failure();
1301  }
1302  // We can only verify the reduced dimension size to be 1 if this is not the
1303  // special case of output rank == 0.
1304  if (outputRank != 0) {
1305  auto outputShape = outputType.getShape();
1306  if (!outputType.isDynamicDim(reduceAxis) &&
1307  outputShape[reduceAxis] != 1) {
1308  op.emitOpError("expect reduced dimension size to be 1, got ")
1309  << outputShape[reduceAxis];
1310  return failure();
1311  }
1312  }
1313  }
1314  return success();
1315 }
1316 
1323 
1325  const ValueShapeRange &operands,
1326  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1327  llvm::SmallVector<int64_t> outShape;
1328  if (resolveBroadcastShape(operands, outShape).failed()) {
1329  inferredReturnShapes.push_back(ShapedTypeComponents());
1330  } else {
1331  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1332  }
1333  return success();
1334 }
1335 
1336 #define NARY_SHAPE_INFER(OP) \
1337  LogicalResult OP::inferReturnTypeComponents( \
1338  MLIRContext *context, ::std::optional<Location> location, \
1339  ValueShapeRange operands, DictionaryAttr attributes, \
1340  OpaqueProperties properties, RegionRange regions, \
1341  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
1342  return NAryInferReturnTypes(operands, inferredReturnShapes); \
1343  }
1344 
1345 NARY_SHAPE_INFER(tosa::AbsOp)
1346 NARY_SHAPE_INFER(tosa::AddOp)
1347 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
1348 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
1349 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
1350 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
1351 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
1352 NARY_SHAPE_INFER(tosa::CastOp)
1353 NARY_SHAPE_INFER(tosa::CeilOp)
1354 NARY_SHAPE_INFER(tosa::ClampOp)
1355 NARY_SHAPE_INFER(tosa::ClzOp)
1356 NARY_SHAPE_INFER(tosa::CosOp)
1357 NARY_SHAPE_INFER(tosa::DivOp)
1358 NARY_SHAPE_INFER(tosa::ExpOp)
1359 NARY_SHAPE_INFER(tosa::FloorOp)
1360 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
1361 NARY_SHAPE_INFER(tosa::GreaterOp)
1362 NARY_SHAPE_INFER(tosa::IdentityOp)
1363 NARY_SHAPE_INFER(tosa::LogOp)
1364 NARY_SHAPE_INFER(tosa::LogicalAndOp)
1365 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
1366 NARY_SHAPE_INFER(tosa::LogicalNotOp)
1367 NARY_SHAPE_INFER(tosa::LogicalOrOp)
1368 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
1369 NARY_SHAPE_INFER(tosa::LogicalXorOp)
1370 NARY_SHAPE_INFER(tosa::MaximumOp)
1371 NARY_SHAPE_INFER(tosa::MinimumOp)
1372 NARY_SHAPE_INFER(tosa::MulOp)
1373 NARY_SHAPE_INFER(tosa::NegateOp)
1374 NARY_SHAPE_INFER(tosa::PowOp)
1375 NARY_SHAPE_INFER(tosa::ReciprocalOp)
1376 NARY_SHAPE_INFER(tosa::RescaleOp)
1377 NARY_SHAPE_INFER(tosa::ReverseOp)
1378 NARY_SHAPE_INFER(tosa::RsqrtOp)
1379 NARY_SHAPE_INFER(tosa::SinOp)
1380 NARY_SHAPE_INFER(tosa::SelectOp)
1381 NARY_SHAPE_INFER(tosa::SubOp)
1382 NARY_SHAPE_INFER(tosa::TanhOp)
1383 NARY_SHAPE_INFER(tosa::ErfOp)
1384 NARY_SHAPE_INFER(tosa::SigmoidOp)
1385 #undef PRED_SHAPE_INFER
1386 
1388  ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
1389  ArrayRef<int64_t> pad,
1390  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1391  llvm::SmallVector<int64_t> outputShape;
1392  outputShape.resize(4, ShapedType::kDynamic);
1393 
1394  // We only know the rank if the input type is unranked.
1395  if (!inputShape) {
1396  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1397  return success();
1398  }
1399 
1400  // Batch and number of channels are identical for pooling layer.
1401  outputShape[0] = inputShape.getDimSize(0);
1402  outputShape[3] = inputShape.getDimSize(3);
1403 
1404  int64_t height = inputShape.getDimSize(1);
1405  int64_t width = inputShape.getDimSize(2);
1406 
1407  if (!ShapedType::isDynamic(height)) {
1408  int64_t padded = height + pad[0] + pad[1] - kernel[0];
1409  outputShape[1] = padded / stride[0] + 1;
1410  }
1411 
1412  if (!ShapedType::isDynamic(width)) {
1413  int64_t padded = width + pad[2] + pad[3] - kernel[1];
1414  outputShape[2] = padded / stride[1] + 1;
1415  }
1416 
1417  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1418  return success();
1419 }
1420 
1421 LogicalResult Conv2DOp::inferReturnTypeComponents(
1422  MLIRContext *context, ::std::optional<Location> location,
1423  Conv2DOp::Adaptor adaptor,
1424  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1425  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
1426 
1427  int64_t inputWidth = ShapedType::kDynamic;
1428  int64_t inputHeight = ShapedType::kDynamic;
1429  int64_t weightWidth = ShapedType::kDynamic;
1430  int64_t weightHeight = ShapedType::kDynamic;
1431 
1432  // Input shape describes input width/height and batch.
1433 
1434  ShapeAdaptor inputShape(adaptor.getInput().getType());
1435  if (inputShape.hasRank()) {
1436  outputShape[0] = inputShape.getDimSize(0);
1437  inputHeight = inputShape.getDimSize(1);
1438  inputWidth = inputShape.getDimSize(2);
1439  }
1440 
1441  // Weight shapes describes the filter width/height and the output channels.
1442  ShapeAdaptor weightShape(adaptor.getWeight().getType());
1443  if (weightShape.hasRank()) {
1444  outputShape[3] = weightShape.getDimSize(0);
1445  weightHeight = weightShape.getDimSize(1);
1446  weightWidth = weightShape.getDimSize(2);
1447  }
1448 
1449  // Bias shape can describe the output channels.
1450  ShapeAdaptor biasShape(adaptor.getBias().getType());
1451  if (biasShape.hasRank()) {
1452  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1453  ? biasShape.getDimSize(0)
1454  : outputShape[3];
1455  }
1456 
1457  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
1458  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1459  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
1460 
1461  if (!ShapedType::isDynamic(inputHeight) &&
1462  !ShapedType::isDynamic(weightHeight)) {
1463  int64_t inputSize = inputHeight + padding[0] + padding[1];
1464  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1465  int64_t unstridedResult = inputSize - filterSize + 1;
1466  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1467  }
1468 
1469  if (!ShapedType::isDynamic(inputWidth) &&
1470  !ShapedType::isDynamic(weightWidth)) {
1471  int64_t inputSize = inputWidth + padding[2] + padding[3];
1472  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1473  int64_t unstridedResult = inputSize - filterSize + 1;
1474  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1475  }
1476 
1477  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1478  return success();
1479 }
1480 
1481 LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); }
1482 
1483 LogicalResult Conv3DOp::inferReturnTypeComponents(
1484  MLIRContext *context, ::std::optional<Location> location,
1485  Conv3DOp::Adaptor adaptor,
1486  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1487  llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
1488 
1489  int64_t inputWidth = ShapedType::kDynamic;
1490  int64_t inputHeight = ShapedType::kDynamic;
1491  int64_t inputDepth = ShapedType::kDynamic;
1492 
1493  int64_t weightWidth = ShapedType::kDynamic;
1494  int64_t weightHeight = ShapedType::kDynamic;
1495  int64_t weightDepth = ShapedType::kDynamic;
1496 
1497  // Input shape describes input width/height and batch.
1498  ShapeAdaptor inputShape(adaptor.getInput().getType());
1499  if (inputShape.hasRank()) {
1500  outputShape[0] = inputShape.getDimSize(0);
1501  inputDepth = inputShape.getDimSize(1);
1502  inputHeight = inputShape.getDimSize(2);
1503  inputWidth = inputShape.getDimSize(3);
1504  }
1505 
1506  // Weight shapes describes the filter width/height and the output channels.
1507  ShapeAdaptor weightShape(adaptor.getWeight().getType());
1508  if (weightShape.hasRank()) {
1509  outputShape[4] = weightShape.getDimSize(0);
1510  weightDepth = weightShape.getDimSize(1);
1511  weightHeight = weightShape.getDimSize(2);
1512  weightWidth = weightShape.getDimSize(3);
1513  }
1514 
1515  // Bias shape can describe the output channels.
1516  ShapeAdaptor biasShape(adaptor.getBias().getType());
1517  if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
1518  outputShape[4] = biasShape.getDimSize(0);
1519  }
1520 
1521  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
1522  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1523  llvm::ArrayRef<int64_t> pad = adaptor.getPad();
1524 
1525  if (!ShapedType::isDynamic(inputDepth) &&
1526  !ShapedType::isDynamic(weightDepth)) {
1527  int32_t inputSize = inputDepth + pad[0] + pad[1];
1528  int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
1529  int32_t unstridedResult = inputSize - filterSize + 1;
1530  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1531  }
1532 
1533  if (!ShapedType::isDynamic(inputHeight) &&
1534  !ShapedType::isDynamic(weightHeight)) {
1535  int32_t inputSize = inputHeight + pad[2] + pad[3];
1536  int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
1537  int32_t unstridedResult = inputSize - filterSize + 1;
1538  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1539  }
1540 
1541  if (!ShapedType::isDynamic(inputWidth) &&
1542  !ShapedType::isDynamic(weightWidth)) {
1543  int32_t inputSize = inputWidth + pad[4] + pad[5];
1544  int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
1545  int32_t unstridedResult = inputSize - filterSize + 1;
1546  outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
1547  }
1548 
1549  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1550  return success();
1551 }
1552 
1553 LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); }
1554 
1555 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
1556  MLIRContext *context, ::std::optional<Location> location,
1557  AvgPool2dOp::Adaptor adaptor,
1558  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1559  ShapeAdaptor inputShape(adaptor.getInput().getType());
1560  const Properties &prop = adaptor.getProperties();
1561  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
1562  inferredReturnShapes);
1563 }
1564 
1565 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
1566  MLIRContext *context, ::std::optional<Location> location,
1567  MaxPool2dOp::Adaptor adaptor,
1568  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1569  ShapeAdaptor inputShape(adaptor.getInput().getType());
1570  const Properties &prop = adaptor.getProperties();
1571  return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
1572  inferredReturnShapes);
1573 }
1574 
1575 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
1576  MLIRContext *context, ::std::optional<Location> location,
1577  DepthwiseConv2DOp::Adaptor adaptor,
1578  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1579  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
1580 
1581  int64_t inputWidth = ShapedType::kDynamic;
1582  int64_t inputHeight = ShapedType::kDynamic;
1583  int64_t inputChannels = ShapedType::kDynamic;
1584 
1585  int64_t weightWidth = ShapedType::kDynamic;
1586  int64_t weightHeight = ShapedType::kDynamic;
1587  int64_t depthChannels = ShapedType::kDynamic;
1588 
1589  // Input shape describes input width/height and batch.
1590  ShapeAdaptor inputShape(adaptor.getInput().getType());
1591  if (inputShape.hasRank()) {
1592  outputShape[0] = inputShape.getDimSize(0);
1593  inputHeight = inputShape.getDimSize(1);
1594  inputWidth = inputShape.getDimSize(2);
1595  inputChannels = inputShape.getDimSize(3);
1596  }
1597 
1598  // Weight shapes describes the filter width/height and the output channels.
1599  ShapeAdaptor weightShape(adaptor.getWeight().getType());
1600  if (weightShape.hasRank()) {
1601  weightHeight = weightShape.getDimSize(0);
1602  weightWidth = weightShape.getDimSize(1);
1603  inputChannels = ShapedType::isDynamic(inputChannels)
1604  ? weightShape.getDimSize(2)
1605  : inputChannels;
1606  depthChannels = weightShape.getDimSize(3);
1607  }
1608 
1609  // If both inputChannels and depthChannels are available we can determine
1610  // the output channels.
1611  if (!ShapedType::isDynamic(inputChannels) &&
1612  !ShapedType::isDynamic(depthChannels)) {
1613  outputShape[3] = inputChannels * depthChannels;
1614  }
1615 
1616  // Bias shape can describe the output channels.
1617  ShapeAdaptor biasShape(adaptor.getBias().getType());
1618  if (biasShape.hasRank()) {
1619  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1620  ? biasShape.getDimSize(0)
1621  : outputShape[3];
1622  }
1623 
1624  llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
1625  llvm::ArrayRef<int64_t> padding = adaptor.getPad();
1626  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1627 
1628  if (!ShapedType::isDynamic(inputHeight) &&
1629  !ShapedType::isDynamic(weightHeight)) {
1630  int64_t inputSize = inputHeight + padding[0] + padding[1];
1631  int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1632  int64_t unstridedResult = inputSize - filterSize + 1;
1633  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1634  }
1635 
1636  if (!ShapedType::isDynamic(inputWidth) &&
1637  !ShapedType::isDynamic(weightWidth)) {
1638  int64_t inputSize = inputWidth + padding[2] + padding[3];
1639  int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1640  int64_t unstridedResult = inputSize - filterSize + 1;
1641  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1642  }
1643 
1644  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1645  return success();
1646 }
1647 
1649 
1650 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
1651  MLIRContext *context, ::std::optional<Location> location,
1652  TransposeConv2DOp::Adaptor adaptor,
1653  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1654  // outputShape is mutable.
1655  llvm::SmallVector<int64_t> outputShape =
1656  convertToMlirShape(adaptor.getOutShape());
1657 
1658  int64_t inputWidth = ShapedType::kDynamic;
1659  int64_t inputHeight = ShapedType::kDynamic;
1660  int64_t weightWidth = ShapedType::kDynamic;
1661  int64_t weightHeight = ShapedType::kDynamic;
1662 
1663  // Input shape describes input width/height and batch.
1664  ShapeAdaptor inputShape(adaptor.getInput().getType());
1665  if (inputShape.hasRank()) {
1666  outputShape[0] = ShapedType::isDynamic(outputShape[0])
1667  ? inputShape.getDimSize(0)
1668  : outputShape[0];
1669  inputHeight = inputShape.getDimSize(1);
1670  inputWidth = inputShape.getDimSize(2);
1671  }
1672 
1673  // Weight shapes describes the filter width/height and the output channels.
1674  ShapeAdaptor weightShape(adaptor.getFilter().getType());
1675  if (weightShape.hasRank()) {
1676  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1677  ? weightShape.getDimSize(0)
1678  : outputShape[3];
1679  weightHeight = weightShape.getDimSize(1);
1680  weightWidth = weightShape.getDimSize(2);
1681  }
1682 
1683  // Bias shape can describe the output channels.
1684  ShapeAdaptor biasShape(adaptor.getInput().getType());
1685  if (biasShape.hasRank()) {
1686  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1687  ? biasShape.getDimSize(0)
1688  : outputShape[3];
1689  }
1690 
1691  llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
1692  llvm::ArrayRef<int64_t> stride = adaptor.getStride();
1693 
1694  if (!ShapedType::isDynamic(inputHeight) &&
1695  !ShapedType::isDynamic(weightHeight)) {
1696  int64_t calculateSize =
1697  (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
1698  outputShape[1] =
1699  ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
1700  }
1701 
1702  if (!ShapedType::isDynamic(inputWidth) &&
1703  !ShapedType::isDynamic(weightWidth)) {
1704  int64_t calculateSize =
1705  (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
1706  outputShape[2] =
1707  ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
1708  }
1709 
1710  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1711  return success();
1712 }
1713 
1714 LogicalResult IfOp::inferReturnTypeComponents(
1715  MLIRContext *context, ::std::optional<Location> location,
1716  IfOp::Adaptor adaptor,
1717  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1719  for (Region *region : adaptor.getRegions()) {
1720  for (auto &block : *region)
1721  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1722  yieldOps.push_back(returnOp);
1723  }
1724 
1725  if (yieldOps.empty())
1726  return failure();
1727 
1728  // Get the initial type information for the yield op.
1729  llvm::SmallVector<ValueKnowledge> resultKnowledge;
1730  resultKnowledge.reserve(yieldOps.front().getNumOperands());
1731  for (auto operand : yieldOps.front().getOperands()) {
1732  resultKnowledge.push_back(
1733  ValueKnowledge::getKnowledgeFromType(operand.getType()));
1734  }
1735 
1736  for (auto yieldOp : yieldOps) {
1737  if (resultKnowledge.size() != yieldOp.getNumOperands())
1738  return failure();
1739 
1740  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
1741  int32_t index = it.index();
1742  auto meet = ValueKnowledge::meet(
1743  resultKnowledge[index],
1744  ValueKnowledge::getKnowledgeFromType(it.value().getType()));
1745  if (!meet)
1746  continue;
1747  resultKnowledge[index] = meet;
1748  }
1749  }
1750 
1751  for (const ValueKnowledge &result : resultKnowledge) {
1752  inferredReturnShapes.push_back(result.getShapedTypeComponents());
1753  }
1754 
1755  return success();
1756 }
1757 
1758 LogicalResult WhileOp::inferReturnTypeComponents(
1759  MLIRContext *context, ::std::optional<Location> location,
1760  WhileOp::Adaptor adaptor,
1761  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1763  for (auto &block : adaptor.getBody())
1764  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1765  yieldOps.push_back(returnOp);
1766 
1767  // TOSA's while must have a tosa.yield as its terminator. If not found this
1768  // tosa.while is invalid.
1769  if (yieldOps.empty())
1770  return failure();
1771 
1772  // Get the initial type information from the operand types.
1773  llvm::SmallVector<ValueKnowledge> resultKnowledge;
1774  resultKnowledge.reserve(yieldOps.front().getNumOperands());
1775  for (auto operand : yieldOps.front().getOperands()) {
1776  resultKnowledge.push_back(
1777  ValueKnowledge::getKnowledgeFromType(operand.getType()));
1778  }
1779 
1780  for (auto yieldOp : yieldOps) {
1781  if (resultKnowledge.size() != yieldOp.getNumOperands())
1782  return failure();
1783 
1784  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
1785  int32_t index = it.index();
1786  if (auto meet = ValueKnowledge::meet(
1787  resultKnowledge[index],
1788  ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
1789  resultKnowledge[index] = meet;
1790  }
1791  }
1792  }
1793 
1794  for (const ValueKnowledge &result : resultKnowledge) {
1795  inferredReturnShapes.push_back(result.getShapedTypeComponents());
1796  }
1797 
1798  return success();
1799 }
1800 
1801 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
1802  if (auto vt = llvm::dyn_cast<VectorType>(getType()))
1803  return llvm::to_vector<4>(vt.getShape());
1804  return std::nullopt;
1805 }
1806 
1807 // parse and print of IfOp refer to the implementation of SCF dialect.
1809  // Create the regions for 'then'.
1810  result.regions.reserve(2);
1811  Region *thenRegion = result.addRegion();
1812  Region *elseRegion = result.addRegion();
1813 
1814  auto &builder = parser.getBuilder();
1816  // Create a i1 tensor type for the boolean condition.
1817  Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
1818  if (parser.parseOperand(cond) ||
1819  parser.resolveOperand(cond, i1Type, result.operands))
1820  return failure();
1821  // Parse optional results type list.
1822  if (parser.parseOptionalArrowTypeList(result.types))
1823  return failure();
1824  // Parse the 'then' region.
1825  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
1826  return failure();
1827 
1828  // If we find an 'else' keyword then parse the 'else' region.
1829  if (!parser.parseOptionalKeyword("else")) {
1830  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
1831  return failure();
1832  }
1833 
1834  // Parse the optional attribute list.
1835  if (parser.parseOptionalAttrDict(result.attributes))
1836  return failure();
1837  return success();
1838 }
1839 
1840 void IfOp::print(OpAsmPrinter &p) {
1841  bool printBlockTerminators = false;
1842 
1843  p << " " << getCond();
1844  if (!getResults().empty()) {
1845  p << " -> (" << getResultTypes() << ")";
1846  // Print yield explicitly if the op defines values.
1847  printBlockTerminators = true;
1848  }
1849  p << ' ';
1850  p.printRegion(getThenBranch(),
1851  /*printEntryBlockArgs=*/false,
1852  /*printBlockTerminators=*/printBlockTerminators);
1853 
1854  // Print the 'else' regions if it exists and has a block.
1855  auto &elseRegion = getElseBranch();
1856  if (!elseRegion.empty()) {
1857  p << " else ";
1858  p.printRegion(elseRegion,
1859  /*printEntryBlockArgs=*/false,
1860  /*printBlockTerminators=*/printBlockTerminators);
1861  }
1862 
1863  p.printOptionalAttrDict((*this)->getAttrs());
1864 }
1865 
1867  TensorType inputType = getInput().getType();
1868  TensorType outputType = getOutput().getType();
1869  int32_t reverseAxis = getAxis();
1870 
1871  if (reverseAxis < 0)
1872  return emitOpError("expected non-negative reverse axis");
1873  if (inputType.hasRank()) {
1874  int64_t inputRank = inputType.getRank();
1875  // We allow for a special case where the input/output shape has rank 0 and
1876  // axis is also 0.
1877  if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
1878  return emitOpError("expect input tensor rank (")
1879  << inputRank << ") to be larger than reverse axis (" << reverseAxis
1880  << ")";
1881  }
1882  if (outputType.hasRank()) {
1883  int64_t outputRank = outputType.getRank();
1884  if (inputType.hasRank() && outputRank != inputType.getRank())
1885  return emitOpError(
1886  "expect output tensor rank to be equal to input tensor rank");
1887  if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
1888  return emitOpError("expect output tensor rank (")
1889  << outputRank << ") to be larger than reverse axis ("
1890  << reverseAxis << ")";
1891  }
1892  return success();
1893 }
1894 
1895 // parse and print of WhileOp refer to the implementation of SCF dialect.
1899  Region *cond = result.addRegion();
1900  Region *body = result.addRegion();
1901 
1902  OptionalParseResult listResult =
1903  parser.parseOptionalAssignmentList(regionArgs, operands);
1904  if (listResult.has_value() && failed(listResult.value()))
1905  return failure();
1906 
1907  FunctionType functionType;
1908  SMLoc typeLoc = parser.getCurrentLocation();
1909  if (failed(parser.parseColonType(functionType)))
1910  return failure();
1911 
1912  result.addTypes(functionType.getResults());
1913 
1914  if (functionType.getNumInputs() != operands.size()) {
1915  return parser.emitError(typeLoc)
1916  << "expected as many input types as operands "
1917  << "(expected " << operands.size() << " got "
1918  << functionType.getNumInputs() << ")";
1919  }
1920 
1921  // Resolve input operands.
1922  if (failed(parser.resolveOperands(operands, functionType.getInputs(),
1923  parser.getCurrentLocation(),
1924  result.operands)))
1925  return failure();
1926 
1927  // Propagate the types into the region arguments.
1928  for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
1929  regionArgs[i].type = functionType.getInput(i);
1930 
1931  return failure(parser.parseRegion(*cond, regionArgs) ||
1932  parser.parseKeyword("do") || parser.parseRegion(*body) ||
1934 }
1935 
1937  Block::BlockArgListType blocksArgs,
1938  ValueRange initializers,
1939  StringRef prefix = "") {
1940  assert(blocksArgs.size() == initializers.size() &&
1941  "expected same length of arguments and initializers");
1942  if (initializers.empty())
1943  return;
1944 
1945  parser << prefix << '(';
1946  llvm::interleaveComma(
1947  llvm::zip(blocksArgs, initializers), parser,
1948  [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
1949  parser << ")";
1950 }
1951 
1952 void WhileOp::print(OpAsmPrinter &parser) {
1953  printInitializationList(parser, getCond().front().getArguments(), getInputs(),
1954  " ");
1955  parser << " : ";
1956  parser.printFunctionalType(getInputs().getTypes(), getResults().getTypes());
1957  parser << ' ';
1958  parser.printRegion(getCond(), /*printEntryBlockArgs=*/false);
1959  parser << " do ";
1960  parser.printRegion(getBody());
1961  parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
1962 }
1963 
1964 //===----------------------------------------------------------------------===//
1965 // TOSA Attribute Definitions.
1966 //===----------------------------------------------------------------------===//
1967 
1968 #define GET_ATTRDEF_CLASSES
1969 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
1970 
1971 //===----------------------------------------------------------------------===//
1972 // TOSA Operator Definitions.
1973 //===----------------------------------------------------------------------===//
1974 
1975 #define GET_OP_CLASSES
1976 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition: TosaOps.cpp:426
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:1218
#define REDUCE_SHAPE_INFER(OP)
Definition: TosaOps.cpp:1243
static bool hasZeroDimension(ShapedType shapedType)
Definition: TosaOps.cpp:207
static LogicalResult verifyConvOp(T op)
Definition: TosaOps.cpp:223
static void buildUnaryOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter unary operators that have scale relationship between their...
Definition: TosaOps.cpp:481
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:1387
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition: TosaOps.cpp:494
static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias)
The tosa.fully_connected op has its own builder as it does not have strides/dilation/padding.
Definition: TosaOps.cpp:407
static LogicalResult verifyReduceOp(T op)
Definition: TosaOps.cpp:1268
#define NARY_SHAPE_INFER(OP)
Definition: TosaOps.cpp:1336
static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings, Value padConst)
This builder is called on TOSA pad operator when an explicit pad_const value is passed in.
Definition: TosaOps.cpp:506
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:1324
#define COMPATIBLE_RETURN_TYPES(OP)
Definition: TosaOps.cpp:1234
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition: TosaOps.cpp:522
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition: TosaOps.cpp:386
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr but avg_pool operator has...
Definition: TosaOps.cpp:463
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition: TosaOps.cpp:819
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition: TosaOps.cpp:364
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Definition: TosaOps.cpp:1936
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast_or_null() const
Definition: Attributes.h:184
U dyn_cast() const
Definition: Attributes.h:179
IntegerType getI32Type()
Definition: Builders.cpp:83
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:93
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:91
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:58
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Definition: QuantUtils.cpp:120
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
Definition: QuantUtils.cpp:239
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
Definition: QuantUtils.cpp:219
ParseResult parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &attr)
Definition: TosaOps.cpp:164
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
Definition: QuantUtils.cpp:164
void printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, Attribute attr)
Definition: TosaOps.cpp:186
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Definition: QuantUtils.cpp:193
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
Definition: Diagnostics.h:491
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
Definition: ShapeUtils.h:33
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45