MLIR  16.0.0git
TosaOps.cpp
Go to the documentation of this file.
1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://developer.mlplatform.org/w/tosa/
12 //
13 //===----------------------------------------------------------------------===//
14 
20 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/TypeUtilities.h"
26 #include "llvm/ADT/DenseMap.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 
29 using namespace mlir;
30 using namespace mlir::tosa;
31 
32 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
33 
34 //===----------------------------------------------------------------------===//
35 // Tosa dialect interface includes.
36 //===----------------------------------------------------------------------===//
37 
38 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
39 
40 namespace {
41 //===----------------------------------------------------------------------===//
42 // Dialect Function Inliner Interface.
43 //===----------------------------------------------------------------------===//
44 struct TosaInlinerInterface : public DialectInlinerInterface {
46 
47  //===--------------------------------------------------------------------===//
48  // Analysis Hooks.
49  //===--------------------------------------------------------------------===//
50 
51  /// All operations can be inlined by default.
52  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
53  BlockAndValueMapping &map) const final {
54  return true;
55  }
56 
57  /// All regions with If and While parent operators can be inlined.
58  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
59  BlockAndValueMapping &map) const final {
60  return (isa<tosa::IfOp>(dest->getParentOp()) ||
61  isa<tosa::WhileOp>(dest->getParentOp()));
62  }
63 };
64 } // namespace
65 
66 //===----------------------------------------------------------------------===//
67 // TOSA control flow support.
68 //===----------------------------------------------------------------------===//
69 
70 /// Returns the while loop body.
71 Region &tosa::WhileOp::getLoopBody() { return getBody(); }
72 
73 //===----------------------------------------------------------------------===//
74 // Tosa dialect initialization.
75 //===----------------------------------------------------------------------===//
76 
77 void TosaDialect::initialize() {
78  addOperations<
79 #define GET_OP_LIST
80 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
81  >();
82  addAttributes<
83 #define GET_ATTRDEF_LIST
84 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
85  >();
86  addInterfaces<TosaInlinerInterface>();
87 }
88 
90  Type type, Location loc) {
91  // Tosa dialect constants only support ElementsAttr unlike standard dialect
92  // constant which supports all attributes.
93  if (value.isa<ElementsAttr>())
94  return builder.create<tosa::ConstOp>(loc, type, value.cast<ElementsAttr>());
95  return nullptr;
96 }
97 
98 //===----------------------------------------------------------------------===//
99 // TOSA Operator Verifiers.
100 //===----------------------------------------------------------------------===//
101 
102 template <typename T>
104  // All TOSA conv ops have an input() and weight().
105  auto inputType =
106  op.getInput().getType().template dyn_cast<RankedTensorType>();
107  auto weightType =
108  op.getWeight().getType().template dyn_cast<RankedTensorType>();
109 
110  // Must be ranked tensor types
111  if (!inputType) {
112  op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
113  return failure();
114  }
115  if (!weightType) {
116  op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
117  return failure();
118  }
119 
120  auto inputEType = inputType.getElementType();
121  auto weightEType = weightType.getElementType();
122 
123  bool inputIsQuant = !inputEType.template isa<FloatType>();
124  bool weightIsQuant = !weightEType.template isa<FloatType>();
125 
126  // Either both must be quantized or both unquantized.
127  if (inputIsQuant != weightIsQuant) {
128  op.emitOpError(
129  "expect both input and weight to be float or not together, got ")
130  << inputEType << " and " << weightEType;
131  return failure();
132  }
133 
134  // Quantized type must have constructed the quantizationattr, and unquantized
135  // types should not have a quantizationattr.
136  if ((inputIsQuant && !op.getQuantizationInfo()) ||
137  (!inputIsQuant && op.getQuantizationInfo())) {
138  op.emitOpError("quantizationattr is required for quantized type, and not "
139  "allowed for float type");
140  return failure();
141  }
142 
143  return success();
144 }
145 
147  auto inputETy = getInput().getType().cast<ShapedType>().getElementType();
148  auto resultETy = getType().cast<ShapedType>().getElementType();
149 
150  if (auto quantType = inputETy.dyn_cast<mlir::quant::UniformQuantizedType>())
151  inputETy = quantType.getStorageType();
152 
153  if (auto quantType = resultETy.dyn_cast<mlir::quant::UniformQuantizedType>())
154  resultETy = quantType.getStorageType();
155 
156  if (inputETy.isF32() && resultETy.isF32())
157  return success();
158  if (inputETy.isInteger(8) && resultETy.isInteger(8))
159  return success();
160  if (inputETy.isInteger(16) && resultETy.isInteger(16))
161  return success();
162 
163  return emitOpError("input/output element types are incompatible.");
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // TOSA Operator Quantization Builders.
168 //===----------------------------------------------------------------------===//
169 
170 /// This builder is called on all convolution operators except TransposeConv,
171 /// which has specialized output shape semantics. The builder also defines the
172 /// bitwidth of the output given the bit width of the input & weight content.
173 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
174  Type outputType, Value input, Value weight,
175  Value bias, ArrayAttr pad,
176  ArrayAttr stride, ArrayAttr dilation) {
177 
178  result.addOperands({input, weight, bias});
179  result.addAttribute("pad", pad);
180  result.addAttribute("stride", stride);
181  result.addAttribute("dilation", dilation);
182 
183  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
184  if (quantAttr) {
185  result.addAttribute("quantization_info", quantAttr);
186  result.addTypes(
187  buildConvOpResultTypeInfo(builder, outputType, input, weight));
188  } else {
189  result.addTypes(outputType);
190  }
191 }
192 
193 /// Handles tosa.transpose_conv2d which has outpad and output shape attributes.
195  OperationState &result,
196  Type outputType, Value input,
197  Value weight, Value bias,
198  ArrayAttr outpad, ArrayAttr stride,
199  ArrayAttr outputShape) {
200  result.addOperands({input, weight, bias});
201  result.addAttribute("out_pad", outpad);
202  result.addAttribute("stride", stride);
203  result.addAttribute("out_shape", outputShape);
204  auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
205 
206  if (quantAttr) {
207  result.addAttribute("quantization_info", quantAttr);
208  result.addTypes(
209  buildConvOpResultTypeInfo(builder, outputType, input, weight));
210  } else {
211  result.addTypes(outputType);
212  }
213 }
214 
215 /// The tosa.fully_connected op has its own builder as it does not have
216 /// strides/dilation/padding.
217 static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
218  Type outputType, Value input, Value weight,
219  Value bias) {
220 
221  result.addOperands({input, weight, bias});
222  auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
223  if (quantAttr) {
224  result.addAttribute("quantization_info", quantAttr);
225  result.addTypes(
226  buildConvOpResultTypeInfo(builder, outputType, input, weight));
227  } else {
228  result.addTypes(outputType);
229  }
230 }
231 
232 /// The tosa.matmul op is also intended to be generated where a fully_connected
233 /// op must be constructed where the weight is not a constant. In this case,
234 /// the fully_connected op must be expressed using matmul.
235 /// TODO: Add link to the leglization document explaining this.
237  OperationState &result, Type outputType,
238  Value a, Value b) {
239  result.addOperands({a, b});
240  auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
241 
242  if (quantAttr) {
243  result.addAttribute("quantization_info", quantAttr);
244 
245  auto inputType = a.getType().dyn_cast<ShapedType>();
246  assert(inputType && "Input must be a shaped tensor type!");
247 
248  auto inputQType = inputType.getElementType()
250  assert(inputQType && "Tensor must have quantized datatype!");
251 
252  unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
253 
254  auto outputShapedType = outputType.dyn_cast<ShapedType>();
255  assert(outputShapedType && "Output must be a shaped type");
256 
257  IntegerType accElementType;
258  if (inputBits == 16)
259  accElementType = builder.getIntegerType(48);
260  else
261  accElementType = builder.getI32Type();
262  auto accType = outputShapedType.clone(accElementType);
263  result.addTypes(accType);
264  } else {
265  result.addTypes(outputType);
266  }
267 }
268 
269 /// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr
270 /// but avg_pool operator has its own builder as it has additional parameters
271 /// not part of the unary ops.
273  OperationState &result,
274  Type outputType, Value input,
275  ArrayAttr kernel, ArrayAttr stride,
276  ArrayAttr pad) {
277  result.addOperands(input);
278  result.addAttribute("kernel", kernel);
279  result.addAttribute("stride", stride);
280  result.addAttribute("pad", pad);
281  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
282  if (quantAttr)
283  result.addAttribute("quantization_info", quantAttr);
284  result.types.push_back(outputType);
285 }
286 
287 /// This builder is called on single-parameter unary operators that have scale
288 /// relationship between their input and output, expressed by the
289 /// UnaryOpQuantizationAttr.
291  OperationState &result, Type outputType,
292  Value input) {
293  result.addOperands(input);
294  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
295  if (quantAttr)
296  result.addAttribute("quantization_info", quantAttr);
297  result.types.push_back(outputType);
298 }
299 
300 /// This builder is called on TOSA pad operator that needs to create its own
301 /// OptionalAttr quantization_attr parameter to scale the padding values
302 /// correctly. No pad_const is interpreted as zero-padding.
303 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
304  Type outputType, Value input,
305  Value paddings) {
306  result.addOperands({input, paddings});
307  auto quantAttr = buildPadOpQuantizationAttr(builder, input);
308  if (quantAttr)
309  result.addAttribute("quantization_info", quantAttr);
310  result.types.push_back(outputType);
311 }
312 
313 /// This builder is called on TOSA pad operator when an explicit pad_const
314 /// value is passed in. It also optionally constructs quantization_attr.
316  OperationState &result,
317  Type outputType, Value input,
318  Value paddings,
319  Value padConst) {
320  result.addOperands({input, paddings, padConst});
321  auto quantAttr = buildPadOpQuantizationAttr(builder, input);
322  if (quantAttr)
323  result.addAttribute("quantization_info", quantAttr);
324  result.types.push_back(outputType);
325 }
326 
327 //===----------------------------------------------------------------------===//
328 // TOSA Operator Return Type Inference.
329 //===----------------------------------------------------------------------===//
330 
331 static void getI64Values(ArrayAttr arrayAttr, SmallVector<int64_t> &values) {
332  for (auto it : arrayAttr) {
333  values.push_back(it.cast<IntegerAttr>().getValue().getSExtValue());
334  }
335 }
336 
337 static void getF64Values(ArrayAttr arrayAttr, SmallVector<double> &values) {
338  for (auto it : arrayAttr) {
339  values.push_back(it.cast<FloatAttr>().getValueAsDouble());
340  }
341 }
342 
344  SmallVector<int64_t> &outShape) {
345  int64_t outRank = 0;
346  for (int i = 0, e = operands.size(); i != e; ++i) {
347  auto shape = operands.getShape(i);
348  if (!shape.hasRank()) {
349  // TODO(jennik): Update function to have better case handling for invalid
350  // operands and for ranked tensors.
351  return failure();
352  }
353  outRank = std::max<int64_t>(outRank, shape.getRank());
354  }
355 
356  outShape.resize(outRank, 1);
357 
358  for (int i = 0, e = operands.size(); i != e; ++i) {
359  auto shape = operands.getShape(i);
360  auto rankDiff = outShape.size() - shape.getRank();
361 
362  for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
363  auto dim1 = outShape[i + rankDiff];
364  auto dim2 = shape.getDimSize(i);
365  auto resolvedDim = dim1;
366 
367  if (dim1 == 1) {
368  resolvedDim = dim2;
369  } else if (dim2 == 1) {
370  resolvedDim = dim1;
371  } else if (dim1 != dim2) {
372  return failure();
373  }
374  outShape[i + rankDiff] = resolvedDim;
375  }
376  }
377 
378  return success();
379 }
380 
381 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
382  MLIRContext *context, ::llvm::Optional<Location> location,
383  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
384  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
385  ShapeAdaptor inputShape = operands.getShape(0);
386  IntegerAttr axis = attributes.get("axis").cast<IntegerAttr>();
387  int32_t axisVal = axis.getValue().getSExtValue();
388 
389  if (!inputShape.hasRank()) {
390  inferredReturnShapes.push_back(ShapedTypeComponents());
391  return success();
392  }
393 
394  SmallVector<int64_t> outShape;
395  outShape.reserve(inputShape.getRank() - 1);
396  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
397  if (i == axisVal)
398  continue;
399  outShape.push_back(inputShape.getDimSize(i));
400  }
401 
402  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
403  return success();
404 }
405 
406 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
407  MLIRContext *context, ::llvm::Optional<Location> location,
408  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
409  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
410  // Infer all dimension sizes by reducing based on inputs.
411  int32_t axis =
412  attributes.get("axis").cast<IntegerAttr>().getValue().getSExtValue();
413  llvm::SmallVector<int64_t> outputShape;
414  bool hasRankedInput = false;
415  for (auto operand : operands) {
416  ShapeAdaptor operandShape = operands.getShape(operand);
417  if (!operandShape.hasRank())
418  continue;
419 
420  // Copy the Operand's rank.
421  if (!hasRankedInput)
422  outputShape.resize(operandShape.getRank(), ShapedType::kDynamicSize);
423 
424  // Copy shapes until the dim is non-dynamic.
425  for (int i = 0, s = operandShape.getRank(); i < s; i++) {
426  if (i == axis || operandShape.isDynamicDim(i))
427  continue;
428  if (outputShape[i] == ShapedType::kDynamicSize)
429  outputShape[i] = operandShape.getDimSize(i);
430  if (outputShape[i] != operandShape.getDimSize(i))
431  return failure();
432  }
433 
434  hasRankedInput = true;
435  }
436 
437  if (!hasRankedInput) {
438  inferredReturnShapes.push_back(ShapedTypeComponents());
439  return success();
440  }
441 
442  // Determine the dimension size along the concatenation axis.
443  int concatDimSize = 0;
444  for (auto operand : operands) {
445  ShapeAdaptor operandShape = operands.getShape(operand);
446 
447  // We need to know the length of the concatenation axis of all inputs to
448  // determine the dimension size of the output shape.
449  if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
450  concatDimSize = ShapedType::kDynamicSize;
451  break;
452  }
453 
454  concatDimSize += operandShape.getDimSize(axis);
455  }
456 
457  outputShape[axis] = concatDimSize;
458 
459  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
460  return success();
461 }
462 
463 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
464  MLIRContext *context, ::llvm::Optional<Location> location,
465  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
466  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
468  if (resolveBroadcastShape(operands, outShape).failed()) {
469  inferredReturnShapes.push_back(ShapedTypeComponents());
470  return success();
471  }
472 
473  inferredReturnShapes.push_back(
474  ShapedTypeComponents(outShape, IntegerType::get(context, /*width=*/1)));
475  return success();
476 }
477 
478 bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
479  if (l.size() != r.size() || l.size() != 1)
480  return false;
481  return succeeded(verifyCompatibleShape(l[0], r[0]));
482 }
483 
484 LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
485  MLIRContext *context, ::llvm::Optional<Location> location,
486  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
487  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
488  ShapeAdaptor inputShape = operands.getShape(0);
489  ShapeAdaptor weightShape = operands.getShape(1);
490  ShapeAdaptor biasShape = operands.getShape(2);
491 
492  // All shapes are dynamic.
493  SmallVector<int64_t> outShape;
494  outShape.resize(2, ShapedType::kDynamicSize);
495 
496  if (inputShape.hasRank()) {
497  outShape[0] = inputShape.getDimSize(0);
498  }
499 
500  if (weightShape.hasRank()) {
501  outShape[1] = weightShape.getDimSize(0);
502  }
503 
504  if (biasShape.hasRank()) {
505  outShape[1] = outShape[1] == ShapedType::kDynamicSize
506  ? biasShape.getDimSize(0)
507  : outShape[1];
508  }
509 
510  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
511  return success();
512 }
513 
515 
516 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
517  MLIRContext *context, ::llvm::Optional<Location> location,
518  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
519  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
520  ShapeAdaptor lhsShape = operands.getShape(0);
521  ShapeAdaptor rhsShape = operands.getShape(1);
522 
523  // All shapes are dynamic.
524  SmallVector<int64_t> outShape;
525  outShape.resize(3, ShapedType::kDynamicSize);
526 
527  if (lhsShape.hasRank()) {
528  outShape[0] = lhsShape.getDimSize(0);
529  outShape[1] = lhsShape.getDimSize(1);
530  }
531 
532  if (rhsShape.hasRank()) {
533  outShape[0] = outShape[0] == ShapedType::kDynamicSize
534  ? rhsShape.getDimSize(0)
535  : outShape[0];
536  outShape[2] = rhsShape.getDimSize(2);
537  }
538 
539  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
540  return success();
541 }
542 
543 LogicalResult tosa::PadOp::inferReturnTypeComponents(
544  MLIRContext *context, ::llvm::Optional<Location> location,
545  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
546  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
547  ShapeAdaptor inputShape = operands.getShape(0);
548  ShapeAdaptor paddingShape = operands.getShape(1);
549  SmallVector<int64_t> outputShape;
550 
551  // If both inputs have unknown shape, we cannot determine the shape of the
552  // output.
553  if (!inputShape.hasRank() && !paddingShape.hasRank()) {
554  inferredReturnShapes.push_back(ShapedTypeComponents());
555  return success();
556  }
557 
558  // If the input rank is unknown we can info the output rank using the padding
559  // shape's first dim.
560  if (!inputShape.hasRank()) {
561  if (paddingShape.isDynamicDim(0)) {
562  inferredReturnShapes.push_back(ShapedTypeComponents());
563  return success();
564  }
565 
566  outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamicSize);
567  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
568  return success();
569  }
570 
571  DenseIntElementsAttr paddings;
572  // If the paddings value is not a constant, all dimensions must be dynamic.
573  if (!matchPattern(operands[1], m_Constant(&paddings))) {
574  outputShape.resize(inputShape.getRank(), ShapedType::kDynamicSize);
575  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
576  return success();
577  }
578 
579  SmallVector<int64_t> paddingValues;
580  for (auto val : paddings) {
581  paddingValues.push_back(val.getSExtValue());
582  }
583 
584  outputShape.reserve(inputShape.getRank());
585  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
586  if (inputShape.isDynamicDim(i)) {
587  outputShape.push_back(ShapedType::kDynamicSize);
588  continue;
589  }
590 
591  outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] +
592  paddingValues[i * 2 + 1]);
593  }
594 
595  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
596  return success();
597 }
598 
599 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
600  MLIRContext *context, ::llvm::Optional<Location> location,
601  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
602  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
603  ArrayAttr sizes = SliceOpAdaptor(operands, attributes).getSize();
604  SmallVector<int64_t> outputShape;
605  outputShape.reserve(sizes.size());
606  for (auto val : sizes) {
607  outputShape.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
608  }
609 
610  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
611  return success();
612 }
613 
614 LogicalResult tosa::TableOp::inferReturnTypeComponents(
615  MLIRContext *context, ::llvm::Optional<Location> location,
616  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
617  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
618  ShapeAdaptor inputShape = operands.getShape(0);
619 
620  if (!inputShape.hasRank()) {
621  inferredReturnShapes.push_back(ShapedTypeComponents());
622  return success();
623  }
624 
625  inferredReturnShapes.resize(1);
626  inputShape.getDims(inferredReturnShapes[0]);
627  return success();
628 }
629 
630 LogicalResult tosa::TileOp::inferReturnTypeComponents(
631  MLIRContext *context, ::llvm::Optional<Location> location,
632  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
633  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
634  TileOpAdaptor adaptor(operands, attributes);
635  ArrayAttr multiples = adaptor.getMultiples();
636  ShapeAdaptor inputShape = operands.getShape(0);
637  SmallVector<int64_t> outputShape;
638  if (!inputShape.hasRank()) {
639  outputShape.resize(multiples.size(), ShapedType::kDynamicSize);
640  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
641  return success();
642  }
643 
644  // We need the multiple values to determine the output shape.
645  SmallVector<int64_t> multipleValues;
646  multipleValues.reserve(multiples.size());
647  for (auto val : multiples) {
648  multipleValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
649  }
650 
651  // Any non dynamic dimension can be multiplied to a known size.
652  outputShape.reserve(multiples.size());
653  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
654  int dim = inputShape.getDimSize(i);
655  if (dim != ShapedType::kDynamicSize)
656  dim *= multipleValues[i];
657  outputShape.push_back(dim);
658  }
659 
660  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
661  return success();
662 }
663 
664 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
665  MLIRContext *context, ::llvm::Optional<Location> location,
666  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
667  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
668  ReshapeOpAdaptor adaptor(operands, attributes);
669  ShapeAdaptor inputShape = operands.getShape(0);
670 
671  ArrayAttr newShape = adaptor.getNewShape();
672  llvm::SmallVector<int64_t> newShapeValue;
673  getI64Values(newShape, newShapeValue);
674 
675  // We cannot infer from the total number of elements so we must take the
676  // shape attribute as exact.
677  if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
678  inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
679  return success();
680  }
681 
682  // Determine the number of elements covered by the slice of all static
683  // dimensions. This allows us to infer the length of the remaining dynamic
684  // dimension.
685  int64_t numElements = inputShape.getNumElements();
686  int64_t staticMul = 1;
687  for (auto val : newShapeValue) {
688  if (val != ShapedType::kDynamicSize) {
689  staticMul *= val;
690  }
691  }
692 
693  // Determine the length of the dynamic dimension.
694  for (auto &val : newShapeValue) {
695  if (val == ShapedType::kDynamicSize)
696  val = numElements / staticMul;
697  }
698 
699  inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
700  return success();
701 }
702 
703 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
704  MLIRContext *context, ::llvm::Optional<Location> location,
705  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
706  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
707  ShapeAdaptor inputShape = operands.getShape(0);
708  ShapeAdaptor permsShape = operands.getShape(1);
709 
710  // If input rank and permutation length is unknown, the output rank is
711  // unknown.
712  if (!inputShape.hasRank() || !permsShape.hasRank() ||
713  permsShape.isDynamicDim(0)) {
714  inferredReturnShapes.push_back(ShapedTypeComponents());
715  return success();
716  }
717 
718  // This would imply the number of permutations does not match the rank of the
719  // input which is illegal.
720  if (permsShape.getDimSize(0) != inputShape.getRank()) {
721  return failure();
722  }
723 
724  // Without the input dims we cannot determine the output dim sizes but we
725  // can determine the output rank.
726  SmallVector<int64_t> outputShape;
727  if (!inputShape.hasRank()) {
728  outputShape.resize(permsShape.getDimSize(0), ShapedType::kDynamicSize);
729  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
730  return success();
731  }
732 
733  // Rank-0 means no permutations matter.
734  if (inputShape.getRank() == 0) {
735  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
736  return success();
737  }
738 
739  // Check whether the input dimensions are all the same.
740  bool allTheSame = true;
741  for (int i = 1, s = inputShape.getRank(); i < s; i++) {
742  if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
743  allTheSame = false;
744  break;
745  }
746  }
747 
748  // If all of the input dimensions are the same we don't care about the
749  // permutation.
750  if (allTheSame) {
751  outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
752  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
753  return success();
754  }
755 
756  outputShape.resize(inputShape.getRank(), ShapedType::kDynamicSize);
757  // If the permuations are a constant we can directly determine the output
758  // shape.
759  if (ShapeAdaptor permShape = operands.getValueAsShape(1)) {
760  outputShape.reserve(inputShape.getRank());
761  for (int i = 0, s = inputShape.getRank(); i < s; i++) {
762  outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
763  }
764  }
765 
766  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
767  return success();
768 }
769 
770 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
771  MLIRContext *context, ::llvm::Optional<Location> location,
772  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
773  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
774  llvm::SmallVector<int64_t> outputShape;
775  outputShape.resize(3, ShapedType::kDynamicSize);
776 
777  ShapeAdaptor valuesShape = operands.getShape(0);
778  if (valuesShape.hasRank()) {
779  outputShape[0] = valuesShape.getDimSize(0);
780  outputShape[2] = valuesShape.getDimSize(2);
781  }
782 
783  ShapeAdaptor indicesShape = operands.getShape(1);
784  if (indicesShape.hasRank()) {
785  if (outputShape[0] == ShapedType::kDynamicSize)
786  outputShape[0] = indicesShape.getDimSize(0);
787  if (outputShape[1] == ShapedType::kDynamicSize)
788  outputShape[1] = indicesShape.getDimSize(1);
789  }
790 
791  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
792  return success();
793 }
794 
795 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
796  MLIRContext *context, ::llvm::Optional<Location> location,
797  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
798  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
799  ResizeOpAdaptor adaptor(operands, attributes);
800  llvm::SmallVector<int64_t, 4> outputShape;
801  outputShape.resize(4, ShapedType::kDynamicSize);
802 
803  int32_t inHeight = ShapedType::kDynamicSize;
804  int32_t inWidth = ShapedType::kDynamicSize;
805 
806  ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
807  if (inputShape.hasRank()) {
808  outputShape[0] = inputShape.getDimSize(0);
809  outputShape[3] = inputShape.getDimSize(3);
810 
811  inHeight = inputShape.getDimSize(1);
812  inWidth = inputShape.getDimSize(2);
813  }
814 
815  int32_t shift = adaptor.getShift();
817  getI64Values(adaptor.getOutputSize(), newShape);
818  outputShape[1] = newShape[0];
819  outputShape[2] = newShape[1];
820 
821  llvm::SmallVector<int64_t> strideInt;
822  llvm::SmallVector<int64_t> offsetInt;
823  llvm::SmallVector<double> strideFp;
824  llvm::SmallVector<double> offsetFp;
825  getI64Values(adaptor.getOffset(), offsetInt);
826  getF64Values(adaptor.getOffsetFp(), offsetFp);
827  getI64Values(adaptor.getStride(), strideInt);
828  getF64Values(adaptor.getStrideFp(), strideFp);
829 
830  // If we have a 0 zero in integers we know that the resize indexing needs to
831  // be performed in floating point. Use the floating point varient to compute
832  // the resize shape.
833  bool fpMode = strideInt[0] == 0;
834 
835  // We can compute the output shape if attribute specifies unknown dimensions
836  // based on the offset and stride. If we perfectly line up to the last index
837  // we need to round up the size to include it.
838  if (outputShape[1] == ShapedType::kDynamicSize && inHeight >= 0 && fpMode) {
839  float sizeFp = (inHeight - offsetFp[0] - 1) / strideFp[0];
840  float round = std::floor(sizeFp) == sizeFp ? 1 : 0;
841  outputShape[1] = std::ceil(sizeFp) + round;
842  }
843 
844  if (outputShape[2] == ShapedType::kDynamicSize && inWidth >= 0 && fpMode) {
845  float sizeFp = (inWidth - offsetFp[1] - 1) / strideFp[1];
846  float round = std::floor(sizeFp) == sizeFp ? 1 : 0;
847  outputShape[2] = std::ceil(sizeFp) + round;
848  }
849 
850  if (outputShape[1] == ShapedType::kDynamicSize && inHeight >= 0 && !fpMode) {
851  int64_t size = (inHeight - 1);
852  size = ((size << shift) - offsetInt[0]) / strideInt[0];
853  outputShape[1] = size + 1;
854  }
855 
856  if (outputShape[2] == ShapedType::kDynamicSize && inWidth >= 0 && !fpMode) {
857  int64_t size = (inWidth - 1);
858  size = ((size << shift) - offsetInt[1]) / strideInt[1];
859  outputShape[2] = size + 1;
860  }
861 
862  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
863  return success();
864 }
865 
866 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
867  MLIRContext *context, ::llvm::Optional<Location> location,
868  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
869  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
870  llvm::SmallVector<int64_t> outputShape;
871  outputShape.resize(3, ShapedType::kDynamicSize);
872 
873  ShapeAdaptor valuesInShape = operands.getShape(0);
874  if (valuesInShape.hasRank()) {
875  outputShape[0] = valuesInShape.getDimSize(0);
876  outputShape[1] = valuesInShape.getDimSize(1);
877  outputShape[2] = valuesInShape.getDimSize(2);
878  }
879 
880  ShapeAdaptor indicesShape = operands.getShape(1);
881  if (indicesShape.hasRank()) {
882  if (outputShape[0] == ShapedType::kDynamicSize)
883  outputShape[0] = indicesShape.getDimSize(0);
884  }
885 
886  ShapeAdaptor inputShape = operands.getShape(2);
887  if (inputShape.hasRank()) {
888  if (outputShape[0] == ShapedType::kDynamicSize)
889  outputShape[0] = inputShape.getDimSize(0);
890  if (outputShape[2] == ShapedType::kDynamicSize)
891  outputShape[2] = inputShape.getDimSize(2);
892  }
893 
894  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
895  return success();
896 }
897 
899  ShapeAdaptor operandShape, IntegerAttr axis,
900  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
901  if (!operandShape.hasRank()) {
902  inferredReturnShapes.push_back(ShapedTypeComponents());
903  return success();
904  }
905 
906  SmallVector<int64_t> outputShape;
907  operandShape.getDims(outputShape);
908  int64_t axisVal = axis.getValue().getSExtValue();
909  outputShape[axisVal] = 1;
910  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
911  return success();
912 }
913 
914 #define REDUCE_SHAPE_INFER(OP) \
915  LogicalResult OP::inferReturnTypeComponents( \
916  MLIRContext *context, ::llvm::Optional<Location> location, \
917  ValueShapeRange operands, DictionaryAttr attributes, \
918  RegionRange regions, \
919  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
920  return ReduceInferReturnTypes(operands.getShape(0), \
921  attributes.get("axis").cast<IntegerAttr>(), \
922  inferredReturnShapes); \
923  }
924 
925 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
926 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
927 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
928 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
929 REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
930 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
931 #undef REDUCE_SHAPE_INFER
932 
934  const ValueShapeRange &operands,
935  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
937  if (resolveBroadcastShape(operands, outShape).failed()) {
938  inferredReturnShapes.push_back(ShapedTypeComponents());
939  } else {
940  inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
941  }
942  return success();
943 }
944 
945 #define NARY_SHAPE_INFER(OP) \
946  LogicalResult OP::inferReturnTypeComponents( \
947  MLIRContext *context, ::llvm::Optional<Location> location, \
948  ValueShapeRange operands, DictionaryAttr attributes, \
949  RegionRange regions, \
950  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
951  return NAryInferReturnTypes(operands, inferredReturnShapes); \
952  }
953 
954 NARY_SHAPE_INFER(tosa::AbsOp)
955 NARY_SHAPE_INFER(tosa::AddOp)
956 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
957 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
958 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
959 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
960 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
961 NARY_SHAPE_INFER(tosa::CastOp)
962 NARY_SHAPE_INFER(tosa::CeilOp)
963 NARY_SHAPE_INFER(tosa::ClampOp)
964 NARY_SHAPE_INFER(tosa::ClzOp)
965 NARY_SHAPE_INFER(tosa::DivOp)
966 NARY_SHAPE_INFER(tosa::ExpOp)
967 NARY_SHAPE_INFER(tosa::FloorOp)
968 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
969 NARY_SHAPE_INFER(tosa::GreaterOp)
970 NARY_SHAPE_INFER(tosa::IdentityOp)
971 NARY_SHAPE_INFER(tosa::LogOp)
972 NARY_SHAPE_INFER(tosa::LogicalAndOp)
973 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
974 NARY_SHAPE_INFER(tosa::LogicalNotOp)
975 NARY_SHAPE_INFER(tosa::LogicalOrOp)
976 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
977 NARY_SHAPE_INFER(tosa::LogicalXorOp)
978 NARY_SHAPE_INFER(tosa::MaximumOp)
979 NARY_SHAPE_INFER(tosa::MinimumOp)
980 NARY_SHAPE_INFER(tosa::MulOp)
981 NARY_SHAPE_INFER(tosa::NegateOp)
982 NARY_SHAPE_INFER(tosa::PowOp)
983 NARY_SHAPE_INFER(tosa::ReciprocalOp)
984 NARY_SHAPE_INFER(tosa::RescaleOp)
985 NARY_SHAPE_INFER(tosa::ReverseOp)
986 NARY_SHAPE_INFER(tosa::RsqrtOp)
987 NARY_SHAPE_INFER(tosa::SelectOp)
988 NARY_SHAPE_INFER(tosa::SubOp)
989 NARY_SHAPE_INFER(tosa::TanhOp)
990 NARY_SHAPE_INFER(tosa::SigmoidOp)
991 #undef PRED_SHAPE_INFER
992 
994  const ValueShapeRange &operands, DictionaryAttr attributes,
995  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
996  ShapeAdaptor inputShape = operands.getShape(0);
997  llvm::SmallVector<int64_t> outputShape;
998  outputShape.resize(4, -1);
999 
1000  // We only know the rank if the input type is unranked.
1001  if (!inputShape) {
1002  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1003  return success();
1004  }
1005 
1006  // Batch and number of channels are identical for pooling layer.
1007  outputShape[0] = inputShape.getDimSize(0);
1008  outputShape[3] = inputShape.getDimSize(3);
1009 
1010  int32_t height = inputShape.getDimSize(1);
1011  int32_t width = inputShape.getDimSize(2);
1012 
1016 
1017  getI64Values(attributes.get("kernel").cast<ArrayAttr>(), kernel);
1018  getI64Values(attributes.get("stride").cast<ArrayAttr>(), stride);
1019  getI64Values(attributes.get("pad").cast<ArrayAttr>(), pad);
1020 
1021  if (height != -1) {
1022  int32_t padded = height + pad[0] + pad[1] - kernel[0];
1023  outputShape[1] = padded / stride[0] + 1;
1024  }
1025 
1026  if (width != -1) {
1027  int32_t padded = width + pad[2] + pad[3] - kernel[1];
1028  outputShape[2] = padded / stride[1] + 1;
1029  }
1030 
1031  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1032  return success();
1033 }
1034 
1035 LogicalResult Conv2DOp::inferReturnTypeComponents(
1036  MLIRContext *context, ::llvm::Optional<Location> location,
1037  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1038  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1039  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamicSize);
1040  Conv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1041 
1042  int32_t inputWidth = ShapedType::kDynamicSize;
1043  int32_t inputHeight = ShapedType::kDynamicSize;
1044  int32_t weightWidth = ShapedType::kDynamicSize;
1045  int32_t weightHeight = ShapedType::kDynamicSize;
1046 
1047  // Input shape describes input width/height and batch.
1048 
1049  ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
1050  if (inputShape.hasRank()) {
1051  outputShape[0] = inputShape.getDimSize(0);
1052  inputHeight = inputShape.getDimSize(1);
1053  inputWidth = inputShape.getDimSize(2);
1054  }
1055 
1056  // Weight shapes describes the filter width/height and the output channels.
1057  ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
1058  if (weightShape.hasRank()) {
1059  outputShape[3] = weightShape.getDimSize(0);
1060  weightHeight = weightShape.getDimSize(1);
1061  weightWidth = weightShape.getDimSize(2);
1062  }
1063 
1064  // Bias shape can describe the output channels.
1065  ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
1066  if (biasShape.hasRank()) {
1067  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1068  ? biasShape.getDimSize(0)
1069  : outputShape[3];
1070  }
1071 
1072  llvm::SmallVector<int64_t> dilation;
1075 
1076  getI64Values(adaptor.getDilation(), dilation);
1077  getI64Values(adaptor.getPad(), padding);
1078  getI64Values(adaptor.getStride(), stride);
1079 
1080  if (!ShapedType::isDynamic(inputHeight) &&
1081  !ShapedType::isDynamic(weightHeight)) {
1082  int32_t inputSize = inputHeight + padding[0] + padding[1];
1083  int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1084  int32_t unstridedResult = inputSize - filterSize + 1;
1085  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1086  }
1087 
1088  if (!ShapedType::isDynamic(inputWidth) &&
1089  !ShapedType::isDynamic(weightWidth)) {
1090  int32_t inputSize = inputWidth + padding[2] + padding[3];
1091  int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1092  int32_t unstridedResult = inputSize - filterSize + 1;
1093  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1094  }
1095 
1096  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1097  return success();
1098 }
1099 
1100 LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); }
1101 
1102 LogicalResult Conv3DOp::inferReturnTypeComponents(
1103  MLIRContext *context, ::llvm::Optional<Location> location,
1104  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1105  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1106  llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamicSize);
1107  Conv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1108 
1109  int32_t inputWidth = ShapedType::kDynamicSize;
1110  int32_t inputHeight = ShapedType::kDynamicSize;
1111  int32_t inputDepth = ShapedType::kDynamicSize;
1112 
1113  int32_t weightWidth = ShapedType::kDynamicSize;
1114  int32_t weightHeight = ShapedType::kDynamicSize;
1115  int32_t weightDepth = ShapedType::kDynamicSize;
1116 
1117  // Input shape describes input width/height and batch.
1118  ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
1119  if (inputShape.hasRank()) {
1120  outputShape[0] = inputShape.getDimSize(0);
1121  inputHeight = inputShape.getDimSize(1);
1122  inputWidth = inputShape.getDimSize(2);
1123  inputDepth = inputShape.getDimSize(3);
1124  }
1125 
1126  // Weight shapes describes the filter width/height and the output channels.
1127  ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
1128  if (weightShape.hasRank()) {
1129  outputShape[4] = weightShape.getDimSize(0);
1130  weightHeight = weightShape.getDimSize(1);
1131  weightWidth = weightShape.getDimSize(2);
1132  weightDepth = weightShape.getDimSize(3);
1133  }
1134 
1135  // Bias shape can describe the output channels.
1136  ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
1137  if (biasShape.hasRank()) {
1138  outputShape[4] =
1139  (outputShape[4] == -1) ? biasShape.getDimSize(0) : outputShape[4];
1140  }
1141 
1142  llvm::SmallVector<int64_t> dilation;
1145 
1146  getI64Values(adaptor.getDilation(), dilation);
1147  getI64Values(adaptor.getPad(), padding);
1148  getI64Values(adaptor.getStride(), stride);
1149 
1150  if (!ShapedType::isDynamic(inputHeight) &&
1151  !ShapedType::isDynamic(weightHeight)) {
1152  int32_t inputSize = inputHeight + padding[0] + padding[1];
1153  int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1154  int32_t unstridedResult = inputSize - filterSize + 1;
1155  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1156  }
1157 
1158  if (!ShapedType::isDynamic(inputWidth) &&
1159  !ShapedType::isDynamic(weightWidth)) {
1160  int32_t inputSize = inputWidth + padding[2] + padding[3];
1161  int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1162  int32_t unstridedResult = inputSize - filterSize + 1;
1163  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1164  }
1165 
1166  if (!ShapedType::isDynamic(inputDepth) &&
1167  !ShapedType::isDynamic(weightDepth)) {
1168  int32_t inputSize = inputDepth + padding[4] + padding[5];
1169  int32_t filterSize = (weightDepth - 1) * dilation[2] + 1;
1170  int32_t unstridedResult = inputSize - filterSize + 1;
1171  outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
1172  }
1173 
1174  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1175  return success();
1176 }
1177 
1178 LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); }
1179 
1180 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
1181  MLIRContext *context, ::llvm::Optional<Location> location,
1182  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1183  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1184  return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
1185 }
1186 
1187 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
1188  MLIRContext *context, ::llvm::Optional<Location> location,
1189  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1190  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1191  return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
1192 }
1193 
1194 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
1195  MLIRContext *context, ::llvm::Optional<Location> location,
1196  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1197  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1198  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamicSize);
1199  DepthwiseConv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1200 
1201  int32_t inputWidth = ShapedType::kDynamicSize;
1202  int32_t inputHeight = ShapedType::kDynamicSize;
1203  int32_t inputChannels = ShapedType::kDynamicSize;
1204 
1205  int32_t weightWidth = ShapedType::kDynamicSize;
1206  int32_t weightHeight = ShapedType::kDynamicSize;
1207  int32_t depthChannels = ShapedType::kDynamicSize;
1208 
1209  // Input shape describes input width/height and batch.
1210  ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
1211  if (inputShape.hasRank()) {
1212  outputShape[0] = inputShape.getDimSize(0);
1213  inputHeight = inputShape.getDimSize(1);
1214  inputWidth = inputShape.getDimSize(2);
1215  inputChannels = inputShape.getDimSize(3);
1216  }
1217 
1218  // Weight shapes describes the filter width/height and the output channels.
1219  ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
1220  if (weightShape.hasRank()) {
1221  weightHeight = weightShape.getDimSize(0);
1222  weightWidth = weightShape.getDimSize(1);
1223  inputChannels = ShapedType::isDynamic(inputChannels)
1224  ? weightShape.getDimSize(2)
1225  : inputChannels;
1226  depthChannels = weightShape.getDimSize(3);
1227  }
1228 
1229  // If both inputChannels and depthChannels are available we can determine
1230  // the output channels.
1231  if (!ShapedType::isDynamic(inputChannels) &&
1232  !ShapedType::isDynamic(depthChannels)) {
1233  outputShape[3] = inputChannels * depthChannels;
1234  }
1235 
1236  // Bias shape can describe the output channels.
1237  ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
1238  if (biasShape.hasRank()) {
1239  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1240  ? biasShape.getDimSize(0)
1241  : outputShape[3];
1242  }
1243 
1244  llvm::SmallVector<int64_t> dilation;
1247 
1248  getI64Values(adaptor.getDilation(), dilation);
1249  getI64Values(adaptor.getPad(), padding);
1250  getI64Values(adaptor.getStride(), stride);
1251 
1252  if (!ShapedType::isDynamic(inputHeight) &&
1253  !ShapedType::isDynamic(weightHeight)) {
1254  int32_t inputSize = inputHeight + padding[0] + padding[1];
1255  int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1256  int32_t unstridedResult = inputSize - filterSize + 1;
1257  outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1258  }
1259 
1260  if (!ShapedType::isDynamic(inputWidth) &&
1261  !ShapedType::isDynamic(weightWidth)) {
1262  int32_t inputSize = inputWidth + padding[2] + padding[3];
1263  int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1264  int32_t unstridedResult = inputSize - filterSize + 1;
1265  outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1266  }
1267 
1268  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1269  return success();
1270 }
1271 
1273 
1274 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
1275  MLIRContext *context, ::llvm::Optional<Location> location,
1276  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1277  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1278  TransposeConv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1279  llvm::SmallVector<int64_t> outputShape;
1280  getI64Values(adaptor.getOutShape(), outputShape);
1281 
1282  int32_t inputWidth = ShapedType::kDynamicSize;
1283  int32_t inputHeight = ShapedType::kDynamicSize;
1284  int32_t weightWidth = ShapedType::kDynamicSize;
1285  int32_t weightHeight = ShapedType::kDynamicSize;
1286 
1287  // Input shape describes input width/height and batch.
1288  ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
1289  if (inputShape.hasRank()) {
1290  outputShape[0] = ShapedType::isDynamic(outputShape[0])
1291  ? inputShape.getDimSize(0)
1292  : outputShape[0];
1293  inputHeight = inputShape.getDimSize(1);
1294  inputWidth = inputShape.getDimSize(2);
1295  }
1296 
1297  // Weight shapes describes the filter width/height and the output channels.
1298  ShapeAdaptor weightShape = operands.getShape(adaptor.getFilter());
1299  if (weightShape.hasRank()) {
1300  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1301  ? weightShape.getDimSize(0)
1302  : outputShape[3];
1303  weightHeight = weightShape.getDimSize(1);
1304  weightWidth = weightShape.getDimSize(2);
1305  }
1306 
1307  // Bias shape can describe the output channels.
1308  ShapeAdaptor biasShape = operands.getShape(adaptor.getInput());
1309  if (biasShape.hasRank()) {
1310  outputShape[3] = ShapedType::isDynamic(outputShape[3])
1311  ? biasShape.getDimSize(0)
1312  : outputShape[3];
1313  }
1314 
1317 
1318  getI64Values(adaptor.getOutPad(), padding);
1319  getI64Values(adaptor.getStride(), stride);
1320 
1321  if (!ShapedType::isDynamic(inputHeight) &&
1322  !ShapedType::isDynamic(weightHeight)) {
1323  int32_t calculateSize =
1324  (inputHeight - 1) * stride[0] - padding[0] - padding[1] + weightHeight;
1325  outputShape[1] = outputShape[1] == -1 ? calculateSize : outputShape[1];
1326  }
1327 
1328  if (!ShapedType::isDynamic(inputWidth) &&
1329  !ShapedType::isDynamic(weightWidth)) {
1330  int32_t calculateSize =
1331  (inputWidth - 1) * stride[1] - padding[2] - padding[3] + weightWidth;
1332  outputShape[2] = outputShape[2] == -1 ? calculateSize : outputShape[2];
1333  }
1334 
1335  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1336  return success();
1337 }
1338 
1339 LogicalResult IfOp::inferReturnTypeComponents(
1340  MLIRContext *context, ::llvm::Optional<Location> location,
1341  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1342  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1344  for (Region *region : regions) {
1345  for (auto &block : *region)
1346  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1347  yieldOps.push_back(returnOp);
1348  }
1349 
1350  if (yieldOps.empty())
1351  return failure();
1352 
1353  // Get the initial type information for the yield op.
1354  llvm::SmallVector<ValueKnowledge> resultKnowledge;
1355  resultKnowledge.reserve(yieldOps.front().getNumOperands());
1356  for (auto operand : yieldOps.front().getOperands()) {
1357  resultKnowledge.push_back(
1358  ValueKnowledge::getKnowledgeFromType(operand.getType()));
1359  }
1360 
1361  for (auto yieldOp : yieldOps) {
1362  if (resultKnowledge.size() != yieldOp.getNumOperands())
1363  return failure();
1364 
1365  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
1366  int32_t index = it.index();
1367  auto meet = ValueKnowledge::meet(
1368  resultKnowledge[index],
1369  ValueKnowledge::getKnowledgeFromType(it.value().getType()));
1370  if (!meet)
1371  continue;
1372  resultKnowledge[index] = meet;
1373  }
1374  }
1375 
1376  for (const ValueKnowledge &result : resultKnowledge) {
1377  inferredReturnShapes.push_back(result.getShapedTypeComponents());
1378  }
1379 
1380  return success();
1381 }
1382 
1383 LogicalResult WhileOp::inferReturnTypeComponents(
1384  MLIRContext *context, ::llvm::Optional<Location> location,
1385  ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1386  SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1388  for (auto &block : *regions[1])
1389  if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1390  yieldOps.push_back(returnOp);
1391 
1392  // TOSA's while must have a tosa.yield as its terminator. If not found this
1393  // tosa.while is invalid.
1394  if (yieldOps.empty())
1395  return failure();
1396 
1397  // Get the initial type information from the operand types.
1398  llvm::SmallVector<ValueKnowledge> resultKnowledge;
1399  resultKnowledge.reserve(yieldOps.front().getNumOperands());
1400  for (auto operand : yieldOps.front().getOperands()) {
1401  resultKnowledge.push_back(
1402  ValueKnowledge::getKnowledgeFromType(operand.getType()));
1403  }
1404 
1405  for (auto yieldOp : yieldOps) {
1406  if (resultKnowledge.size() != yieldOp.getNumOperands())
1407  return failure();
1408 
1409  for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
1410  int32_t index = it.index();
1411  if (auto meet = ValueKnowledge::meet(
1412  resultKnowledge[index],
1413  ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
1414  resultKnowledge[index] = meet;
1415  }
1416  }
1417  }
1418 
1419  for (const ValueKnowledge &result : resultKnowledge) {
1420  inferredReturnShapes.push_back(result.getShapedTypeComponents());
1421  }
1422 
1423  return success();
1424 }
1425 
1426 //===----------------------------------------------------------------------===//
1427 // TOSA Attribute Definitions.
1428 //===----------------------------------------------------------------------===//
1429 
1430 #define GET_ATTRDEF_CLASSES
1431 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
1432 
1433 //===----------------------------------------------------------------------===//
1434 // TOSA Operator Definitions.
1435 //===----------------------------------------------------------------------===//
1436 
1437 #define GET_OP_CLASSES
1438 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, BlockAndValueMapping &valueMapping)
Utility to check that all of the operations within &#39;src&#39; can be inlined.
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
U cast() const
Definition: Attributes.h:135
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Statically known information for a particular Value.
Definition: ShapeUtils.h:33
ShapedTypeComponents that represents the components of a ShapedType.
Represents a family of uniform, quantized types.
Definition: QuantTypes.h:256
bool isa() const
Definition: Attributes.h:111
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
Definition: QuantUtils.cpp:160
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition: ShapeUtils.h:136
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:685
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition: TosaOps.cpp:236
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
int64_t floor(Fraction f)
Definition: Fraction.h:63
#define REDUCE_SHAPE_INFER(OP)
Definition: TosaOps.cpp:914
int64_t getNumElements() const
Returns the number of elements in the shape.
static void getI64Values(ArrayAttr arrayAttr, SmallVector< int64_t > &values)
Definition: TosaOps.cpp:331
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
Definition: QuantUtils.cpp:215
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
Range of values and shapes (corresponding effectively to Shapes dialect&#39;s ValueShape type concept)...
static void buildUnaryOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter unary operators that have scale relationship between their...
Definition: TosaOps.cpp:290
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, ArrayAttr outpad, ArrayAttr stride, ArrayAttr outputShape)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition: TosaOps.cpp:194
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:898
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, ArrayAttr pad, ArrayAttr stride, ArrayAttr dilation)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition: TosaOps.cpp:173
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
void addOperands(ValueRange newOperands)
ShapeAdaptor getValueAsShape(int index)
Returns an argument as shape.
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:933
U dyn_cast() const
Definition: Types.h:270
#define NARY_SHAPE_INFER(OP)
Definition: TosaOps.cpp:945
static LogicalResult poolingInferReturnTypes(const ValueShapeRange &operands, DictionaryAttr attributes, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition: TosaOps.cpp:993
Attributes are known-constant values of operations.
Definition: Attributes.h:24
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:43
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
bool hasRank() const
Returns whether the shape has a rank.
This is the interface that must be implemented by the dialects of operations to be inlined...
Definition: InliningUtils.h:40
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition: TosaOps.cpp:303
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:32
void addTypes(ArrayRef< Type > newTypes)
ShapeAdaptor getShape(int index) const
Returns the shape of index&#39;th operand.
bool hasStaticShape() const
Returns whether the shape is fully static.
static LogicalResult verifyConvOp(T op)
Definition: TosaOps.cpp:103
This represents an operation in an abstracted form, suitable for use with the builder APIs...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:258
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition: TosaOps.cpp:343
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
Definition: QuantUtils.cpp:116
bool isDynamicDim(int index) const
Returns whether the index&#39;th dimension is dynamic.
int64_t ceil(Fraction f)
Definition: Fraction.h:65
static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings, Value padConst)
This builder is called on TOSA pad operator when an explicit pad_const value is passed in...
Definition: TosaOps.cpp:315
Type getType() const
Return the type of this value.
Definition: Value.h:118
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:332
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, ArrayAttr kernel, ArrayAttr stride, ArrayAttr pad)
Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr but avg_pool operator has...
Definition: TosaOps.cpp:272
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:328
int64_t getDimSize(int index) const
Returns the size of the index&#39;th dimension.
static ValueKnowledge getKnowledgeFromType(Type type)
Definition: ShapeUtils.h:45
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:372
static void getF64Values(ArrayAttr arrayAttr, SmallVector< double > &values)
Definition: TosaOps.cpp:337
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Definition: QuantUtils.cpp:189
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
Definition: QuantUtils.cpp:235
ValueRange getValues() const
Returns the Values in the ValueRange.
This class helps build Operations.
Definition: Builders.h:192
IntegerType getI32Type()
Definition: Builders.cpp:54
static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias)
The tosa.fully_connected op has its own builder as it does not have strides/dilation/padding.
Definition: TosaOps.cpp:217
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type. ...
Definition: FoldUtils.cpp:50
An attribute that represents a reference to a dense integer vector or tensor object.
SmallVector< Type, 4 > types
Types of the results of this operation.
int64_t getRank() const
Returns the rank of the shape.