29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/TypeSwitch.h"
38 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
45 #include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
46 #include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
47 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
48 #include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
51 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
72 return (isa<tosa::IfOp>(dest->getParentOp()) ||
73 isa<tosa::WhileOp>(dest->getParentOp()));
79 TosaDialectBytecodeInterface(
Dialect *dialect)
89 LogicalResult writeAttribute(
Attribute attr,
91 return ::writeAttribute(attr, writer);
101 LogicalResult writeType(
Type type,
103 return ::writeType(type, writer);
110 std::unique_ptr<DialectVersion>
113 reader.
emitError(
"Dialect does not support versioning");
117 LogicalResult upgradeFromVersion(
Operation *topLevelOp,
131 return {&getBodyGraph()};
139 return to_vector(llvm::map_range(shape, [](int64_t dim) {
140 return dim == -1 ? ShapedType::kDynamic : dim;
146 Type elementType = variableOp.getType();
156 void TosaDialect::initialize() {
158 #define GET_TYPEDEF_LIST
159 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
163 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
166 #define GET_ATTRDEF_LIST
167 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
169 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
170 declarePromisedInterfaces<
171 mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
172 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
173 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
174 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
175 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
176 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
177 GreaterEqualOp, MatMulOp>();
184 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
185 return builder.
create<tosa::ConstShapeOp>(
186 loc, type, llvm::cast<DenseIntElementsAttr>(value));
188 if (llvm::isa<ElementsAttr>(value))
189 return builder.
create<tosa::ConstOp>(loc, type,
190 llvm::cast<ElementsAttr>(value));
200 ParseResult getShapeAndElementType(
OpAsmParser &parser,
Type parsedType,
202 TypeAttr &typeAttr) {
203 if (
auto shapedType = dyn_cast<ShapedType>(parsedType)) {
204 if (!shapedType.hasRank())
206 <<
"expected ranked type";
208 auto elementType = shapedType.getElementType();
216 <<
"expected shaped type";
233 <<
"expected attribute";
235 if (
auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
236 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
240 <<
"expected Typed attr";
243 initialValueAttr =
nullptr;
247 <<
"expected type after colon";
249 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
254 TypeAttr typeAttr,
Attribute initialValueAttr) {
255 bool needsSpace =
false;
256 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
259 Type elementType = typeAttr.getValue();
260 RankedTensorType tensorType =
267 if (initialValueAttr) {
279 std::optional<int64_t>
idivCheck(
const int64_t lhs,
const int64_t rhs) {
287 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
288 srcType = quantType.getStorageType();
297 Value valZp, StringRef name) {
302 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
306 if (!bothInts || !sameBitWidth) {
308 <<
"expected " << name <<
" and " << name
309 <<
"_zp to both be integer of the same bitwidth, but got " << eType
310 <<
" vs. " << eZpType;
317 Value src, int32_t val) {
322 const auto padConstAttr{
323 llvm::isa<FloatType>(srcElemType)
328 return builder.
create<tosa::ConstOp>(loc, padConstType, padConstAttr);
335 template <
typename T>
337 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
338 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
340 auto inputEType = inputType.getElementType();
341 auto weightEType = weightType.getElementType();
343 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
345 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
346 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
347 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
349 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
350 inputEType = quantType.getStorageType();
352 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
353 weightEType = quantType.getStorageType();
355 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
356 biasEType = quantType.getStorageType();
358 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
359 resultEType = quantType.getStorageType();
361 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
365 "expect both bias and result to have same element type, got ")
366 << biasEType <<
" and " << resultEType;
370 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
371 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
372 if (inputEType != weightEType) {
374 "expect both input and weight to have same element type, got ")
375 << inputEType <<
" and " << weightEType;
380 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
381 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
384 if (inputIsFloat != weightIsFloat) {
386 "expect both input and weight to be float or not together, got ")
387 << inputEType <<
" and " << weightEType;
392 if (inputEType != inputZpEType) {
393 return op.emitOpError(
"expect both input and its zero point are the same "
394 "element type, got ")
395 << inputEType <<
" and " << inputZpEType;
399 if (weightEType != weightZpEType) {
400 return op.emitOpError(
"expect both weight and its zero point are the same "
401 "element type, got ")
402 << weightEType <<
" and " << weightZpEType;
405 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
406 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
409 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
410 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
418 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().
getType());
419 auto outputType = llvm::dyn_cast<TensorType>(getOutput().
getType());
421 if (!attrType || !outputType) {
422 emitOpError(
"expected tensors for attr/result type");
426 if (
auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
427 outputType.getElementType())) {
428 if (result.getStorageType() == attrType.getElementType())
432 if (attrType.getElementType() != outputType.getElementType()) {
433 emitOpError(
"expected same attr/result element types");
440 template <
typename T>
443 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
445 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
446 inputEType = quantType.getStorageType();
448 auto accType = op.getAccType();
449 if (inputEType.isInteger(8) && !accType.isInteger(32))
450 return op.emitOpError(
"accumulator type for i8 tensor is not i32");
452 if (inputEType.isInteger(16) && !accType.isInteger(48))
453 return op.emitOpError(
"accumulator type for i16 tensor is not i48");
455 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
456 return op.emitOpError(
"accumulator type for f8 tensor is not f16");
458 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
459 return op.emitOpError(
"accumulator type for f16 tensor is not f16/f32");
461 if (inputEType.isBF16() && !accType.isF32())
462 return op.emitOpError(
"accumulator type for bf16 tensor is not f32");
464 if (inputEType.isF32() && !accType.isF32())
465 return op.emitOpError(
"accumulator type for f32 tensor is not f32");
468 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
470 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
471 resultEType = quantType.getStorageType();
481 template <
typename T>
484 if (llvm::any_of(padding, [](int64_t p) {
return p < 0; }))
485 return op.emitOpError(
"expect all padding values to be >= 0, got ")
489 if (llvm::any_of(strides, [](int64_t s) {
return s < 1; }))
490 return op.emitOpError(
"expect all stride values to be >= 1, got ")
494 if (llvm::any_of(dilations, [](int64_t d) {
return d < 1; }))
495 return op.emitOpError(
"expect all dilation values to be >= 1, got ")
498 const RankedTensorType outputType =
499 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
504 const RankedTensorType inputType =
505 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
506 const RankedTensorType weightType =
507 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
509 if (inputType && weightType) {
510 const auto verifyOutputSize =
511 [&op](
const int64_t inputSize,
const int64_t kernelSize,
512 const int64_t outputSize,
const int64_t padBefore,
513 const int64_t padAfter,
const int64_t stride,
514 const int64_t dilation,
const llvm::StringRef dimName,
515 const llvm::StringRef dimAxis,
516 const llvm::StringRef padBeforeName,
517 const llvm::StringRef padAfterName) -> LogicalResult {
518 if (inputSize == ShapedType::kDynamic ||
519 kernelSize == ShapedType::kDynamic)
524 const std::optional<int64_t> calculatedOutSizeMinusOne =
idivCheck(
525 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
527 if (!calculatedOutSizeMinusOne.has_value())
528 return op.emitOpError(
"expected input_")
529 << dimName <<
" - 1 + pad_" << padBeforeName <<
" + pad_"
530 << padAfterName <<
" - (kernel_" << dimName
531 <<
" - 1) * dilation_" << dimAxis
532 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
533 << inputSize <<
" - 1 + " << padBefore <<
" + " << padAfter
534 <<
" - (" << kernelSize <<
" - 1) * " << dilation <<
") / "
537 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
538 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
539 return op.emitOpError(
"calculated output ")
540 << dimName <<
" did not match expected: "
541 <<
"calculated=" << calculatedOutSize
542 <<
", expected=" << outputSize;
548 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
549 if (failed(verifyOutputSize(
550 inputType.getDimSize(1), weightType.getDimSize(1),
551 outputType.getDimSize(1), padding[0], padding[1], strides[0],
552 dilations[0],
"height",
"y",
"top",
"bottom")))
555 if (failed(verifyOutputSize(
556 inputType.getDimSize(2), weightType.getDimSize(2),
557 outputType.getDimSize(2), padding[2], padding[3], strides[1],
558 dilations[1],
"width",
"x",
"left",
"right")))
563 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
564 if (failed(verifyOutputSize(
565 inputType.getDimSize(1), weightType.getDimSize(0),
566 outputType.getDimSize(1), padding[0], padding[1], strides[0],
567 dilations[0],
"height",
"y",
"top",
"bottom")))
570 if (failed(verifyOutputSize(
571 inputType.getDimSize(2), weightType.getDimSize(1),
572 outputType.getDimSize(2), padding[2], padding[3], strides[1],
573 dilations[1],
"width",
"x",
"left",
"right")))
578 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
579 if (failed(verifyOutputSize(
580 inputType.getDimSize(1), weightType.getDimSize(1),
581 outputType.getDimSize(1), padding[0], padding[1], strides[0],
582 dilations[0],
"depth",
"d",
"front",
"back")))
585 if (failed(verifyOutputSize(
586 inputType.getDimSize(2), weightType.getDimSize(2),
587 outputType.getDimSize(2), padding[2], padding[3], strides[1],
588 dilations[1],
"height",
"y",
"top",
"bottom")))
591 if (failed(verifyOutputSize(
592 inputType.getDimSize(3), weightType.getDimSize(3),
593 outputType.getDimSize(3), padding[4], padding[5], strides[2],
594 dilations[2],
"width",
"x",
"left",
"right")))
599 const RankedTensorType biasType =
600 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
605 const int64_t biasChannels = biasType.getDimSize(0);
606 const int64_t outputChannels =
607 outputType.getDimSize(outputType.getRank() - 1);
608 if (biasChannels == ShapedType::kDynamic ||
609 outputChannels == ShapedType::kDynamic)
613 if (biasChannels != outputChannels && biasChannels != 1)
614 return op.emitOpError(
615 "bias channels expected to be equal to output channels (")
616 << outputChannels <<
") or 1, got " << biasChannels;
623 StringRef name1,
Type type2,
625 auto shapeType1 = dyn_cast<ShapedType>(type1);
626 auto shapeType2 = dyn_cast<ShapedType>(type2);
627 if (!shapeType1 || !shapeType2)
630 auto elemType1 = shapeType1.getElementType();
631 auto elemType2 = shapeType2.getElementType();
632 if (elemType1 != elemType2)
634 <<
"require same element type for " << name1 <<
" (" << elemType1
635 <<
") and " << name2 <<
" (" << elemType2 <<
")";
639 <<
"require same shapes for " << name1 <<
" (" << type1 <<
") and "
640 << name2 <<
" (" << type2 <<
")";
650 if (list1.size() != list2.size())
652 <<
"require same number of values in " << name1 <<
" ("
653 << list1.size() <<
") and " << name2 <<
" (" << list2.size() <<
")";
655 for (
auto [type1, type2] :
669 return shapeAdaptor.
getNumElements() == 1 ? success() : failure();
677 tosa::VariableOp varOp =
nullptr;
691 if (
auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
692 if (symName == tosaOp.getName()) {
707 template <
typename T>
709 StringRef symName = op.getName();
712 return op->emitOpError(
"'")
713 << symName <<
"' has not been declared by 'tosa.variable'";
726 template <
typename T>
728 auto inputType = llvm::dyn_cast<TensorType>(inType);
729 auto outputType = llvm::dyn_cast<TensorType>(outType);
731 op.emitOpError(
"expect shaped tensor for input, got ") << inType;
735 op.emitOpError(
"expect shaped tensor for output, got ") << outType;
738 auto inputElementType = inputType.getElementType();
739 auto outputElementType = outputType.getElementType();
740 auto inputQuantType =
741 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
742 auto outputQuantType =
743 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
744 if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
745 (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
746 inputElementType != outputElementType) {
751 op.emitOpError(
"expect input and output to have same element type, got ")
752 << inputElementType <<
" and " << outputElementType;
759 const ShapedType resultType = llvm::cast<ShapedType>(
getType());
762 if (
const auto resultETy = resultType.getElementType();
763 !resultETy.isIntOrIndex())
764 return emitOpError(
"result tensor is not of integer type");
766 const auto inputType = llvm::cast<ShapedType>(getInput().
getType());
767 if (!inputType.hasRank())
771 const int64_t axis = getAxisAttr().getInt();
772 if (((axis < 0) || axis >= inputType.getRank()))
773 return emitOpError(
"specified axis is outside the rank of the tensor");
775 if (!resultType.hasRank())
781 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
783 return emitOpError(
"expected output shape '")
784 << expectedOutputShape <<
"', got '" << outputShape <<
"'";
789 template <
typename T>
792 if (llvm::any_of(kernel, [](int64_t s) {
return s < 1; }))
793 return op.emitOpError(
"expect all kernel values to be >= 1, got ")
797 if (llvm::any_of(strides, [](int64_t s) {
return s < 1; }))
798 return op.emitOpError(
"expect all stride values to be >= 1, got ")
802 if (llvm::any_of(padding, [](int64_t p) {
return p < 0; }))
803 return op.emitOpError(
"expect all padding values to be >= 0, got ")
807 const int64_t kernelX = kernel[1];
808 const int64_t padLeft = padding[2];
809 const int64_t padRight = padding[3];
810 if (padRight >= kernelX || padLeft >= kernelX)
811 return op.emitOpError(
"expected left/right padding to be less than the "
812 "width of the kernel, got pad_left=")
813 << padLeft <<
", pad_right=" << padRight <<
", kernel_x=" << kernelX;
815 const int64_t kernelY = kernel[0];
816 const int64_t padTop = padding[0];
817 const int64_t padBottom = padding[1];
818 if (padTop >= kernelY || padBottom >= kernelY)
819 return op.emitOpError(
"expected top/bottom padding to be less than the "
820 "height of the kernel, got pad_top=")
821 << padTop <<
", pad_bottom=" << padBottom
822 <<
", kernel_y=" << kernelY;
824 const auto inputType =
825 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
826 const auto outputType =
827 llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
828 if (!inputType || !outputType)
831 const auto verifyOutputSize =
832 [&op](
const int64_t inputSize,
const int64_t outputSize,
833 const int64_t kernelSize,
const int64_t strideSize,
834 const int64_t padBefore,
const int64_t padAfter,
835 const llvm::StringRef dimName,
const llvm::StringRef dimAxis,
836 const llvm::StringRef padBeforeName,
837 const llvm::StringRef padAfterName) -> LogicalResult {
838 if (ShapedType::isDynamic(inputSize))
841 const std::optional<int64_t> calculatedOutSizeMinusOne =
842 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
843 if (!calculatedOutSizeMinusOne.has_value())
844 return op.emitOpError(
"expected input_")
845 << dimName <<
" + pad_" << padBeforeName <<
" + pad_"
846 << padAfterName <<
" - kernel_" << dimAxis
847 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
848 << inputSize <<
" + " << padBefore <<
" + " << padAfter <<
" - "
849 << kernelSize <<
") / " << strideSize;
851 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
852 if (!ShapedType::isDynamic(outputSize) && calculatedOutSize != outputSize)
853 return op.emitOpError(
"calculated output ")
854 << dimName <<
" did not match expected: "
855 <<
"calculated=" << calculatedOutSize
856 <<
", expected=" << outputSize;
861 if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
862 kernel[0], strides[0], padding[0], padding[1],
863 "height",
"y",
"top",
"bottom")))
866 if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
867 kernel[1], strides[1], padding[2], padding[3],
868 "width",
"x",
"left",
"right")))
883 auto accType = getAccType();
884 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
885 return emitOpError(
"accumulator type for integer tensor is not i32");
887 if (inputETy.
isF16() && !(accType.isF16() || accType.isF32()))
888 return emitOpError(
"accumulator type for f16 tensor is not f16/f32");
890 if (inputETy.
isBF16() && !accType.isF32())
891 return emitOpError(
"accumulator type for bf16 tensor is not f32");
893 if (inputETy.
isF32() && !accType.isF32())
894 return emitOpError(
"accumulator type for f32 tensor is not f32");
896 if (inputETy != inputZpETy)
897 return emitOpError(
"expect both input and its zero point are the same "
898 "element type, got ")
899 << inputETy <<
" and " << inputZpETy;
901 if (resultETy != outputZpETy)
902 return emitOpError(
"expect both output and its zero point are the same "
903 "element type, got ")
904 << resultETy <<
" and " << outputZpETy;
906 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
907 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
910 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
911 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
919 llvm::cast<ShapedType>(getInput().
getType()).getElementType();
921 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
922 inputETy = quantType.getStorageType();
925 llvm::cast<ShapedType>(getOutput().
getType()).getElementType();
927 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
928 outputETy = quantType.getStorageType();
930 if (inputETy != outputETy)
931 return emitOpError(
"input/output element types are incompatible.");
933 auto maxValAttr = getMaxValAttr();
934 auto minValAttr = getMinValAttr();
938 if (inputETy.
isInteger(dataTypeBitWidth)) {
942 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
943 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
944 if (!intMaxValAttr || !intMinValAttr ||
945 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
946 (intMaxValAttr.getType() != inputETy))
947 return emitOpError(
"min/max attributes types are incompatible with "
948 "input/output element types.");
950 const bool isUnsigned = cast<IntegerType>(inputETy).isUnsigned();
951 const APInt minVal = intMinValAttr.getValue();
952 const APInt maxVal = intMaxValAttr.getValue();
953 if (isUnsigned ? maxVal.ult(minVal) : maxVal.slt(minVal))
954 return emitOpError(
"expected min_val <= max_val, got min_val=")
955 << minValAttr <<
", max_val=" << maxValAttr;
960 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
961 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
962 if (!floatMaxValAttr || !floatMinValAttr ||
963 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
964 (floatMaxValAttr.getType() != inputETy))
965 return emitOpError(
"min/max attributes types are incompatible with "
966 "input/output element types.");
968 const APFloat minVal = floatMinValAttr.getValue();
969 const APFloat maxVal = floatMaxValAttr.getValue();
970 if (minVal.isNaN() || maxVal.isNaN())
971 return emitOpError(
"min/max attributes should not be 'NaN', got min_val=")
972 << minValAttr <<
", max_val=" << maxValAttr;
975 return emitOpError(
"expected min_val <= max_val, got min_val=")
976 << minValAttr <<
", max_val=" << maxValAttr;
996 result.
addOperands({input, weight, bias, zps.first, zps.second});
1001 Type finalOutputType = outputType;
1018 result.
addOperands({input, weight, bias, zps.first, zps.second});
1022 Type finalOutputType = outputType;
1039 result.
addOperands({a, b, zps.first, zps.second});
1041 Type finalOutputType{outputType};
1044 auto inputBits = eType.getIntOrFloatBitWidth();
1046 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1047 assert(outputShapedType &&
"Output must be a shaped type");
1049 IntegerType accElementType;
1050 if (inputBits == 16)
1055 finalOutputType = outputShapedType.clone(accElementType);
1066 DenseArrayAttr kernel, DenseArrayAttr stride,
1067 DenseArrayAttr pad, TypeAttr accType) {
1070 int64_t outputZp{0};
1072 if (
auto quantAttr =
1074 inputZp = quantAttr.getInputZp();
1075 outputZp = quantAttr.getOutputZp();
1077 const std::optional<Value> inputZpOp =
1082 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1084 const std::optional<Value> outputZpOp =
1087 (void)
emitError(loc,
"Failed to create output zero point tensor for "
1088 "quantized AVG_POOL2D op");
1091 if (inputZpOp && outputZpOp) {
1092 result.
addOperands({input, inputZpOp.value(), outputZpOp.value()});
1103 result.
types.push_back(outputType);
1113 int64_t input1Zp{0};
1114 int64_t outputZp{0};
1117 input1Zp = quantAttr.getInputZp();
1118 outputZp = quantAttr.getOutputZp();
1120 const std::optional<Value> input1ZpOp =
1124 loc,
"Failed to create input1 zero point for quantized NEGATE op");
1127 const std::optional<Value> outputZpOp =
1131 loc,
"Failed to create output zero point for quantized NEGATE op");
1134 if (input1ZpOp && outputZpOp) {
1135 result.
addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1143 result.
types.push_back(outputType);
1156 zp =
static_cast<int32_t
>(quantAttr.getInputZp());
1159 result.
addOperands({input, paddings, padConstOp});
1160 result.
types.push_back(outputType);
1164 StringRef name,
Type variableType,
1169 auto shapedType = dyn_cast<ShapedType>(variableType);
1171 (void)
emitError(loc,
"variable type must be a shaped type");
1174 if (!shapedType.hasRank()) {
1175 (void)
emitError(loc,
"variable type must be a ranked type");
1179 auto elementType = shapedType.getElementType();
1196 int64_t outRank = 0;
1197 for (
int i = 0, e = operands.size(); i != e; ++i) {
1199 if (!shape.hasRank()) {
1204 outRank = std::max<int64_t>(outRank, shape.getRank());
1207 outShape.resize(outRank, 1);
1209 for (
int i = 0, e = operands.size(); i != e; ++i) {
1211 auto rankDiff = outShape.size() - shape.getRank();
1213 for (
size_t i = 0, e = shape.getRank(); i < e; ++i) {
1214 auto dim1 = outShape[i + rankDiff];
1215 auto dim2 = shape.getDimSize(i);
1216 auto resolvedDim = dim1;
1220 }
else if (dim2 == 1) {
1222 }
else if (dim1 != dim2) {
1225 outShape[i + rankDiff] = resolvedDim;
1232 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1233 MLIRContext *context, ::std::optional<Location> location,
1234 ArgMaxOp::Adaptor adaptor,
1237 IntegerAttr axis = adaptor.getProperties().axis;
1238 int32_t axisVal = axis.getValue().getSExtValue();
1240 if (!inputShape.hasRank()) {
1246 outShape.reserve(inputShape.getRank() - 1);
1247 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1250 outShape.push_back(inputShape.getDimSize(i));
1257 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1258 MLIRContext *context, ::std::optional<Location> location,
1259 RFFT2dOp::Adaptor adaptor,
1261 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1263 if (!inputShape.hasRank())
1267 outputShape.resize(3, ShapedType::kDynamic);
1268 outputShape[0] = inputShape.getDimSize(0);
1269 outputShape[1] = inputShape.getDimSize(1);
1270 int64_t inWidth = inputShape.getDimSize(2);
1274 if (inWidth != ShapedType::kDynamic)
1275 outputShape[2] = inWidth / 2 + 1;
1284 const llvm::StringRef dimName) {
1285 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1288 << dimName <<
" to be a power of two, got " << dimSize;
1294 const auto outputTypes = getResultTypes();
1296 return emitOpError(
"expected output shapes to match, got ") << outputTypes;
1298 const auto inputType =
1299 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1303 const int64_t height = inputType.getDimSize(1);
1304 if (!ShapedType::isDynamic(height) &&
1308 const int64_t width = inputType.getDimSize(2);
1309 if (!ShapedType::isDynamic(width) &&
1313 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1319 outputType.getShape().drop_back())))
1320 return emitOpError(
"expected batch and height dimensions of input/output "
1321 "to match, got input=")
1322 << inputType <<
" output=" << outputType;
1325 const int64_t outputWidth = outputType.getDimSize(2);
1326 if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) &&
1327 (outputWidth != (width / 2) + 1))
1329 "expected output width to be equal to input_width / 2 + 1, got ")
1335 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1336 MLIRContext *context, ::std::optional<Location> location,
1337 FFT2dOp::Adaptor adaptor,
1339 inferredReturnShapes.push_back(
1341 inferredReturnShapes.push_back(
1347 const auto inputRealType =
1348 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1349 const auto inputImagType =
1350 llvm::dyn_cast<RankedTensorType>(getInputImag().
getType());
1351 if (!inputRealType || !inputImagType)
1354 const auto trySelectStaticDim = [](
const int64_t a,
const int64_t b) {
1355 return ShapedType::isDynamic(a) ? a : b;
1358 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1359 inputImagType.getDimSize(1));
1360 if (!ShapedType::isDynamic(height) &&
1364 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1365 inputImagType.getDimSize(2));
1366 if (!ShapedType::isDynamic(width) &&
1373 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1374 MLIRContext *context, ::std::optional<Location> location,
1375 ConcatOp::Adaptor adaptor,
1378 const Properties &prop = adaptor.getProperties();
1379 int32_t axis = prop.axis.getValue().getSExtValue();
1381 bool hasRankedInput =
false;
1382 for (
auto operand : adaptor.getOperands()) {
1384 if (!operandShape.hasRank())
1388 if (!hasRankedInput)
1389 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1392 for (
int i = 0, s = operandShape.getRank(); i < s; i++) {
1393 if (i == axis || operandShape.isDynamicDim(i))
1395 if (outputShape[i] == ShapedType::kDynamic)
1396 outputShape[i] = operandShape.getDimSize(i);
1397 if (outputShape[i] != operandShape.getDimSize(i))
1399 "Cannot concat tensors with different sizes"
1400 " on the non-axis dimension ",
1404 hasRankedInput =
true;
1407 if (adaptor.getInput1().empty())
1411 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1412 if (!hasRankedInput) {
1418 int64_t concatDimSize = 0;
1419 for (
auto operand : adaptor.getOperands()) {
1424 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1425 concatDimSize = ShapedType::kDynamic;
1429 concatDimSize += operandShape.getDimSize(axis);
1432 outputShape[axis] = concatDimSize;
1440 auto outType = getOutput().getType();
1444 if (inputList.empty())
1445 return emitOpError(
"expect at least one input");
1447 if (!llvm::all_of(inputList, [&](
auto input) {
1449 *
this, input.getType(), outType));
1454 const int32_t axis = getAxis();
1456 for (
const auto &input : inputList) {
1457 const Type inputType = input.getType();
1459 if (currShape.hasRank()) {
1460 firstRankedInputShape = currShape;
1462 if (axis < 0 || axis >= firstRankedInputShape.
getRank())
1463 return emitOpError(
"expect axis to be within range 0 < axis < "
1464 "rank(input1[firstRankedTensorIdx]), got ")
1470 const auto allOperandsHasRank = [](
const Value input) {
1473 if (llvm::all_of(inputList, allOperandsHasRank)) {
1474 const int64_t firstInputRank = firstRankedInputShape.
getRank();
1476 for (
const auto &[index, input] :
llvm::enumerate(inputList.drop_front())) {
1478 const int64_t inputRank = inputShape.getRank();
1479 const size_t operandNum = index + 1;
1482 if (inputRank != firstInputRank)
1484 "expect all operands to have the same rank, but got ")
1485 << firstInputRank <<
" vs " << inputRank <<
" on operands 0 and "
1489 for (
int i = 0; i < inputRank; i++) {
1490 const int64_t inputDim = inputShape.getDimSize(i);
1491 const int64_t firstInputDim = firstRankedInputShape.
getDimSize(i);
1492 if (i == axis || firstRankedInputShape.
isDynamicDim(i) ||
1493 inputShape.isDynamicDim(i))
1495 if (inputDim != firstInputDim)
1496 return emitOpError(
"expect all operand shapes to have the same sizes "
1497 "on non-axis dimensions, but got ")
1498 << inputDim <<
" vs " << firstInputDim <<
" at index " << i
1499 <<
" on operands 0 and " << operandNum;
1504 int64_t axisSum = 0;
1505 for (
const auto &input : inputList) {
1507 if (inputShape.isDynamicDim(axis)) {
1512 axisSum += inputShape.getDimSize(axis);
1515 if (axisSum >= 0 && outputShape.hasRank() &&
1516 !outputShape.isDynamicDim(axis) &&
1517 axisSum != outputShape.getDimSize(axis))
1518 return emitOpError(
"requires sum of axis dimensions of input1 "
1519 "equal to output axis dimension, got ")
1520 << axisSum <<
" and " << outputShape.getDimSize(axis);
1526 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1527 MLIRContext *context, ::std::optional<Location> location,
1544 if (l.size() != r.size() || l.size() != 1)
1549 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1550 MLIRContext *context, ::std::optional<Location> location,
1551 MatMulOp::Adaptor adaptor,
1558 outShape.resize(3, ShapedType::kDynamic);
1560 if (lhsShape.hasRank()) {
1561 outShape[0] = lhsShape.getDimSize(0);
1562 outShape[1] = lhsShape.getDimSize(1);
1565 if (rhsShape.hasRank()) {
1566 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1568 outShape[2] = rhsShape.getDimSize(2);
1576 auto aType = llvm::dyn_cast<ShapedType>(getA().
getType());
1577 auto bType = llvm::dyn_cast<ShapedType>(getB().
getType());
1581 return emitOpError(
"expect a shaped tensor for input a, got ")
1582 << getA().getType();
1585 return emitOpError(
"expect a shaped tensor for input b, got ")
1586 << getB().getType();
1588 auto aElementType = aType.getElementType();
1589 auto bElementType = bType.getElementType();
1591 auto aQuantizedEType =
1592 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1593 auto bQuantizedEType =
1594 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1596 if (aQuantizedEType || bQuantizedEType) {
1597 if (!aQuantizedEType || !bQuantizedEType) {
1598 return emitOpError(
"expect operands to be both quantized or both not "
1600 << aElementType <<
" and " << bElementType;
1603 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1604 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1605 if (aQuantWidth != bQuantWidth) {
1606 return emitOpError(
"expect quantized operands to have same widths, got ")
1607 << aQuantWidth <<
" and " << bQuantWidth;
1611 if (aElementType != bElementType) {
1612 return emitOpError(
"expect same element type for inputs a and b, got ")
1613 << aElementType <<
" and " << bElementType;
1620 if (aEType != aZpEType) {
1621 return emitOpError(
"expect input a and a_zp have the same "
1622 "element type, got ")
1623 << aEType <<
" and " << aZpEType;
1628 if (bEType != bZpEType) {
1629 return emitOpError(
"expect input b and b_zp have the same "
1630 "element type, got ")
1631 << bEType <<
" and " << bZpEType;
1634 FailureOr<int64_t> maybeAZp = getAZeroPoint();
1635 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1638 FailureOr<int64_t> maybeBZp = getBZeroPoint();
1639 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1645 LogicalResult tosa::PadOp::inferReturnTypeComponents(
1646 MLIRContext *context, ::std::optional<Location> location,
1647 PadOp::Adaptor adaptor,
1649 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1651 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
1656 if (!inputShape.hasRank()) {
1657 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
1666 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1671 outputShape.reserve(inputShape.getRank());
1672 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1673 if (inputShape.isDynamicDim(i)) {
1674 outputShape.push_back(ShapedType::kDynamic);
1677 auto padFront = paddingValues[i * 2];
1678 auto padBack = paddingValues[i * 2 + 1];
1679 if (padFront < 0 || padBack < 0) {
1681 outputShape.push_back(ShapedType::kDynamic);
1685 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
1699 if (
auto padConst = getPadConst()) {
1707 RankedTensorType inputType =
1708 llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1709 RankedTensorType outputType =
1710 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
1711 if (!inputType || !outputType)
1714 auto inputRank = inputType.getRank();
1715 auto outputRank = outputType.getRank();
1716 if (inputRank != outputRank)
1717 return emitOpError() <<
"expect same input and output tensor rank, but got "
1718 <<
"inputRank: " << inputRank
1719 <<
", outputRank: " << outputRank;
1726 auto paddingValues = paddingAttr.getValues<APInt>();
1727 if (paddingValues.size() !=
static_cast<size_t>(inputRank * 2))
1728 return emitOpError() <<
"padding tensor must have " << inputRank
1729 <<
" * 2 = " << inputRank * 2 <<
" elements, but got "
1730 << paddingValues.size();
1732 auto inputShape = inputType.getShape();
1733 auto outputShape = outputType.getShape();
1735 for (int64_t i = 0; i < inputRank; ++i) {
1736 int64_t padStart = paddingValues[i * 2].getSExtValue();
1737 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
1739 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
1740 return emitOpError()
1741 <<
"invalid padding values at dimension " << i
1742 <<
": values must be non-negative or -1 for dynamic padding, got ["
1743 << padStart <<
", " << padEnd <<
"]";
1747 if (inputShape[i] == ShapedType::kDynamic ||
1748 outputShape[i] == ShapedType::kDynamic)
1751 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
1752 return emitOpError() <<
"mismatch in output shape at dimension " << i
1753 <<
": expected " << inputShape[i] <<
" + "
1754 << padStart <<
" + " << padEnd <<
" = "
1755 << (inputShape[i] + padStart + padEnd)
1756 <<
", but got " << outputShape[i];
1763 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
1764 MLIRContext *context, ::std::optional<Location> location,
1765 SliceOp::Adaptor adaptor,
1774 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
1782 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1785 if (inputShape.hasRank()) {
1786 for (
size_t i = 0; i < size.size(); i++) {
1787 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
1788 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
1789 start[i] < inputShape.getDimSize(i))) {
1791 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
1794 outputShape[i] = size[i];
1798 if (size[i] == -1) {
1799 outputShape[i] = inputShape.getDimSize(i) - start[i];
1800 }
else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
1802 outputShape[i] = size[i];
1821 if (inputShape.hasRank()) {
1822 const auto inputRank = inputShape.getRank();
1824 if (outputShape.hasRank() && inputRank != outputShape.getRank())
1826 "expect input1 and output to have the same ranks, got ")
1827 << inputRank <<
" and " << outputShape.getRank();
1829 const auto startShapeRank =
1830 llvm::cast<tosa::shapeType>(getStart().
getType()).getRank();
1831 if (inputRank != startShapeRank)
1832 return emitOpError(
"length of start is not equal to rank of input shape");
1834 const auto sizeShapeRank =
1835 llvm::cast<tosa::shapeType>(getSize().
getType()).getRank();
1836 if (inputRank != sizeShapeRank)
1837 return emitOpError(
"length of size is not equal to rank of input shape");
1843 LogicalResult tosa::MulOp::inferReturnTypeComponents(
1844 MLIRContext *context, ::std::optional<Location> location,
1860 const Value output = getOutput();
1865 if (
auto resIntType = dyn_cast<IntegerType>(resElemType)) {
1866 IntegerType lhsIntType =
1868 IntegerType rhsIntType =
1870 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
1871 return emitOpError(
"requires the same element type for all operands");
1876 if (lhsIntType.getWidth() > resIntType.getWidth())
1877 return emitOpError(
"invalid data type size for operands or result");
1882 for (
int i = 0; i < 2; ++i) {
1885 "requires the same element type for all operands and results");
1889 ElementsAttr shift_elem;
1891 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
1893 return emitOpError() <<
"require shift to be 0 for float type";
1901 TypeRange operandTypes = getOperandTypes();
1902 ShapedType aType = cast<ShapedType>(operandTypes[0]);
1903 ShapedType bType = cast<ShapedType>(operandTypes[1]);
1905 const bool aHasRank = aType.hasRank();
1906 const bool bHasRank = bType.hasRank();
1907 if (aHasRank && bHasRank) {
1908 const int64_t aRank = aType.getRank();
1909 const int64_t bRank = bType.getRank();
1911 return emitOpError(
"a and b operands don't have matching ranks, got ")
1912 << aRank <<
" and " << bRank;
1917 aType.getShape(), bType.getShape(), resultShape))
1918 return emitOpError(
"a and b operands don't have broadcast-compatible "
1920 << aType <<
" and " << bType;
1923 ShapedType resultType = cast<ShapedType>(output.
getType());
1924 if (!resultType.hasRank())
1927 const int64_t resultRank = resultType.getRank();
1928 if (aHasRank && resultRank != aType.getRank())
1929 return emitOpError(
"result type has different rank than a, got ")
1930 << resultRank <<
" vs " << aType.getRank();
1931 if (bHasRank && resultRank != bType.getRank())
1932 return emitOpError(
"result type has different rank than b, got ")
1933 << resultRank <<
" vs " << bType.getRank();
1938 LogicalResult tosa::TableOp::inferReturnTypeComponents(
1939 MLIRContext *context, ::std::optional<Location> location,
1940 TableOp::Adaptor adaptor,
1942 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1944 if (!inputShape.hasRank()) {
1949 inferredReturnShapes.resize(1);
1950 inputShape.getDims(inferredReturnShapes[0]);
1955 TensorType inputType = getInput1().getType();
1956 TensorType outputType = getOutput().getType();
1959 inputType.getRank() != outputType.getRank())
1960 return emitOpError()
1961 <<
"expected input tensor rank to equal result tensor rank";
1963 auto inputDims = inputType.
getShape();
1964 auto outputDims = outputType.
getShape();
1966 int64_t dim = it.index();
1967 auto [inputDim, outputDim] = it.value();
1968 if (!ShapedType::isDynamic(outputDim) && outputDim != inputDim) {
1969 return emitOpError() <<
"dim(result, " << dim <<
") = " << outputDim
1970 <<
" doesn't match dim(input, " << dim
1971 <<
") = " << inputDim;
1983 multiples = llvm::to_vector(
1984 llvm::map_range(multiplesAttr.getValues<APInt>(),
1985 [](
const APInt &val) { return val.getSExtValue(); }));
1989 LogicalResult tosa::TileOp::inferReturnTypeComponents(
1990 MLIRContext *context, ::std::optional<Location> location,
1991 TileOp::Adaptor adaptor,
1998 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2006 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2008 if (!inputShape.hasRank()) {
2009 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2010 inferredReturnShapes.push_back(
2013 }
else if (
static_cast<size_t>(inputShape.getRank()) != multiples.size())
2017 outputShape.reserve(multiples.size());
2018 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
2019 if (multiples[i] == ShapedType::kDynamic) {
2020 outputShape.push_back(ShapedType::kDynamic);
2022 int64_t dim = inputShape.getDimSize(i);
2023 if (dim != ShapedType::kDynamic)
2024 dim *= multiples[i];
2025 outputShape.push_back(dim);
2039 ShapedType inputType = llvm::cast<ShapedType>(getInput1().
getType());
2040 ShapedType outputType = llvm::cast<ShapedType>(
getType());
2042 shapeType multiplesType =
2043 llvm::cast<tosa::shapeType>(getMultiples().
getType());
2045 auto multiplesRank = multiplesType.getRank();
2047 if (inputType.hasRank()) {
2048 if (inputType.getRank() != multiplesRank)
2049 return emitOpError(
"expect 'multiples' to have rank ")
2050 << inputType.getRank() <<
" but got " << multiplesRank <<
".";
2051 if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
2052 return emitOpError(
"expect same input and output tensor rank.");
2053 }
else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2054 return emitOpError(
"expect 'multiples' array to have length ")
2055 << outputType.getRank() <<
" but got " << multiplesRank <<
".";
2058 if (getConstantMultiples(multiples).succeeded() &&
2059 llvm::any_of(multiples, [](int64_t v) {
return v <= 0 && v != -1; }))
2061 "expect element of 'multiples' to be positive integer or -1.");
2067 if (l.size() != r.size() || l.size() != 1)
2072 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2073 MLIRContext *context, ::std::optional<Location> location,
2074 ReshapeOp::Adaptor adaptor,
2076 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2081 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2091 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2092 inferredReturnShapes.push_back(
2100 int64_t numElements = inputShape.getNumElements();
2101 int64_t staticMul = 1;
2102 for (
auto val : newShapeValue) {
2103 if (!ShapedType::isDynamic(val)) {
2109 for (
auto &val : newShapeValue) {
2110 if (ShapedType::isDynamic(val))
2111 val = numElements / staticMul;
2114 inferredReturnShapes.push_back(
2125 TensorType inputType = getInput1().getType();
2130 return mlir::success();
2133 int missingDims = llvm::count(shapeValues, -1);
2134 if (missingDims > 1)
2135 return emitOpError() <<
"expected at most one target dimension to be -1";
2137 const auto outputType = dyn_cast<RankedTensorType>(
getType());
2141 if ((int64_t)shapeValues.size() != outputType.getRank())
2142 return emitOpError() <<
"new shape does not match result rank";
2144 for (
auto [newShapeDim, outputShapeDim] :
2145 zip(shapeValues, outputType.getShape())) {
2146 if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2147 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2148 return emitOpError() <<
"new shape is inconsistent with result shape";
2150 if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2151 return emitOpError() <<
"new shape has invalid tensor dimension size "
2155 if (inputType.hasStaticShape()) {
2156 int64_t inputElementsNum = inputType.getNumElements();
2157 if (outputType.hasStaticShape()) {
2158 int64_t outputElementsNum = outputType.getNumElements();
2159 if (inputElementsNum != outputElementsNum) {
2160 return emitOpError() <<
"cannot reshape " << inputElementsNum
2161 <<
" elements into " << outputElementsNum;
2165 int64_t newShapeElementsNum = std::accumulate(
2166 shapeValues.begin(), shapeValues.end(), 1LL,
2167 [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
2168 bool isStaticNewShape =
2169 llvm::all_of(shapeValues, [](int64_t s) {
return s > 0; });
2170 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2171 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2172 return emitOpError() <<
"cannot reshape " << inputElementsNum
2173 <<
" elements into " << newShapeElementsNum;
2177 return mlir::success();
2184 ElementsAttr zpAttr;
2189 Type zpElemType = zpAttr.getElementType();
2191 if (llvm::isa<FloatType>(zpElemType)) {
2192 if (zpAttr.getValues<APFloat>()[0].isZero()) {
2199 if (llvm::isa<IntegerType>(zpElemType)) {
2201 return zpAttr.getValues<APInt>()[0].getSExtValue();
2203 return zpAttr.getValues<APInt>()[0].getZExtValue();
2210 template <
typename T>
2212 const std::string &operand) {
2215 if (!zpElemType.
isInteger(8) && zp != 0) {
2217 std::string lower = operand;
2218 std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
2219 return op.emitOpError()
2220 << lower <<
" zero point must be zero for non-int8 integer types";
2228 const std::string &operand) {
2229 bool isInputZp = (operand ==
"Input");
2231 bool tensorUnsigned =
2232 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2233 StringRef tensorName = isInputZp ?
"input" :
"output";
2239 !(zpElemType.
isInteger(16) && tensorUnsigned)) {
2240 return op.emitOpError()
2241 <<
"expect " << tensorName <<
"_zp of 0, got " << zp;
2243 if (zpElemType.
isInteger(16) && tensorUnsigned && zp != 32768) {
2244 return op.emitOpError() <<
"expect " << tensorName
2245 <<
"_zp of 0 or 32768 for unsigned int16 "
2246 << tensorName <<
", got " << zp;
2253 #define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2254 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2255 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2257 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2258 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2277 #undef ZERO_POINT_HELPER
2279 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2280 MLIRContext *context, ::std::optional<Location> location,
2281 TransposeOp::Adaptor adaptor,
2283 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2292 const auto inputRank = inputShape.
getRank();
2296 if (adaptor.getPerms().size() !=
static_cast<size_t>(inputRank)) {
2302 if (inputRank == 0) {
2308 bool allTheSame =
true;
2309 for (
int i = 1, s = inputRank; i < s; i++) {
2319 outputShape.resize(inputRank, inputShape.
getDimSize(0));
2324 outputShape.resize(inputRank, ShapedType::kDynamic);
2327 if (llvm::any_of(adaptor.getPerms(),
2328 [inputRank](
const auto i) { return i >= inputRank; }))
2331 outputShape.reserve(inputRank);
2332 for (
int i = 0, s = inputRank; i < s; i++) {
2333 outputShape[i] = inputShape.
getDimSize(adaptor.getPerms()[i]);
2352 if (inputShape.hasRank() &&
2353 constantPerms.size() !=
static_cast<size_t>(inputShape.getRank()))
2354 return emitOpError() <<
"expected perms attribute to have size "
2355 << inputShape.getRank()
2356 <<
" (input rank) but got size "
2357 << constantPerms.size();
2359 if (inputShape.hasRank() && outputShape.hasRank() &&
2360 inputShape.getRank() != outputShape.getRank())
2361 return emitOpError()
2362 <<
"expected input tensor rank to equal result tensor rank";
2364 if (outputShape.hasRank() &&
2365 constantPerms.size() !=
static_cast<size_t>(outputShape.getRank()))
2366 return emitOpError() <<
"expected perms attribute to have size "
2367 << outputShape.getRank()
2368 <<
" (output rank) but got size "
2369 << constantPerms.size();
2371 if (!llvm::all_of(constantPerms,
2372 [&constantPerms](int32_t s) {
2374 static_cast<size_t>(s) < constantPerms.size();
2377 constantPerms, [](int32_t v) -> int64_t {
return v; }))))
2378 return emitOpError() <<
"expected valid permutation indices";
2381 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2382 inputShape.getNumElements() != outputShape.getNumElements())
2383 return emitOpError() <<
"expected input1 and output to have same numbers "
2385 << inputShape.getNumElements() <<
" and "
2386 << outputShape.getNumElements();
2390 if (inputShape.hasRank() && outputShape.hasRank()) {
2391 for (
auto i = 0; i < outputShape.getRank(); i++) {
2392 if (inputShape.isDynamicDim(constantPerms[i]) ||
2393 outputShape.isDynamicDim(i))
2396 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2397 return emitOpError()
2398 <<
"expected output tensor dim " << i <<
" to match "
2399 <<
"input dim " << constantPerms[i] <<
" with value of "
2400 << inputShape.getDimSize(constantPerms[i]);
2412 Value input = getInput1();
2413 auto inputType = cast<TensorType>(input.
getType());
2416 for (
auto dim : transposePerms) {
2417 int32_t dimInInput = transposePerms[dim];
2418 if (inputType.isDynamicDim(dimInInput))
2420 builder.
create<tensor::DimOp>(getLoc(), input, dimInInput)
2424 builder.
getIndexAttr(inputType.getDimSize(dimInInput));
2427 reifiedReturnShapes.emplace_back(std::move(returnedDims));
2431 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2432 MLIRContext *context, ::std::optional<Location> location,
2433 GatherOp::Adaptor adaptor,
2436 outputShape.resize(3, ShapedType::kDynamic);
2438 ShapeAdaptor valuesShape(adaptor.getValues().getType());
2439 if (valuesShape.hasRank()) {
2440 outputShape[0] = valuesShape.getDimSize(0);
2441 outputShape[2] = valuesShape.getDimSize(2);
2444 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2445 if (indicesShape.hasRank()) {
2446 if (outputShape[0] == ShapedType::kDynamic)
2447 outputShape[0] = indicesShape.getDimSize(0);
2448 if (outputShape[1] == ShapedType::kDynamic)
2449 outputShape[1] = indicesShape.getDimSize(1);
2467 int64_t N = ShapedType::kDynamic;
2468 int64_t
W = ShapedType::kDynamic;
2469 int64_t
C = ShapedType::kDynamic;
2471 if (valuesShape.hasRank()) {
2472 N = valuesShape.getDimSize(0);
2473 C = valuesShape.getDimSize(2);
2475 if (indicesShape.hasRank()) {
2476 const int64_t indicesN = indicesShape.getDimSize(0);
2477 W = indicesShape.getDimSize(1);
2478 if (N == ShapedType::kDynamic)
2480 else if (indicesN != ShapedType::kDynamic && N != indicesN)
2481 return emitOpError() <<
"requires indices dimension 0 to have size " << N
2482 <<
", got " << indicesN;
2484 if (outputShape.hasRank()) {
2485 const int64_t outputN = outputShape.getDimSize(0);
2486 const int64_t outputW = outputShape.getDimSize(1);
2487 const int64_t outputC = outputShape.getDimSize(2);
2488 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2490 return emitOpError() <<
"requires output dimension 0 to have size " << N
2491 <<
", got " << outputN;
2493 if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2495 return emitOpError() <<
"requires output dimension 1 to have size " <<
W
2496 <<
", got " << outputW;
2497 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2499 return emitOpError() <<
"requires output dimension 2 to have size " <<
C
2500 <<
", got " << outputC;
2505 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2506 MLIRContext *context, ::std::optional<Location> location,
2507 ResizeOp::Adaptor adaptor,
2510 outputShape.resize(4, ShapedType::kDynamic);
2513 if (!inputShape.hasRank())
2516 outputShape[0] = inputShape.getDimSize(0);
2517 outputShape[3] = inputShape.getDimSize(3);
2518 int64_t inputHeight = inputShape.getDimSize(1);
2519 int64_t inputWidth = inputShape.getDimSize(2);
2521 if ((inputHeight == ShapedType::kDynamic) ||
2522 (inputWidth == ShapedType::kDynamic))
2536 const int64_t outputHeight =
2537 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2541 const int64_t outputWidth =
2542 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2546 if (outputHeight < 0 || outputWidth < 0) {
2549 "calculated output height and width must be non-negative, "
2551 outputHeight,
", width = ", outputWidth);
2554 outputShape[1] = outputHeight;
2555 outputShape[2] = outputWidth;
2561 const Value input = getInput();
2562 const Value output = getOutput();
2563 const RankedTensorType inputType =
2564 llvm::dyn_cast<RankedTensorType>(input.
getType());
2565 const RankedTensorType outputType =
2566 llvm::dyn_cast<RankedTensorType>(output.
getType());
2578 if (llvm::any_of(scaleValues, [](int64_t s) {
return s <= 0; }))
2579 return emitOpError(
"expect all scale values to be > 0, got ")
2582 const int64_t scaleYN = scaleValues[0];
2583 const int64_t scaleYD = scaleValues[1];
2584 const int64_t scaleXN = scaleValues[2];
2585 const int64_t scaleXD = scaleValues[3];
2587 const int64_t offsetY = offsetValues[0];
2588 const int64_t offsetX = offsetValues[1];
2590 const int64_t borderY = borderValues[0];
2591 const int64_t borderX = borderValues[1];
2598 const int64_t oh = outputType.getDimSize(1);
2599 const int64_t ow = outputType.getDimSize(2);
2600 const int64_t ih = inputType.getDimSize(1);
2601 const int64_t iw = inputType.getDimSize(2);
2607 if (ih != ShapedType::kDynamic && ih != 1) {
2608 const std::optional<int64_t> calculatedOutHeightMinusOne =
2609 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
2610 if (!calculatedOutHeightMinusOne.has_value())
2611 return emitOpError(
"expected (input_height - 1) * scale_y_n - offset_y + "
2613 <<
"to be wholly divisible by scale_y_d, got ((" << ih
2614 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
2615 <<
") / " << scaleYD;
2616 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
2617 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
2618 return emitOpError(
"calculated output height did not match expected: ")
2619 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
2626 if (iw != ShapedType::kDynamic && iw != 1) {
2627 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
2628 const std::optional<int64_t> calculatedOutWidthMinusOne =
2630 if (!calculatedOutWidthMinusOne.has_value())
2631 return emitOpError(
"expected (input_width - 1) * scale_x_n - offset_x + "
2633 <<
"to be wholly divisible by scale_x_d, got ((" << iw
2634 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
2635 <<
") / " << scaleXD;
2636 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
2637 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
2638 return emitOpError(
"calculated output width did not match expected: ")
2639 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
2645 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
2646 MLIRContext *context, ::std::optional<Location> location,
2647 ScatterOp::Adaptor adaptor,
2650 outputShape.resize(3, ShapedType::kDynamic);
2652 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
2653 if (valuesInShape.hasRank()) {
2654 outputShape[0] = valuesInShape.getDimSize(0);
2655 outputShape[1] = valuesInShape.getDimSize(1);
2656 outputShape[2] = valuesInShape.getDimSize(2);
2659 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2660 if (indicesShape.hasRank()) {
2661 if (outputShape[0] == ShapedType::kDynamic)
2662 outputShape[0] = indicesShape.getDimSize(0);
2666 if (inputShape.hasRank()) {
2667 if (outputShape[0] == ShapedType::kDynamic)
2668 outputShape[0] = inputShape.getDimSize(0);
2669 if (outputShape[2] == ShapedType::kDynamic)
2670 outputShape[2] = inputShape.getDimSize(2);
2692 int64_t N = ShapedType::kDynamic;
2693 int64_t K = ShapedType::kDynamic;
2694 int64_t
W = ShapedType::kDynamic;
2695 int64_t
C = ShapedType::kDynamic;
2696 if (valuesInShape.hasRank()) {
2697 N = valuesInShape.getDimSize(0);
2698 K = valuesInShape.getDimSize(1);
2699 C = valuesInShape.getDimSize(2);
2701 if (indicesShape.hasRank()) {
2702 const int64_t indicesN = indicesShape.getDimSize(0);
2703 W = indicesShape.getDimSize(1);
2704 if (N == ShapedType::kDynamic)
2706 else if (indicesN != ShapedType::kDynamic && N != indicesN)
2707 return emitOpError() <<
"requires indices dimension 0 to have size " << N
2708 <<
", got " << indicesN;
2710 if (inputShape.hasRank()) {
2711 const int64_t inputN = inputShape.getDimSize(0);
2712 const int64_t inputW = inputShape.getDimSize(1);
2713 const int64_t inputC = inputShape.getDimSize(2);
2714 if (N == ShapedType::kDynamic)
2716 else if (inputN != ShapedType::kDynamic && N != inputN)
2717 return emitOpError() <<
"requires input dimension 0 to have size " << N
2718 <<
", got " << inputN;
2719 if (W == ShapedType::kDynamic)
2721 else if (inputW != ShapedType::kDynamic && W != inputW)
2722 return emitOpError() <<
"requires input dimension 1 to have size " <<
W
2723 <<
", got " << inputW;
2725 if (C == ShapedType::kDynamic)
2727 else if (inputC != ShapedType::kDynamic && C != inputC)
2728 return emitOpError() <<
"requires input dimension 2 to have size " <<
C
2729 <<
", got " << inputC;
2731 if (outputShape.hasRank()) {
2732 const int64_t outputN = outputShape.getDimSize(0);
2733 const int64_t outputK = outputShape.getDimSize(1);
2734 const int64_t outputC = outputShape.getDimSize(2);
2735 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2737 return emitOpError() <<
"requires values_out dimension 0 to have size "
2738 << N <<
", got " << outputN;
2739 if (K == ShapedType::kDynamic)
2741 else if (outputK != ShapedType::kDynamic && K != outputK)
2742 return emitOpError() <<
"requires values_out dimension 1 to have size "
2743 << K <<
", got " << outputK;
2744 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2746 return emitOpError() <<
"requires values_out dimension 2 to have size "
2747 <<
C <<
", got " << outputC;
2749 if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
2750 return emitOpError() <<
"requires dimensions K >= W, got K=" << K
2759 int64_t axisVal = axis.getValue().getSExtValue();
2760 if (!operandShape.
hasRank() || operandShape.
getRank() <= axisVal) {
2766 operandShape.
getDims(outputShape);
2767 outputShape[axisVal] = 1;
2772 #define COMPATIBLE_RETURN_TYPES(OP) \
2773 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
2774 if (l.size() != r.size() || l.size() != 1) \
2776 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
2778 return succeeded(verifyCompatibleShape(l[0], r[0])); \
2781 #define REDUCE_SHAPE_INFER(OP) \
2782 LogicalResult OP::inferReturnTypeComponents( \
2783 MLIRContext *context, ::std::optional<Location> location, \
2784 OP::Adaptor adaptor, \
2785 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2787 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
2788 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
2789 const Properties &prop = adaptor.getProperties(); \
2790 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
2791 inferredReturnShapes); \
2793 COMPATIBLE_RETURN_TYPES(OP)
2801 #undef REDUCE_SHAPE_INFER
2803 #undef COMPATIBLE_RETURN_TYPES
2805 template <
typename T>
2808 TensorType inputType = op.getInput().getType();
2809 TensorType outputType = op.getOutput().getType();
2810 int32_t reduceAxis = op.getAxis();
2812 if (reduceAxis < 0) {
2813 op.emitOpError(
"reduce axis must not be negative");
2817 int64_t inputRank = inputType.getRank();
2820 if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
2821 op.emitOpError(
"expect input tensor rank (")
2822 << inputRank <<
") to be larger than reduce axis (" << reduceAxis
2828 int64_t outputRank = outputType.getRank();
2829 if (inputType.
hasRank() && outputRank != inputType.getRank()) {
2831 "expect output tensor rank to be equal to input tensor rank");
2834 if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
2835 op.emitOpError(
"expect output tensor rank (")
2836 << outputRank <<
") to be larger than reduce axis (" << reduceAxis
2842 if (outputRank != 0) {
2843 auto outputShape = outputType.
getShape();
2844 if (!outputType.isDynamicDim(reduceAxis) &&
2845 outputShape[reduceAxis] != 1) {
2846 op.emitOpError(
"expect reduced dimension size to be 1, got ")
2847 << outputShape[reduceAxis];
2874 #define NARY_SHAPE_INFER(OP) \
2875 LogicalResult OP::inferReturnTypeComponents( \
2876 MLIRContext *context, ::std::optional<Location> location, \
2877 ValueShapeRange operands, DictionaryAttr attributes, \
2878 OpaqueProperties properties, RegionRange regions, \
2879 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2880 return NAryInferReturnTypes(operands, inferredReturnShapes); \
2920 #undef PRED_SHAPE_INFER
2922 LogicalResult tosa::NegateOp::inferReturnTypeComponents(
2923 MLIRContext *context, ::std::optional<Location> location,
2924 NegateOp::Adaptor adaptor,
2926 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2933 const Type input1Type = getInput1().getType();
2934 const Type outputType = getOutput().getType();
2941 return emitOpError() <<
"requires the same shape for input1 and output";
2944 const Type input1ZpEType =
2946 if (input1EType != input1ZpEType) {
2947 return emitOpError(
"expect both input1 and its zero point are the same "
2948 "element type, got ")
2949 << input1EType <<
" and " << input1ZpEType;
2952 const Type outputZpEType =
2954 if (outputEType != outputZpEType) {
2955 return emitOpError(
"expect both output and its zero point are the same "
2956 "element type, got ")
2957 << outputEType <<
" and " << outputZpEType;
2960 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
2961 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
2964 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
2965 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
2976 outputShape.resize(4, ShapedType::kDynamic);
2991 if (!ShapedType::isDynamic(height)) {
2992 int64_t padded = height + pad[0] + pad[1] - kernel[0];
2993 outputShape[1] = padded / stride[0] + 1;
2996 if (!ShapedType::isDynamic(width)) {
2997 int64_t padded = width + pad[2] + pad[3] - kernel[1];
2998 outputShape[2] = padded / stride[1] + 1;
3005 LogicalResult Conv2DOp::inferReturnTypeComponents(
3006 MLIRContext *context, ::std::optional<Location> location,
3007 Conv2DOp::Adaptor adaptor,
3011 int64_t inputWidth = ShapedType::kDynamic;
3012 int64_t inputHeight = ShapedType::kDynamic;
3013 int64_t weightWidth = ShapedType::kDynamic;
3014 int64_t weightHeight = ShapedType::kDynamic;
3019 if (inputShape.hasRank()) {
3020 outputShape[0] = inputShape.getDimSize(0);
3021 inputHeight = inputShape.getDimSize(1);
3022 inputWidth = inputShape.getDimSize(2);
3026 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3027 if (weightShape.hasRank()) {
3028 outputShape[3] = weightShape.getDimSize(0);
3029 weightHeight = weightShape.getDimSize(1);
3030 weightWidth = weightShape.getDimSize(2);
3035 if (biasShape.hasRank()) {
3036 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3037 ? biasShape.getDimSize(0)
3045 if (!ShapedType::isDynamic(inputHeight) &&
3046 !ShapedType::isDynamic(weightHeight)) {
3047 int64_t inputSize = inputHeight + padding[0] + padding[1];
3048 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3049 int64_t unstridedResult = inputSize - filterSize + 1;
3050 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3053 if (!ShapedType::isDynamic(inputWidth) &&
3054 !ShapedType::isDynamic(weightWidth)) {
3055 int64_t inputSize = inputWidth + padding[2] + padding[3];
3056 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3057 int64_t unstridedResult = inputSize - filterSize + 1;
3058 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3072 LogicalResult Conv3DOp::inferReturnTypeComponents(
3073 MLIRContext *context, ::std::optional<Location> location,
3074 Conv3DOp::Adaptor adaptor,
3078 int64_t inputWidth = ShapedType::kDynamic;
3079 int64_t inputHeight = ShapedType::kDynamic;
3080 int64_t inputDepth = ShapedType::kDynamic;
3082 int64_t weightWidth = ShapedType::kDynamic;
3083 int64_t weightHeight = ShapedType::kDynamic;
3084 int64_t weightDepth = ShapedType::kDynamic;
3088 if (inputShape.hasRank()) {
3089 outputShape[0] = inputShape.getDimSize(0);
3090 inputDepth = inputShape.getDimSize(1);
3091 inputHeight = inputShape.getDimSize(2);
3092 inputWidth = inputShape.getDimSize(3);
3096 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3097 if (weightShape.hasRank()) {
3098 outputShape[4] = weightShape.getDimSize(0);
3099 weightDepth = weightShape.getDimSize(1);
3100 weightHeight = weightShape.getDimSize(2);
3101 weightWidth = weightShape.getDimSize(3);
3106 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
3107 outputShape[4] = biasShape.getDimSize(0);
3114 if (!ShapedType::isDynamic(inputDepth) &&
3115 !ShapedType::isDynamic(weightDepth)) {
3116 int32_t inputSize = inputDepth + pad[0] + pad[1];
3117 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
3118 int32_t unstridedResult = inputSize - filterSize + 1;
3119 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3122 if (!ShapedType::isDynamic(inputHeight) &&
3123 !ShapedType::isDynamic(weightHeight)) {
3124 int32_t inputSize = inputHeight + pad[2] + pad[3];
3125 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
3126 int32_t unstridedResult = inputSize - filterSize + 1;
3127 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3130 if (!ShapedType::isDynamic(inputWidth) &&
3131 !ShapedType::isDynamic(weightWidth)) {
3132 int32_t inputSize = inputWidth + pad[4] + pad[5];
3133 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
3134 int32_t unstridedResult = inputSize - filterSize + 1;
3135 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3149 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3150 MLIRContext *context, ::std::optional<Location> location,
3151 AvgPool2dOp::Adaptor adaptor,
3154 const Properties &prop = adaptor.getProperties();
3156 inferredReturnShapes);
3159 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3160 MLIRContext *context, ::std::optional<Location> location,
3161 MaxPool2dOp::Adaptor adaptor,
3164 const Properties &prop = adaptor.getProperties();
3166 inferredReturnShapes);
3180 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3181 MLIRContext *context, ::std::optional<Location> location,
3182 DepthwiseConv2DOp::Adaptor adaptor,
3186 int64_t inputWidth = ShapedType::kDynamic;
3187 int64_t inputHeight = ShapedType::kDynamic;
3188 int64_t inputChannels = ShapedType::kDynamic;
3190 int64_t weightWidth = ShapedType::kDynamic;
3191 int64_t weightHeight = ShapedType::kDynamic;
3192 int64_t depthChannels = ShapedType::kDynamic;
3196 if (inputShape.hasRank()) {
3197 outputShape[0] = inputShape.getDimSize(0);
3198 inputHeight = inputShape.getDimSize(1);
3199 inputWidth = inputShape.getDimSize(2);
3200 inputChannels = inputShape.getDimSize(3);
3204 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3205 if (weightShape.hasRank()) {
3206 weightHeight = weightShape.getDimSize(0);
3207 weightWidth = weightShape.getDimSize(1);
3208 inputChannels = ShapedType::isDynamic(inputChannels)
3209 ? weightShape.getDimSize(2)
3211 depthChannels = weightShape.getDimSize(3);
3216 if (!ShapedType::isDynamic(inputChannels) &&
3217 !ShapedType::isDynamic(depthChannels)) {
3218 outputShape[3] = inputChannels * depthChannels;
3223 if (biasShape.hasRank()) {
3224 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3225 ? biasShape.getDimSize(0)
3233 if (!ShapedType::isDynamic(inputHeight) &&
3234 !ShapedType::isDynamic(weightHeight)) {
3235 int64_t inputSize = inputHeight + padding[0] + padding[1];
3236 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3237 int64_t unstridedResult = inputSize - filterSize + 1;
3238 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3241 if (!ShapedType::isDynamic(inputWidth) &&
3242 !ShapedType::isDynamic(weightWidth)) {
3243 int64_t inputSize = inputWidth + padding[2] + padding[3];
3244 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3245 int64_t unstridedResult = inputSize - filterSize + 1;
3246 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3260 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3261 MLIRContext *context, ::std::optional<Location> location,
3262 TransposeConv2DOp::Adaptor adaptor,
3266 int64_t inputWidth = ShapedType::kDynamic;
3267 int64_t inputHeight = ShapedType::kDynamic;
3268 int64_t weightWidth = ShapedType::kDynamic;
3269 int64_t weightHeight = ShapedType::kDynamic;
3273 if (inputShape.hasRank()) {
3274 outputShape[0] = ShapedType::isDynamic(outputShape[0])
3275 ? inputShape.getDimSize(0)
3277 inputHeight = inputShape.getDimSize(1);
3278 inputWidth = inputShape.getDimSize(2);
3282 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3283 if (weightShape.hasRank()) {
3284 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3285 ? weightShape.getDimSize(0)
3287 weightHeight = weightShape.getDimSize(1);
3288 weightWidth = weightShape.getDimSize(2);
3293 if (biasShape.hasRank()) {
3294 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3295 ? biasShape.getDimSize(0)
3302 if (!ShapedType::isDynamic(inputHeight) &&
3303 !ShapedType::isDynamic(weightHeight)) {
3304 int64_t calculateSize =
3305 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3307 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3310 if (!ShapedType::isDynamic(inputWidth) &&
3311 !ShapedType::isDynamic(weightWidth)) {
3312 int64_t calculateSize =
3313 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3315 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3327 const int64_t strideY = strides[0];
3328 const int64_t strideX = strides[1];
3330 if (strideY < 1 || strideX < 1)
3331 return emitOpError(
"expect all stride values to be >= 1, got [")
3334 const auto checkPadAgainstKernelDim =
3335 [
this](int64_t pad_value, int64_t kernel_dim_size,
3336 llvm::StringRef pad_name,
3337 llvm::StringRef kernel_dim_name) -> LogicalResult {
3338 if (pad_value <= -kernel_dim_size)
3339 return emitOpError(
"expected ")
3340 << pad_name <<
" > -" << kernel_dim_name
3341 <<
", but got: " << pad_name <<
"=" << pad_value <<
" and "
3342 << kernel_dim_name <<
"=" << kernel_dim_size;
3347 const int64_t outPadTop = padding[0];
3348 const int64_t outPadBottom = padding[1];
3349 const int64_t outPadLeft = padding[2];
3350 const int64_t outPadRight = padding[3];
3352 const auto weightType =
3353 llvm::dyn_cast<RankedTensorType>(getWeight().
getType());
3356 const int64_t kernelHeight = weightType.getDimSize(1);
3357 if (!ShapedType::isDynamic(kernelHeight)) {
3358 if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
3359 "out_pad_top",
"KH")))
3362 if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3363 "out_pad_bottom",
"KH")))
3367 const int64_t kernelWidth = weightType.getDimSize(2);
3368 if (!ShapedType::isDynamic(kernelWidth)) {
3369 if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
3370 "out_pad_left",
"KW")))
3373 if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3374 "out_pad_right",
"KW")))
3380 const auto outputType =
3381 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
3385 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
3386 if (inputType && weightType) {
3387 const int64_t inputHeight = inputType.getDimSize(1);
3388 const int64_t kernelHeight = weightType.getDimSize(1);
3389 const int64_t outputHeight = outputType.getDimSize(1);
3391 if (!ShapedType::isDynamic(inputHeight) &&
3392 !ShapedType::isDynamic(outputHeight)) {
3394 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3396 "dimension mismatch: expected OH == (IH - 1) * stride_y "
3397 "+ out_pad_top + out_pad_bottom + KH, but got ")
3398 << outputHeight <<
" != (" << inputHeight <<
" - 1) * "
3399 << strideY <<
" + " << outPadTop <<
" + " << outPadBottom
3400 <<
" + " << kernelHeight;
3403 const int64_t inputWidth = inputType.getDimSize(2);
3404 const int64_t kernelWidth = weightType.getDimSize(2);
3405 const int64_t outputWidth = outputType.getDimSize(2);
3407 if (!ShapedType::isDynamic(inputWidth) &&
3408 !ShapedType::isDynamic(outputWidth)) {
3410 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3412 "dimension mismatch: expected OW == (IW - 1) * stride_x "
3413 "+ out_pad_left + out_pad_right + KW, but got ")
3414 << outputWidth <<
" != (" << inputWidth <<
" - 1) * " << strideX
3415 <<
" + " << outPadLeft <<
" + " << outPadRight <<
" + "
3420 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().
getType());
3425 const int64_t biasChannels = biasType.getDimSize(0);
3428 if (biasChannels == ShapedType::kDynamic)
3431 const int64_t outputChannels = outputType.getDimSize(3);
3432 if (biasChannels != outputChannels && biasChannels != 1)
3434 "bias channels expected to be equal to output channels (")
3435 << outputChannels <<
") or 1, got " << biasChannels;
3441 auto inputType = llvm::dyn_cast<ShapedType>(getInput().
getType());
3443 emitOpError(
"expect shaped tensor for input, got ") << getInput().getType();
3447 auto inputElementType =
3449 if (!mlir::isa<IntegerType>(inputElementType)) {
3450 emitOpError(
"expect input to have integer element type, got ")
3451 << inputElementType;
3455 auto outputType = llvm::dyn_cast<ShapedType>(getOutput().
getType());
3457 emitOpError(
"expect shaped tensor for output, got ")
3458 << getOutput().getType();
3462 auto outputElementType =
3464 if (!mlir::isa<IntegerType>(outputElementType)) {
3465 emitOpError(
"expect output to have integer element type, got ")
3466 << outputElementType;
3478 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
3479 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
3482 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3483 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3486 auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().
getType());
3487 if (!multiplierType) {
3488 emitOpError(
"expect shaped tensor for multiplier, got ")
3489 << getMultiplier().getType();
3493 auto shiftType = llvm::dyn_cast<ShapedType>(getShift().
getType());
3495 emitOpError(
"expect shaped tensor for shift, got ") << getShift().getType();
3500 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
3501 emitOpError(
"expect i32 element type for multiplier for scale32=true, got ")
3502 << multiplierType.getElementType();
3507 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
3509 "expect i16 element type for multiplier for scale32=false, got ")
3510 << multiplierType.getElementType();
3514 if (!inputType.hasRank())
3520 int64_t numChannels = 1;
3521 if (getPerChannel()) {
3522 if (inputType.getRank() < 1) {
3523 emitOpError(
"requires input to be at least rank 1 when per_channel is "
3524 "true, but got rank ")
3525 << inputType.getRank();
3528 numChannels = inputType.getDimSize(inputType.getRank() - 1);
3531 if (!multiplierType.hasRank())
3536 if (multiplierShape[0] != ShapedType::kDynamic &&
3537 multiplierShape[0] != numChannels) {
3538 emitOpError(
"expect shape of { ")
3539 << numChannels <<
" } for multiplier input, got { "
3540 << multiplierShape[0] <<
" }";
3544 if (!shiftType.hasRank())
3549 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3550 emitOpError(
"expect shape of { ")
3551 << numChannels <<
" } for shift input, got { " << shiftShape[0] <<
" }";
3558 LogicalResult RescaleOp::inferReturnTypeComponents(
3559 MLIRContext *context, ::std::optional<Location> location,
3560 RescaleOp::Adaptor adaptor,
3567 LogicalResult IfOp::inferReturnTypeComponents(
3568 MLIRContext *context, ::std::optional<Location> location,
3569 IfOp::Adaptor adaptor,
3572 for (
Region *region : adaptor.getRegions()) {
3573 for (
auto &block : *region)
3574 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3575 yieldOps.push_back(returnOp);
3578 if (yieldOps.empty())
3583 resultKnowledge.reserve(yieldOps.front().getNumOperands());
3584 for (
auto operand : yieldOps.front().getOperands()) {
3585 resultKnowledge.push_back(
3589 for (
auto yieldOp : yieldOps) {
3590 if (resultKnowledge.size() != yieldOp.getNumOperands())
3594 int32_t index = it.index();
3596 resultKnowledge[index],
3600 resultKnowledge[index] = meet;
3605 inferredReturnShapes.push_back(result.getShapedTypeComponents());
3611 LogicalResult WhileOp::inferReturnTypeComponents(
3612 MLIRContext *context, ::std::optional<Location> location,
3613 WhileOp::Adaptor adaptor,
3616 for (
auto &block : adaptor.getBodyGraph())
3617 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3618 yieldOps.push_back(returnOp);
3622 if (yieldOps.empty())
3627 resultKnowledge.reserve(yieldOps.front().getNumOperands());
3628 for (
auto operand : yieldOps.front().getOperands()) {
3629 resultKnowledge.push_back(
3633 for (
auto yieldOp : yieldOps) {
3634 if (resultKnowledge.size() != yieldOp.getNumOperands())
3638 int32_t index = it.index();
3640 resultKnowledge[index],
3642 resultKnowledge[index] = meet;
3648 inferredReturnShapes.push_back(result.getShapedTypeComponents());
3654 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
3655 if (
auto vt = llvm::dyn_cast<VectorType>(
getType()))
3656 return llvm::to_vector<4>(vt.getShape());
3657 return std::nullopt;
3694 bool printBlockTerminators =
false;
3696 p <<
" " << getCondition();
3697 if (!getResults().empty()) {
3698 p <<
" -> (" << getResultTypes() <<
")";
3700 printBlockTerminators =
true;
3705 printBlockTerminators);
3708 auto &elseRegion = getElseGraph();
3709 if (!elseRegion.
empty()) {
3713 printBlockTerminators);
3721 "'then_graph' arguments", getInputList(),
3727 "'else_graph' arguments", getInputList(),
3732 auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
3734 "'then_graph' results", getOutputList(),
3739 auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
3741 "'else_graph' results", getOutputList(),
3746 auto condType = getCondition().getType();
3748 return emitOpError() <<
"'condition' must be a size 1 tensor, got "
3756 getOutputList(),
"'output_list'")
3761 "'cond_graph' arguments", getInputList(),
3767 "'body_graph' arguments", getInputList(),
3772 auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
3774 "'body_graph' results", getInputList(),
3781 auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
3782 if (condYield.getInputs().size() != 1)
3783 return emitOpError() <<
"require 'cond_graph' only have one result";
3785 auto condOutType = condYield.getInputs()[0].getType();
3787 return emitOpError() <<
"'cond_graph' result must be a size 1 tensor, got "
3791 return emitOpError() <<
"'cond_graph' result must be a boolean tensor, got "
3802 TensorType inputType = getInput1().getType();
3803 TensorType outputType = getOutput().getType();
3804 int32_t reverseAxis = getAxis();
3806 if (reverseAxis < 0)
3807 return emitOpError(
"expected non-negative reverse axis");
3809 int64_t inputRank = inputType.getRank();
3812 if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
3813 return emitOpError(
"expect input tensor rank (")
3814 << inputRank <<
") to be larger than reverse axis (" << reverseAxis
3818 int64_t outputRank = outputType.getRank();
3819 if (inputType.
hasRank() && outputRank != inputType.getRank())
3821 "expect output tensor rank to be equal to input tensor rank");
3822 if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
3823 return emitOpError(
"expect output tensor rank (")
3824 << outputRank <<
") to be larger than reverse axis ("
3825 << reverseAxis <<
")";
3841 auto predicateType = llvm::dyn_cast<ShapedType>(getPred().
getType());
3842 if (!predicateType) {
3843 return emitOpError(
"expect shaped tensor for input1, got ")
3844 << getInput1().getType();
3846 auto predicateElementType = predicateType.getElementType();
3847 if (!predicateElementType.isInteger(1)) {
3848 return emitOpError(
"expect element type of bool for input1, got ")
3849 << predicateElementType;
3856 StringRef symName = getName();
3858 if (succeeded(varOp))
3859 return emitOpError(
"illegal to have multiple declaration of '")
3893 FunctionType functionType;
3898 result.
addTypes(functionType.getResults());
3900 if (functionType.getNumInputs() != operands.size()) {
3902 <<
"expected as many input types as operands "
3903 <<
"(expected " << operands.size() <<
" got "
3904 << functionType.getNumInputs() <<
")";
3914 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
3915 regionArgs[i].type = functionType.getInput(i);
3917 return failure(parser.
parseRegion(*cond, regionArgs) ||
3925 StringRef prefix =
"") {
3926 assert(blocksArgs.size() == initializers.size() &&
3927 "expected same length of arguments and initializers");
3928 if (initializers.empty())
3931 parser << prefix <<
'(';
3932 llvm::interleaveComma(
3933 llvm::zip(blocksArgs, initializers), parser,
3934 [&](
auto it) { parser << std::get<0>(it) <<
" = " << std::get<1>(it); });
3940 getInputList(),
" ");
3943 getResults().getTypes());
3958 if (llvm::isa<FloatType>(srcElemType)) {
3960 zpType, builder.
getFloatAttr(srcElemType,
static_cast<double>(zp)));
3961 return builder.
create<tosa::ConstOp>(loc, zpType, zpAttr);
3963 if (llvm::isa<IntegerType>(srcElemType)) {
3966 return builder.
create<tosa::ConstOp>(loc, zpType, zpAttr);
3968 llvm::errs() <<
"zero point is not allowed for unsupported data types\n";
3969 return std::nullopt;
3977 return mlir::isa<tosa::shapeType>(t);
3984 return emitError() <<
"invalid rank (must be >= 0): " << rank;
3990 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
3991 Operation *definingOp = v.getDefiningOp();
3993 return op->
emitOpError(
"shape operand is not compile time resolvable");
4002 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4003 return op->
emitOpError(
"must have operands with tosa shape type");
4007 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4008 return op->
emitOpError(
"must have result with tosa shape type");
4021 auto getRank = [](
const Type type) {
4022 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4028 for (
auto type : operandTypes) {
4029 if (getRank(type) != rank) {
4030 return op->
emitOpError(
"operands don't have matching ranks");
4033 for (
auto type : resultTypes) {
4034 if (getRank(type) != rank) {
4035 return op->
emitOpError(
"result shape has different rank than operands");
4047 auto valuesRank = getValues().getType().getRank();
4048 if (valuesRank != 1)
4049 return emitOpError(
"expect elements in attribute values with rank 1");
4051 auto count = getValues().getNumElements();
4052 auto rank = (cast<tosa::shapeType>(getResult().
getType())).getRank();
4053 if (!(count == rank || (count == 1 && rank == 0))) {
4054 return emitOpError(
"expect number of elements in attribute values (")
4055 << count <<
") to be equal to the rank (" << rank
4056 <<
") for the result shape type";
4065 #define GET_ATTRDEF_CLASSES
4066 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
4071 #define GET_TYPEDEF_CLASSES
4072 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
4078 #define GET_OP_CLASSES
4079 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType)
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
static FailureOr< tosa::VariableOp > findVariableDecl(Operation *op, StringRef symName)
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
#define REDUCE_SHAPE_INFER(OP)
static LogicalResult verifyConvOp(T op)
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
static LogicalResult verifyReduceOp(T op)
#define NARY_SHAPE_INFER(OP)
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
static LogicalResult verifyConvOpErrorIf(T op)
static LogicalResult verifyConvOpModes(T op)
std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
#define COMPATIBLE_RETURN_TYPES(OP)
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Type getStorageElementTypeOrSelf(Type type)
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
static LogicalResult verifyPoolingOp(T op)
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
MutableArrayRef< BlockArgument > BlockArgListType
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
This is a utility class for mapping one set of IR entities to another.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class indicates that op operates on tosa shape types.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class provides an abstraction over the different types of ranges over Regions.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperator(Operation *op)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
bool isa_tosa_shape_type(mlir::Type t)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)