28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/TypeSwitch.h"
36 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
43 #include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
44 #include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
45 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
46 #include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
49 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
70 return (isa<tosa::IfOp>(dest->getParentOp()) ||
71 isa<tosa::WhileOp>(dest->getParentOp()));
77 TosaDialectBytecodeInterface(
Dialect *dialect)
87 LogicalResult writeAttribute(
Attribute attr,
89 return ::writeAttribute(attr, writer);
99 LogicalResult writeType(
Type type,
101 return ::writeType(type, writer);
108 std::unique_ptr<DialectVersion>
111 reader.
emitError(
"Dialect does not support versioning");
115 LogicalResult upgradeFromVersion(
Operation *topLevelOp,
129 return {&getBodyGraph()};
137 return to_vector(llvm::map_range(shape, [](int64_t dim) {
138 return dim == -1 ? ShapedType::kDynamic : dim;
144 Type elementType = variableOp.getType();
154 void TosaDialect::initialize() {
156 #define GET_TYPEDEF_LIST
157 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
161 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
164 #define GET_ATTRDEF_LIST
165 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
167 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
168 declarePromisedInterfaces<
169 shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
170 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
171 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
172 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
173 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
174 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
175 GreaterEqualOp, MatMulOp>();
182 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
183 return tosa::ConstShapeOp::create(builder, loc, type,
184 llvm::cast<DenseIntElementsAttr>(value));
186 if (llvm::isa<ElementsAttr>(value))
187 return tosa::ConstOp::create(builder, loc, type,
188 llvm::cast<ElementsAttr>(value));
198 ParseResult getShapeAndElementType(
OpAsmParser &parser,
Type parsedType,
200 TypeAttr &typeAttr) {
201 if (
auto shapedType = dyn_cast<ShapedType>(parsedType)) {
202 if (!shapedType.hasRank())
204 <<
"expected ranked type";
206 auto elementType = shapedType.getElementType();
214 <<
"expected shaped type";
231 <<
"expected attribute";
233 if (
auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
234 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
238 <<
"expected Typed attr";
241 initialValueAttr =
nullptr;
245 <<
"expected type after colon";
247 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
252 TypeAttr typeAttr,
Attribute initialValueAttr) {
253 bool needsSpace =
false;
254 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
257 Type elementType = typeAttr.getValue();
258 RankedTensorType tensorType =
265 if (initialValueAttr) {
276 template <
typename EnumType>
277 ParseResult parseAttrEntryWithEnumHandling(
OpAsmParser &parser,
279 llvm::StringRef name;
286 if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
287 if (name ==
"rounding_mode" &&
289 auto sym = symbolizeRoundingMode(kw);
292 <<
"invalid rounding_mode value: " << kw;
299 if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
301 auto sym = symbolizeResizeMode(kw);
304 <<
"invalid resize mode value: " << kw;
312 if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
314 auto sym = symbolizeNanPropagationMode(kw);
317 <<
"invalid nan_mode value: " << kw;
330 template <
typename EnumType>
335 [&]() { return parser.parseOperand(operands.emplace_back()); }))
343 if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
367 parser << namedAttr.
getName().strref() <<
" = ";
369 if (
auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
370 parser << roundingModeAttr.getValue();
371 }
else if (
auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
372 parser << resizeModeAttr.getValue();
373 }
else if (
auto nanPropagationModeAttr =
374 dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
375 parser << nanPropagationModeAttr.getValue();
388 const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
390 if (
auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
391 if (nanAttr.getValue() == kDefaultNanValue) {
393 toPrint.erase(attr.getName());
399 if (!toPrint.empty()) {
401 llvm::interleaveComma(toPrint, parser, [&](
const NamedAttribute namedAttr) {
402 printNamedAttr(parser, namedAttr);
418 llvm::interleaveComma(op->
getAttrs(), parser,
420 printNamedAttr(parser, namedAttr);
432 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
436 printWithEnumHandling(parser, *
this);
440 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
444 printWithEnumHandling(parser, *
this);
448 return parseWithEnumHandling<tosa::ResizeMode>(parser, result);
452 printWithEnumHandling(parser, *
this);
456 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
460 printWithNanPropagationHandling(parser, *
this);
464 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
468 printWithNanPropagationHandling(parser, *
this);
472 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
476 printWithNanPropagationHandling(parser, *
this);
480 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
484 printWithNanPropagationHandling(parser, *
this);
488 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
492 printWithNanPropagationHandling(parser, *
this);
496 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
500 printWithNanPropagationHandling(parser, *
this);
504 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
508 printWithNanPropagationHandling(parser, *
this);
515 static std::optional<int64_t>
idivCheck(
const int64_t lhs,
const int64_t rhs) {
523 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
524 srcType = quantType.getStorageType();
533 Value valZp, StringRef name) {
538 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
542 if (!bothInts || !sameBitWidth) {
544 <<
"expected " << name <<
" and " << name
545 <<
"_zp to both be integer of the same bitwidth, but got " << eType
546 <<
" vs. " << eZpType;
553 Value src, int32_t val) {
558 const auto padConstAttr{
559 llvm::isa<FloatType>(srcElemType)
564 return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
571 template <
typename T>
573 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
574 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
576 auto inputEType = inputType.getElementType();
577 auto weightEType = weightType.getElementType();
579 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
581 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
582 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
583 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
585 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
586 inputEType = quantType.getStorageType();
588 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
589 weightEType = quantType.getStorageType();
591 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
592 biasEType = quantType.getStorageType();
594 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
595 resultEType = quantType.getStorageType();
597 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
601 "expect both bias and result to have same element type, got ")
602 << biasEType <<
" and " << resultEType;
606 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
607 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
608 if (inputEType != weightEType) {
610 "expect both input and weight to have same element type, got ")
611 << inputEType <<
" and " << weightEType;
616 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
617 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
620 if (inputIsFloat != weightIsFloat) {
622 "expect both input and weight to be float or not together, got ")
623 << inputEType <<
" and " << weightEType;
628 if (inputEType != inputZpEType) {
629 return op.emitOpError(
"expect both input and its zero point are the same "
630 "element type, got ")
631 << inputEType <<
" and " << inputZpEType;
635 if (weightEType != weightZpEType) {
636 return op.emitOpError(
"expect both weight and its zero point are the same "
637 "element type, got ")
638 << weightEType <<
" and " << weightZpEType;
641 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
642 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
645 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
646 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
654 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().
getType());
655 auto outputType = llvm::dyn_cast<TensorType>(getOutput().
getType());
657 if (!attrType || !outputType) {
658 emitOpError(
"expected tensors for attr/result type");
662 if (
auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
663 outputType.getElementType())) {
664 if (result.getStorageType() == attrType.getElementType())
668 if (attrType.getElementType() != outputType.getElementType()) {
669 emitOpError(
"expected same attr/result element types");
676 template <
typename T>
679 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
681 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
682 inputEType = quantType.getStorageType();
684 auto accType = op.getAccType();
685 if (inputEType.isInteger(8) && !accType.isInteger(32))
686 return op.emitOpError(
"accumulator type for i8 tensor is not i32");
688 if (inputEType.isInteger(16) && !accType.isInteger(48))
689 return op.emitOpError(
"accumulator type for i16 tensor is not i48");
691 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
692 return op.emitOpError(
"accumulator type for f8 tensor is not f16");
694 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
695 return op.emitOpError(
"accumulator type for f16 tensor is not f16/f32");
697 if (inputEType.isBF16() && !accType.isF32())
698 return op.emitOpError(
"accumulator type for bf16 tensor is not f32");
700 if (inputEType.isF32() && !accType.isF32())
701 return op.emitOpError(
"accumulator type for f32 tensor is not f32");
704 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
706 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
707 resultEType = quantType.getStorageType();
717 template <
typename T>
720 if (llvm::any_of(padding, [](int64_t p) {
return p < 0; }))
721 return op.emitOpError(
"expect all padding values to be >= 0, got ")
725 if (llvm::any_of(strides, [](int64_t s) {
return s < 1; }))
726 return op.emitOpError(
"expect all stride values to be >= 1, got ")
730 if (llvm::any_of(dilations, [](int64_t d) {
return d < 1; }))
731 return op.emitOpError(
"expect all dilation values to be >= 1, got ")
734 const RankedTensorType outputType =
735 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
740 const RankedTensorType inputType =
741 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
742 const RankedTensorType weightType =
743 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
745 if (inputType && weightType) {
746 const auto verifyOutputSize =
747 [&op](
const int64_t inputSize,
const int64_t kernelSize,
748 const int64_t outputSize,
const int64_t padBefore,
749 const int64_t padAfter,
const int64_t stride,
750 const int64_t dilation,
const llvm::StringRef dimName,
751 const llvm::StringRef dimAxis,
752 const llvm::StringRef padBeforeName,
753 const llvm::StringRef padAfterName) -> LogicalResult {
754 if (inputSize == ShapedType::kDynamic ||
755 kernelSize == ShapedType::kDynamic)
760 const std::optional<int64_t> calculatedOutSizeMinusOne =
idivCheck(
761 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
763 if (!calculatedOutSizeMinusOne.has_value())
764 return op.emitOpError(
"expected input_")
765 << dimName <<
" - 1 + pad_" << padBeforeName <<
" + pad_"
766 << padAfterName <<
" - (kernel_" << dimName
767 <<
" - 1) * dilation_" << dimAxis
768 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
769 << inputSize <<
" - 1 + " << padBefore <<
" + " << padAfter
770 <<
" - (" << kernelSize <<
" - 1) * " << dilation <<
") / "
773 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
774 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
775 return op.emitOpError(
"calculated output ")
776 << dimName <<
" did not match expected: "
777 <<
"calculated=" << calculatedOutSize
778 <<
", expected=" << outputSize;
784 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
785 if (
failed(verifyOutputSize(
786 inputType.getDimSize(1), weightType.getDimSize(1),
787 outputType.getDimSize(1), padding[0], padding[1], strides[0],
788 dilations[0],
"height",
"y",
"top",
"bottom")))
791 if (
failed(verifyOutputSize(
792 inputType.getDimSize(2), weightType.getDimSize(2),
793 outputType.getDimSize(2), padding[2], padding[3], strides[1],
794 dilations[1],
"width",
"x",
"left",
"right")))
799 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
800 if (
failed(verifyOutputSize(
801 inputType.getDimSize(1), weightType.getDimSize(0),
802 outputType.getDimSize(1), padding[0], padding[1], strides[0],
803 dilations[0],
"height",
"y",
"top",
"bottom")))
806 if (
failed(verifyOutputSize(
807 inputType.getDimSize(2), weightType.getDimSize(1),
808 outputType.getDimSize(2), padding[2], padding[3], strides[1],
809 dilations[1],
"width",
"x",
"left",
"right")))
814 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
815 if (
failed(verifyOutputSize(
816 inputType.getDimSize(1), weightType.getDimSize(1),
817 outputType.getDimSize(1), padding[0], padding[1], strides[0],
818 dilations[0],
"depth",
"d",
"front",
"back")))
821 if (
failed(verifyOutputSize(
822 inputType.getDimSize(2), weightType.getDimSize(2),
823 outputType.getDimSize(2), padding[2], padding[3], strides[1],
824 dilations[1],
"height",
"y",
"top",
"bottom")))
827 if (
failed(verifyOutputSize(
828 inputType.getDimSize(3), weightType.getDimSize(3),
829 outputType.getDimSize(3), padding[4], padding[5], strides[2],
830 dilations[2],
"width",
"x",
"left",
"right")))
835 const RankedTensorType biasType =
836 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
841 const int64_t biasChannels = biasType.getDimSize(0);
842 const int64_t outputChannels =
843 outputType.getDimSize(outputType.getRank() - 1);
844 if (biasChannels == ShapedType::kDynamic ||
845 outputChannels == ShapedType::kDynamic)
849 if (biasChannels != outputChannels && biasChannels != 1)
850 return op.emitOpError(
851 "bias channels expected to be equal to output channels (")
852 << outputChannels <<
") or 1, got " << biasChannels;
859 StringRef name1,
Type type2,
861 auto shapeType1 = dyn_cast<ShapedType>(type1);
862 auto shapeType2 = dyn_cast<ShapedType>(type2);
863 if (!shapeType1 || !shapeType2)
866 auto elemType1 = shapeType1.getElementType();
867 auto elemType2 = shapeType2.getElementType();
868 if (elemType1 != elemType2)
870 <<
"require same element type for " << name1 <<
" (" << elemType1
871 <<
") and " << name2 <<
" (" << elemType2 <<
")";
875 <<
"require same shapes for " << name1 <<
" (" << type1 <<
") and "
876 << name2 <<
" (" << type2 <<
")";
886 if (list1.size() != list2.size())
888 <<
"require same number of values in " << name1 <<
" ("
889 << list1.size() <<
") and " << name2 <<
" (" << list2.size() <<
")";
891 for (
auto [type1, type2] :
905 return shapeAdaptor.
getNumElements() == 1 ? success() : failure();
913 tosa::VariableOp varOp =
nullptr;
927 if (
auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
928 if (symName == tosaOp.getName()) {
943 template <
typename T>
945 StringRef symName = op.getName();
948 return op->emitOpError(
"'")
949 << symName <<
"' has not been declared by 'tosa.variable'";
962 template <
typename T>
964 auto inputType = llvm::dyn_cast<TensorType>(inType);
965 auto outputType = llvm::dyn_cast<TensorType>(outType);
967 op.emitOpError(
"expect shaped tensor for input, got ") << inType;
971 op.emitOpError(
"expect shaped tensor for output, got ") << outType;
974 auto inputElementType = inputType.getElementType();
975 auto outputElementType = outputType.getElementType();
976 auto inputQuantType =
977 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
978 auto outputQuantType =
979 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
980 if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
981 (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
982 inputElementType != outputElementType) {
987 op.emitOpError(
"expect input and output to have same element type, got ")
988 << inputElementType <<
" and " << outputElementType;
995 const ShapedType resultType = llvm::cast<ShapedType>(
getType());
998 if (
const auto resultETy = resultType.getElementType();
999 !resultETy.isIntOrIndex())
1000 return emitOpError(
"result tensor is not of integer type");
1002 const auto inputType = llvm::cast<ShapedType>(getInput().
getType());
1003 if (!inputType.hasRank())
1007 const int64_t axis = getAxisAttr().getInt();
1008 if (((axis < 0) || axis >= inputType.getRank()))
1009 return emitOpError(
"specified axis is outside the rank of the tensor");
1011 if (!resultType.hasRank())
1017 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
1019 return emitOpError(
"expected output shape '")
1020 << expectedOutputShape <<
"', got '" << outputShape <<
"'";
1025 template <
typename T>
1028 if (llvm::any_of(kernel, [](int64_t s) {
return s < 1; }))
1029 return op.emitOpError(
"expect all kernel values to be >= 1, got ")
1033 if (llvm::any_of(strides, [](int64_t s) {
return s < 1; }))
1034 return op.emitOpError(
"expect all stride values to be >= 1, got ")
1038 if (llvm::any_of(padding, [](int64_t p) {
return p < 0; }))
1039 return op.emitOpError(
"expect all padding values to be >= 0, got ")
1043 const int64_t kernelX = kernel[1];
1044 const int64_t padLeft = padding[2];
1045 const int64_t padRight = padding[3];
1046 if (padRight >= kernelX || padLeft >= kernelX)
1047 return op.emitOpError(
"expected left/right padding to be less than the "
1048 "width of the kernel, got pad_left=")
1049 << padLeft <<
", pad_right=" << padRight <<
", kernel_x=" << kernelX;
1051 const int64_t kernelY = kernel[0];
1052 const int64_t padTop = padding[0];
1053 const int64_t padBottom = padding[1];
1054 if (padTop >= kernelY || padBottom >= kernelY)
1055 return op.emitOpError(
"expected top/bottom padding to be less than the "
1056 "height of the kernel, got pad_top=")
1057 << padTop <<
", pad_bottom=" << padBottom
1058 <<
", kernel_y=" << kernelY;
1060 const auto inputType =
1061 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
1062 const auto outputType =
1063 llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
1064 if (!inputType || !outputType)
1067 const auto verifyOutputSize =
1068 [&op](
const int64_t inputSize,
const int64_t outputSize,
1069 const int64_t kernelSize,
const int64_t strideSize,
1070 const int64_t padBefore,
const int64_t padAfter,
1071 const llvm::StringRef dimName,
const llvm::StringRef dimAxis,
1072 const llvm::StringRef padBeforeName,
1073 const llvm::StringRef padAfterName) -> LogicalResult {
1074 if (ShapedType::isDynamic(inputSize))
1077 const std::optional<int64_t> calculatedOutSizeMinusOne =
1078 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1079 if (!calculatedOutSizeMinusOne.has_value())
1080 return op.emitOpError(
"expected input_")
1081 << dimName <<
" + pad_" << padBeforeName <<
" + pad_"
1082 << padAfterName <<
" - kernel_" << dimAxis
1083 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
1084 << inputSize <<
" + " << padBefore <<
" + " << padAfter <<
" - "
1085 << kernelSize <<
") / " << strideSize;
1087 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1088 if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1089 return op.emitOpError(
"calculated output ")
1090 << dimName <<
" did not match expected: "
1091 <<
"calculated=" << calculatedOutSize
1092 <<
", expected=" << outputSize;
1097 if (
failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
1098 kernel[0], strides[0], padding[0], padding[1],
1099 "height",
"y",
"top",
"bottom")))
1102 if (
failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
1103 kernel[1], strides[1], padding[2], padding[3],
1104 "width",
"x",
"left",
"right")))
1119 auto accType = getAccType();
1120 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1121 return emitOpError(
"accumulator type for integer tensor is not i32");
1123 if (inputETy.
isF16() && !(accType.isF16() || accType.isF32()))
1124 return emitOpError(
"accumulator type for f16 tensor is not f16/f32");
1126 if (inputETy.
isBF16() && !accType.isF32())
1127 return emitOpError(
"accumulator type for bf16 tensor is not f32");
1129 if (inputETy.
isF32() && !accType.isF32())
1130 return emitOpError(
"accumulator type for f32 tensor is not f32");
1132 if (inputETy != inputZpETy)
1133 return emitOpError(
"expect both input and its zero point are the same "
1134 "element type, got ")
1135 << inputETy <<
" and " << inputZpETy;
1137 if (resultETy != outputZpETy)
1138 return emitOpError(
"expect both output and its zero point are the same "
1139 "element type, got ")
1140 << resultETy <<
" and " << outputZpETy;
1142 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
1143 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
1146 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1147 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
1155 llvm::cast<ShapedType>(getInput().
getType()).getElementType();
1156 if (
auto quantType =
1157 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1158 inputETy = quantType.getStorageType();
1161 llvm::cast<ShapedType>(getOutput().
getType()).getElementType();
1162 if (
auto quantType =
1163 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1164 outputETy = quantType.getStorageType();
1166 if (inputETy != outputETy)
1167 return emitOpError(
"input/output element types are incompatible.");
1169 auto maxValAttr = getMaxValAttr();
1170 auto minValAttr = getMinValAttr();
1174 if (inputETy.
isInteger(dataTypeBitWidth)) {
1178 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1179 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1180 if (!intMaxValAttr || !intMinValAttr ||
1181 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1182 (intMaxValAttr.getType() != inputETy))
1183 return emitOpError(
"min/max attributes types are incompatible with "
1184 "input/output element types.");
1187 const bool isBoolean = inputETy.
isInteger(1);
1188 const APInt minVal = intMinValAttr.getValue();
1189 const APInt maxVal = intMaxValAttr.getValue();
1190 if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1191 return emitOpError(
"expected min_val <= max_val, got min_val=")
1192 << minValAttr <<
", max_val=" << maxValAttr;
1197 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1198 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1199 if (!floatMaxValAttr || !floatMinValAttr ||
1200 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1201 (floatMaxValAttr.getType() != inputETy))
1202 return emitOpError(
"min/max attributes types are incompatible with "
1203 "input/output element types.");
1205 const APFloat minVal = floatMinValAttr.getValue();
1206 const APFloat maxVal = floatMaxValAttr.getValue();
1207 if (minVal.isNaN() || maxVal.isNaN())
1208 return emitOpError(
"min/max attributes should not be 'NaN', got min_val=")
1209 << minValAttr <<
", max_val=" << maxValAttr;
1211 if (maxVal < minVal)
1212 return emitOpError(
"expected min_val <= max_val, got min_val=")
1213 << minValAttr <<
", max_val=" << maxValAttr;
1233 result.
addOperands({input, weight, bias, zps.first, zps.second});
1238 Type finalOutputType = outputType;
1255 result.
addOperands({input, weight, bias, zps.first, zps.second});
1259 Type finalOutputType = outputType;
1276 result.
addOperands({a, b, zps.first, zps.second});
1278 Type finalOutputType{outputType};
1281 auto inputBits = eType.getIntOrFloatBitWidth();
1283 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1284 assert(outputShapedType &&
"Output must be a shaped type");
1286 IntegerType accElementType;
1287 if (inputBits == 16)
1292 finalOutputType = outputShapedType.clone(accElementType);
1303 DenseArrayAttr kernel, DenseArrayAttr stride,
1304 DenseArrayAttr pad, TypeAttr accType) {
1307 int64_t outputZp{0};
1309 if (
auto quantAttr =
1311 inputZp = quantAttr.getInputZp();
1312 outputZp = quantAttr.getOutputZp();
1314 const std::optional<Value> inputZpOp =
1319 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1321 const std::optional<Value> outputZpOp =
1324 (void)
emitError(loc,
"Failed to create output zero point tensor for "
1325 "quantized AVG_POOL2D op");
1328 if (inputZpOp && outputZpOp) {
1329 result.
addOperands({input, inputZpOp.value(), outputZpOp.value()});
1340 result.
types.push_back(outputType);
1350 int64_t input1Zp{0};
1351 int64_t outputZp{0};
1354 input1Zp = quantAttr.getInputZp();
1355 outputZp = quantAttr.getOutputZp();
1357 const std::optional<Value> input1ZpOp =
1361 loc,
"Failed to create input1 zero point for quantized NEGATE op");
1364 const std::optional<Value> outputZpOp =
1368 loc,
"Failed to create output zero point for quantized NEGATE op");
1371 if (input1ZpOp && outputZpOp) {
1372 result.
addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1380 result.
types.push_back(outputType);
1393 zp =
static_cast<int32_t
>(quantAttr.getInputZp());
1396 result.
addOperands({input, paddings, padConstOp});
1397 result.
types.push_back(outputType);
1401 StringRef name,
Type variableType,
1406 auto shapedType = dyn_cast<ShapedType>(variableType);
1408 (void)
emitError(loc,
"variable type must be a shaped type");
1411 if (!shapedType.hasRank()) {
1412 (void)
emitError(loc,
"variable type must be a ranked type");
1416 auto elementType = shapedType.getElementType();
1433 int64_t outRank = 0;
1434 for (
int i = 0, e = operands.size(); i != e; ++i) {
1436 if (!shape.hasRank()) {
1441 outRank = std::max<int64_t>(outRank, shape.getRank());
1444 outShape.resize(outRank, 1);
1446 for (
int i = 0, e = operands.size(); i != e; ++i) {
1448 auto rankDiff = outShape.size() - shape.getRank();
1450 for (
size_t i = 0, e = shape.getRank(); i < e; ++i) {
1451 auto dim1 = outShape[i + rankDiff];
1452 auto dim2 = shape.getDimSize(i);
1453 auto resolvedDim = dim1;
1457 }
else if (dim2 == 1) {
1459 }
else if (dim1 != dim2) {
1462 outShape[i + rankDiff] = resolvedDim;
1469 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1470 MLIRContext *context, ::std::optional<Location> location,
1471 ArgMaxOp::Adaptor adaptor,
1474 IntegerAttr axis = adaptor.getProperties().axis;
1475 int32_t axisVal = axis.getValue().getSExtValue();
1477 if (!inputShape.hasRank()) {
1483 outShape.reserve(inputShape.getRank() - 1);
1484 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1487 outShape.push_back(inputShape.getDimSize(i));
1494 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1495 MLIRContext *context, ::std::optional<Location> location,
1496 RFFT2dOp::Adaptor adaptor,
1498 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1500 if (!inputShape.hasRank())
1504 outputShape.resize(3, ShapedType::kDynamic);
1505 outputShape[0] = inputShape.getDimSize(0);
1506 outputShape[1] = inputShape.getDimSize(1);
1507 int64_t inWidth = inputShape.getDimSize(2);
1511 if (inWidth != ShapedType::kDynamic)
1512 outputShape[2] = inWidth / 2 + 1;
1521 const llvm::StringRef dimName) {
1522 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1525 << dimName <<
" to be a power of two, got " << dimSize;
1531 const auto outputTypes = getResultTypes();
1533 return emitOpError(
"expected output shapes to match, got ") << outputTypes;
1535 const auto inputType =
1536 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1540 const int64_t height = inputType.getDimSize(1);
1541 if (ShapedType::isStatic(height) &&
1545 const int64_t width = inputType.getDimSize(2);
1546 if (ShapedType::isStatic(width) &&
1550 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1556 outputType.getShape().drop_back())))
1557 return emitOpError(
"expected batch and height dimensions of input/output "
1558 "to match, got input=")
1559 << inputType <<
" output=" << outputType;
1562 const int64_t outputWidth = outputType.getDimSize(2);
1563 if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1564 (outputWidth != (width / 2) + 1))
1566 "expected output width to be equal to input_width / 2 + 1, got ")
1572 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1573 MLIRContext *context, ::std::optional<Location> location,
1574 FFT2dOp::Adaptor adaptor,
1576 inferredReturnShapes.push_back(
1578 inferredReturnShapes.push_back(
1584 const auto inputRealType =
1585 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1586 const auto inputImagType =
1587 llvm::dyn_cast<RankedTensorType>(getInputImag().
getType());
1588 if (!inputRealType || !inputImagType)
1591 const auto trySelectStaticDim = [](
const int64_t a,
const int64_t b) {
1592 return ShapedType::isDynamic(a) ? a : b;
1595 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1596 inputImagType.getDimSize(1));
1597 if (ShapedType::isStatic(height) &&
1601 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1602 inputImagType.getDimSize(2));
1603 if (ShapedType::isStatic(width) &&
1610 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1611 MLIRContext *context, ::std::optional<Location> location,
1612 ConcatOp::Adaptor adaptor,
1615 const Properties &prop = adaptor.getProperties();
1616 int32_t axis = prop.axis.getValue().getSExtValue();
1618 bool hasRankedInput =
false;
1619 for (
auto operand : adaptor.getOperands()) {
1621 if (!operandShape.hasRank())
1625 if (!hasRankedInput)
1626 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1629 for (
int i = 0, s = operandShape.getRank(); i < s; i++) {
1630 if (i == axis || operandShape.isDynamicDim(i))
1632 if (outputShape[i] == ShapedType::kDynamic)
1633 outputShape[i] = operandShape.getDimSize(i);
1634 if (outputShape[i] != operandShape.getDimSize(i))
1636 "Cannot concat tensors with different sizes"
1637 " on the non-axis dimension ",
1641 hasRankedInput =
true;
1644 if (adaptor.getInput1().empty())
1648 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1649 if (!hasRankedInput) {
1655 int64_t concatDimSize = 0;
1656 for (
auto operand : adaptor.getOperands()) {
1661 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1662 concatDimSize = ShapedType::kDynamic;
1666 concatDimSize += operandShape.getDimSize(axis);
1669 outputShape[axis] = concatDimSize;
1677 auto outType = getOutput().getType();
1681 if (inputList.empty())
1682 return emitOpError(
"expect at least one input");
1684 if (!llvm::all_of(inputList, [&](
auto input) {
1686 *
this, input.getType(), outType));
1691 const int32_t axis = getAxis();
1693 for (
const auto &input : inputList) {
1694 const Type inputType = input.getType();
1696 if (currShape.hasRank()) {
1697 firstRankedInputShape = currShape;
1699 if (axis < 0 || axis >= firstRankedInputShape.
getRank())
1700 return emitOpError(
"expect axis to be within range 0 < axis < "
1701 "rank(input1[firstRankedTensorIdx]), got ")
1707 const auto allOperandsHasRank = [](
const Value input) {
1710 if (llvm::all_of(inputList, allOperandsHasRank)) {
1711 const int64_t firstInputRank = firstRankedInputShape.
getRank();
1713 for (
const auto &[index, input] :
llvm::enumerate(inputList.drop_front())) {
1715 const int64_t inputRank = inputShape.getRank();
1716 const size_t operandNum = index + 1;
1719 if (inputRank != firstInputRank)
1721 "expect all operands to have the same rank, but got ")
1722 << firstInputRank <<
" vs " << inputRank <<
" on operands 0 and "
1726 for (
int i = 0; i < inputRank; i++) {
1727 const int64_t inputDim = inputShape.getDimSize(i);
1728 const int64_t firstInputDim = firstRankedInputShape.
getDimSize(i);
1729 if (i == axis || firstRankedInputShape.
isDynamicDim(i) ||
1730 inputShape.isDynamicDim(i))
1732 if (inputDim != firstInputDim)
1733 return emitOpError(
"expect all operand shapes to have the same sizes "
1734 "on non-axis dimensions, but got ")
1735 << inputDim <<
" vs " << firstInputDim <<
" at index " << i
1736 <<
" on operands 0 and " << operandNum;
1741 int64_t axisSum = 0;
1742 for (
const auto &input : inputList) {
1744 if (inputShape.isDynamicDim(axis)) {
1749 axisSum += inputShape.getDimSize(axis);
1752 if (axisSum >= 0 && outputShape.hasRank() &&
1753 !outputShape.isDynamicDim(axis) &&
1754 axisSum != outputShape.getDimSize(axis))
1755 return emitOpError(
"requires sum of axis dimensions of input1 "
1756 "equal to output axis dimension, got ")
1757 << axisSum <<
" and " << outputShape.getDimSize(axis);
1763 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1764 MLIRContext *context, ::std::optional<Location> location,
1781 if (l.size() != r.size() || l.size() != 1)
1786 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1787 MLIRContext *context, ::std::optional<Location> location,
1788 MatMulOp::Adaptor adaptor,
1795 outShape.resize(3, ShapedType::kDynamic);
1797 if (lhsShape.hasRank()) {
1798 outShape[0] = lhsShape.getDimSize(0);
1799 outShape[1] = lhsShape.getDimSize(1);
1802 if (rhsShape.hasRank()) {
1803 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1805 outShape[2] = rhsShape.getDimSize(2);
1813 auto aType = llvm::dyn_cast<ShapedType>(getA().
getType());
1814 auto bType = llvm::dyn_cast<ShapedType>(getB().
getType());
1818 return emitOpError(
"expect a shaped tensor for input a, got ")
1819 << getA().getType();
1822 return emitOpError(
"expect a shaped tensor for input b, got ")
1823 << getB().getType();
1825 auto aElementType = aType.getElementType();
1826 auto bElementType = bType.getElementType();
1828 auto aQuantizedEType =
1829 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1830 auto bQuantizedEType =
1831 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1833 if (aQuantizedEType || bQuantizedEType) {
1834 if (!aQuantizedEType || !bQuantizedEType) {
1835 return emitOpError(
"expect operands to be both quantized or both not "
1837 << aElementType <<
" and " << bElementType;
1840 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1841 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1842 if (aQuantWidth != bQuantWidth) {
1843 return emitOpError(
"expect quantized operands to have same widths, got ")
1844 << aQuantWidth <<
" and " << bQuantWidth;
1848 if (aElementType != bElementType) {
1849 return emitOpError(
"expect same element type for inputs a and b, got ")
1850 << aElementType <<
" and " << bElementType;
1857 if (aEType != aZpEType) {
1858 return emitOpError(
"expect input a and a_zp have the same "
1859 "element type, got ")
1860 << aEType <<
" and " << aZpEType;
1865 if (bEType != bZpEType) {
1866 return emitOpError(
"expect input b and b_zp have the same "
1867 "element type, got ")
1868 << bEType <<
" and " << bZpEType;
1871 FailureOr<int64_t> maybeAZp = getAZeroPoint();
1872 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1875 FailureOr<int64_t> maybeBZp = getBZeroPoint();
1876 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1882 LogicalResult tosa::PadOp::inferReturnTypeComponents(
1883 MLIRContext *context, ::std::optional<Location> location,
1884 PadOp::Adaptor adaptor,
1886 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1888 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
1893 if (!inputShape.hasRank()) {
1894 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
1903 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1908 outputShape.reserve(inputShape.getRank());
1909 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1910 if (inputShape.isDynamicDim(i)) {
1911 outputShape.push_back(ShapedType::kDynamic);
1914 auto padFront = paddingValues[i * 2];
1915 auto padBack = paddingValues[i * 2 + 1];
1916 if (padFront < 0 || padBack < 0) {
1918 outputShape.push_back(ShapedType::kDynamic);
1922 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
1936 if (
auto padConst = getPadConst()) {
1944 RankedTensorType inputType =
1945 llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1946 RankedTensorType outputType =
1947 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
1948 if (!inputType || !outputType)
1951 auto inputRank = inputType.getRank();
1952 auto outputRank = outputType.getRank();
1953 if (inputRank != outputRank)
1954 return emitOpError() <<
"expect same input and output tensor rank, but got "
1955 <<
"inputRank: " << inputRank
1956 <<
", outputRank: " << outputRank;
1963 auto paddingValues = paddingAttr.getValues<APInt>();
1964 if (paddingValues.size() !=
static_cast<size_t>(inputRank * 2))
1965 return emitOpError() <<
"padding tensor must have " << inputRank
1966 <<
" * 2 = " << inputRank * 2 <<
" elements, but got "
1967 << paddingValues.size();
1969 auto inputShape = inputType.getShape();
1970 auto outputShape = outputType.getShape();
1972 for (int64_t i = 0; i < inputRank; ++i) {
1973 int64_t padStart = paddingValues[i * 2].getSExtValue();
1974 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
1976 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
1977 return emitOpError()
1978 <<
"invalid padding values at dimension " << i
1979 <<
": values must be non-negative or -1 for dynamic padding, got ["
1980 << padStart <<
", " << padEnd <<
"]";
1984 if (inputShape[i] == ShapedType::kDynamic ||
1985 outputShape[i] == ShapedType::kDynamic)
1988 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
1989 return emitOpError() <<
"mismatch in output shape at dimension " << i
1990 <<
": expected " << inputShape[i] <<
" + "
1991 << padStart <<
" + " << padEnd <<
" = "
1992 << (inputShape[i] + padStart + padEnd)
1993 <<
", but got " << outputShape[i];
2000 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
2001 MLIRContext *context, ::std::optional<Location> location,
2002 SliceOp::Adaptor adaptor,
2011 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
2019 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2022 if (inputShape.hasRank()) {
2023 for (
size_t i = 0; i < size.size(); i++) {
2024 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
2025 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
2026 start[i] < inputShape.getDimSize(i))) {
2028 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
2031 outputShape[i] = size[i];
2035 if (size[i] == -1) {
2036 outputShape[i] = inputShape.getDimSize(i) - start[i];
2037 }
else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2039 outputShape[i] = size[i];
2058 if (inputShape.hasRank()) {
2059 const auto inputRank = inputShape.getRank();
2061 if (outputShape.hasRank() && inputRank != outputShape.getRank())
2063 "expect input1 and output to have the same ranks, got ")
2064 << inputRank <<
" and " << outputShape.getRank();
2066 const auto startShapeRank =
2067 llvm::cast<tosa::shapeType>(getStart().
getType()).getRank();
2068 if (inputRank != startShapeRank)
2069 return emitOpError(
"length of start is not equal to rank of input shape");
2071 const auto sizeShapeRank =
2072 llvm::cast<tosa::shapeType>(getSize().
getType()).getRank();
2073 if (inputRank != sizeShapeRank)
2074 return emitOpError(
"length of size is not equal to rank of input shape");
2080 LogicalResult tosa::MulOp::inferReturnTypeComponents(
2081 MLIRContext *context, ::std::optional<Location> location,
2097 const Value output = getOutput();
2102 if (
auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2103 IntegerType lhsIntType =
2105 IntegerType rhsIntType =
2107 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2108 return emitOpError(
"requires the same element type for all operands");
2113 if (lhsIntType.getWidth() > resIntType.getWidth())
2114 return emitOpError(
"invalid data type size for operands or result");
2119 for (
int i = 0; i < 2; ++i) {
2122 "requires the same element type for all operands and results");
2126 ElementsAttr shift_elem;
2128 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
2130 return emitOpError() <<
"require shift to be 0 for float type";
2138 TypeRange operandTypes = getOperandTypes();
2139 ShapedType aType = cast<ShapedType>(operandTypes[0]);
2140 ShapedType bType = cast<ShapedType>(operandTypes[1]);
2142 const bool aHasRank = aType.hasRank();
2143 const bool bHasRank = bType.hasRank();
2144 if (aHasRank && bHasRank) {
2145 const int64_t aRank = aType.getRank();
2146 const int64_t bRank = bType.getRank();
2148 return emitOpError(
"a and b operands don't have matching ranks, got ")
2149 << aRank <<
" and " << bRank;
2154 aType.getShape(), bType.getShape(), resultShape))
2155 return emitOpError(
"a and b operands don't have broadcast-compatible "
2157 << aType <<
" and " << bType;
2160 ShapedType resultType = cast<ShapedType>(output.
getType());
2161 if (!resultType.hasRank())
2164 const int64_t resultRank = resultType.getRank();
2165 if (aHasRank && resultRank != aType.getRank())
2166 return emitOpError(
"result type has different rank than a, got ")
2167 << resultRank <<
" vs " << aType.getRank();
2168 if (bHasRank && resultRank != bType.getRank())
2169 return emitOpError(
"result type has different rank than b, got ")
2170 << resultRank <<
" vs " << bType.getRank();
2175 LogicalResult tosa::TableOp::inferReturnTypeComponents(
2176 MLIRContext *context, ::std::optional<Location> location,
2177 TableOp::Adaptor adaptor,
2179 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2181 if (!inputShape.hasRank()) {
2186 inferredReturnShapes.resize(1);
2187 inputShape.getDims(inferredReturnShapes[0]);
2192 const TensorType inputType = getInput1().getType();
2193 const TensorType outputType = getOutput().getType();
2198 if (inputType.getRank() != outputType.getRank())
2199 return emitOpError()
2200 <<
"expected input tensor rank to equal result tensor rank";
2202 auto inputDims = inputType.
getShape();
2203 auto outputDims = outputType.
getShape();
2205 int64_t dim = it.index();
2206 auto [inputDim, outputDim] = it.value();
2207 if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2208 return emitOpError() <<
"dim(result, " << dim <<
") = " << outputDim
2209 <<
" doesn't match dim(input, " << dim
2210 <<
") = " << inputDim;
2222 multiples = llvm::to_vector(
2223 llvm::map_range(multiplesAttr.getValues<APInt>(),
2224 [](
const APInt &val) { return val.getSExtValue(); }));
2228 LogicalResult tosa::TileOp::inferReturnTypeComponents(
2229 MLIRContext *context, ::std::optional<Location> location,
2230 TileOp::Adaptor adaptor,
2237 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2245 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2247 if (!inputShape.hasRank()) {
2248 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2249 inferredReturnShapes.push_back(
2252 }
else if (
static_cast<size_t>(inputShape.getRank()) != multiples.size())
2256 outputShape.reserve(multiples.size());
2257 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
2258 if (multiples[i] == ShapedType::kDynamic) {
2259 outputShape.push_back(ShapedType::kDynamic);
2261 int64_t dim = inputShape.getDimSize(i);
2262 if (dim != ShapedType::kDynamic)
2263 dim *= multiples[i];
2264 outputShape.push_back(dim);
2278 ShapedType inputType = llvm::cast<ShapedType>(getInput1().
getType());
2279 ShapedType outputType = llvm::cast<ShapedType>(
getType());
2281 shapeType multiplesType =
2282 llvm::cast<tosa::shapeType>(getMultiples().
getType());
2284 auto multiplesRank = multiplesType.getRank();
2286 if (inputType.hasRank()) {
2287 if (inputType.getRank() != multiplesRank)
2288 return emitOpError(
"expect 'multiples' to have rank ")
2289 << inputType.getRank() <<
" but got " << multiplesRank <<
".";
2290 if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
2291 return emitOpError(
"expect same input and output tensor rank.");
2292 }
else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2293 return emitOpError(
"expect 'multiples' array to have length ")
2294 << outputType.getRank() <<
" but got " << multiplesRank <<
".";
2297 if (getConstantMultiples(multiples).succeeded() &&
2298 llvm::any_of(multiples, [](int64_t v) {
return v <= 0 && v != -1; }))
2300 "expect element of 'multiples' to be positive integer or -1.");
2306 if (l.size() != r.size() || l.size() != 1)
2311 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2312 MLIRContext *context, ::std::optional<Location> location,
2313 ReshapeOp::Adaptor adaptor,
2315 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2320 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2330 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2331 inferredReturnShapes.push_back(
2339 int64_t numElements = inputShape.getNumElements();
2340 int64_t staticMul = 1;
2341 for (
auto val : newShapeValue) {
2342 if (ShapedType::isStatic(val)) {
2348 for (
auto &val : newShapeValue) {
2349 if (ShapedType::isDynamic(val))
2350 val = numElements / staticMul;
2353 inferredReturnShapes.push_back(
2364 TensorType inputType = getInput1().getType();
2369 return mlir::success();
2372 int missingDims = llvm::count(shapeValues, -1);
2373 if (missingDims > 1)
2374 return emitOpError() <<
"expected at most one target dimension to be -1";
2376 const auto outputType = dyn_cast<RankedTensorType>(
getType());
2380 if ((int64_t)shapeValues.size() != outputType.getRank())
2381 return emitOpError() <<
"new shape does not match result rank";
2383 for (
auto [newShapeDim, outputShapeDim] :
2384 zip(shapeValues, outputType.getShape())) {
2385 if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2386 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2387 return emitOpError() <<
"new shape is inconsistent with result shape";
2389 if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2390 return emitOpError() <<
"new shape has invalid tensor dimension size "
2394 if (inputType.hasStaticShape()) {
2395 int64_t inputElementsNum = inputType.getNumElements();
2396 if (outputType.hasStaticShape()) {
2397 int64_t outputElementsNum = outputType.getNumElements();
2398 if (inputElementsNum != outputElementsNum) {
2399 return emitOpError() <<
"cannot reshape " << inputElementsNum
2400 <<
" elements into " << outputElementsNum;
2404 int64_t newShapeElementsNum = std::accumulate(
2405 shapeValues.begin(), shapeValues.end(), 1LL,
2406 [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
2407 bool isStaticNewShape =
2408 llvm::all_of(shapeValues, [](int64_t s) {
return s > 0; });
2409 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2410 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2411 return emitOpError() <<
"cannot reshape " << inputElementsNum
2412 <<
" elements into " << newShapeElementsNum;
2416 return mlir::success();
2423 ElementsAttr zpAttr;
2428 Type zpElemType = zpAttr.getElementType();
2430 if (llvm::isa<FloatType>(zpElemType)) {
2431 if (zpAttr.getValues<APFloat>()[0].isZero()) {
2438 if (llvm::isa<IntegerType>(zpElemType)) {
2440 return zpAttr.getValues<APInt>()[0].getSExtValue();
2442 return zpAttr.getValues<APInt>()[0].getZExtValue();
2449 template <
typename T>
2451 const std::string &operand) {
2454 if (!zpElemType.
isInteger(8) && zp != 0) {
2456 std::string lower = operand;
2457 std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
2458 return op.emitOpError()
2459 << lower <<
" zero point must be zero for non-int8 integer types";
2467 const std::string &operand) {
2468 bool isInputZp = (operand ==
"Input");
2470 bool tensorUnsigned =
2471 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2472 StringRef tensorName = isInputZp ?
"input" :
"output";
2478 !(zpElemType.
isInteger(16) && tensorUnsigned)) {
2479 return op.emitOpError()
2480 <<
"expect " << tensorName <<
"_zp of 0, got " << zp;
2482 if (zpElemType.
isInteger(16) && tensorUnsigned && zp != 32768) {
2483 return op.emitOpError() <<
"expect " << tensorName
2484 <<
"_zp of 0 or 32768 for unsigned int16 "
2485 << tensorName <<
", got " << zp;
2492 #define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2493 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2494 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2496 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2497 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2516 #undef ZERO_POINT_HELPER
2518 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2519 MLIRContext *context, ::std::optional<Location> location,
2520 TransposeOp::Adaptor adaptor,
2522 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2531 const auto inputRank = inputShape.
getRank();
2535 if (adaptor.getPerms().size() !=
static_cast<size_t>(inputRank)) {
2541 if (inputRank == 0) {
2547 bool allTheSame =
true;
2548 for (
int i = 1, s = inputRank; i < s; i++) {
2558 outputShape.resize(inputRank, inputShape.
getDimSize(0));
2563 outputShape.resize(inputRank, ShapedType::kDynamic);
2566 if (llvm::any_of(adaptor.getPerms(),
2567 [inputRank](
const auto i) { return i >= inputRank; }))
2570 outputShape.reserve(inputRank);
2571 for (
int i = 0, s = inputRank; i < s; i++) {
2572 outputShape[i] = inputShape.
getDimSize(adaptor.getPerms()[i]);
2591 if (inputShape.hasRank() &&
2592 constantPerms.size() !=
static_cast<size_t>(inputShape.getRank()))
2593 return emitOpError() <<
"expected perms attribute to have size "
2594 << inputShape.getRank()
2595 <<
" (input rank) but got size "
2596 << constantPerms.size();
2598 if (inputShape.hasRank() && outputShape.hasRank() &&
2599 inputShape.getRank() != outputShape.getRank())
2600 return emitOpError()
2601 <<
"expected input tensor rank to equal result tensor rank";
2603 if (outputShape.hasRank() &&
2604 constantPerms.size() !=
static_cast<size_t>(outputShape.getRank()))
2605 return emitOpError() <<
"expected perms attribute to have size "
2606 << outputShape.getRank()
2607 <<
" (output rank) but got size "
2608 << constantPerms.size();
2610 if (!llvm::all_of(constantPerms,
2611 [&constantPerms](int32_t s) {
2613 static_cast<size_t>(s) < constantPerms.size();
2616 constantPerms, [](int32_t v) -> int64_t {
return v; }))))
2617 return emitOpError() <<
"expected valid permutation indices";
2620 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2621 inputShape.getNumElements() != outputShape.getNumElements())
2622 return emitOpError() <<
"expected input1 and output to have same numbers "
2624 << inputShape.getNumElements() <<
" and "
2625 << outputShape.getNumElements();
2629 if (inputShape.hasRank() && outputShape.hasRank()) {
2630 for (
auto i = 0; i < outputShape.getRank(); i++) {
2631 if (inputShape.isDynamicDim(constantPerms[i]) ||
2632 outputShape.isDynamicDim(i))
2635 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2636 return emitOpError()
2637 <<
"expected output tensor dim " << i <<
" to match "
2638 <<
"input dim " << constantPerms[i] <<
" with value of "
2639 << inputShape.getDimSize(constantPerms[i]);
2651 Value input = getInput1();
2652 auto inputType = cast<TensorType>(input.
getType());
2655 for (
auto dim : transposePerms) {
2656 int32_t dimInInput = transposePerms[dim];
2657 if (inputType.isDynamicDim(dimInInput))
2659 tensor::DimOp::create(builder, getLoc(), input, dimInInput)
2663 builder.
getIndexAttr(inputType.getDimSize(dimInInput));
2666 reifiedReturnShapes.emplace_back(std::move(returnedDims));
2670 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2671 MLIRContext *context, ::std::optional<Location> location,
2672 GatherOp::Adaptor adaptor,
2675 outputShape.resize(3, ShapedType::kDynamic);
2677 ShapeAdaptor valuesShape(adaptor.getValues().getType());
2678 if (valuesShape.hasRank()) {
2679 outputShape[0] = valuesShape.getDimSize(0);
2680 outputShape[2] = valuesShape.getDimSize(2);
2683 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2684 if (indicesShape.hasRank()) {
2685 if (outputShape[0] == ShapedType::kDynamic)
2686 outputShape[0] = indicesShape.getDimSize(0);
2687 if (outputShape[1] == ShapedType::kDynamic)
2688 outputShape[1] = indicesShape.getDimSize(1);
2706 int64_t N = ShapedType::kDynamic;
2707 int64_t
W = ShapedType::kDynamic;
2708 int64_t
C = ShapedType::kDynamic;
2710 if (valuesShape.hasRank()) {
2711 N = valuesShape.getDimSize(0);
2712 C = valuesShape.getDimSize(2);
2714 if (indicesShape.hasRank()) {
2715 const int64_t indicesN = indicesShape.getDimSize(0);
2716 W = indicesShape.getDimSize(1);
2717 if (N == ShapedType::kDynamic)
2719 else if (indicesN != ShapedType::kDynamic && N != indicesN)
2720 return emitOpError() <<
"requires indices dimension 0 to have size " << N
2721 <<
", got " << indicesN;
2723 if (outputShape.hasRank()) {
2724 const int64_t outputN = outputShape.getDimSize(0);
2725 const int64_t outputW = outputShape.getDimSize(1);
2726 const int64_t outputC = outputShape.getDimSize(2);
2727 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2729 return emitOpError() <<
"requires output dimension 0 to have size " << N
2730 <<
", got " << outputN;
2732 if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2734 return emitOpError() <<
"requires output dimension 1 to have size " <<
W
2735 <<
", got " << outputW;
2736 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2738 return emitOpError() <<
"requires output dimension 2 to have size " <<
C
2739 <<
", got " << outputC;
2744 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2745 MLIRContext *context, ::std::optional<Location> location,
2746 ResizeOp::Adaptor adaptor,
2749 outputShape.resize(4, ShapedType::kDynamic);
2752 if (!inputShape.hasRank())
2755 outputShape[0] = inputShape.getDimSize(0);
2756 outputShape[3] = inputShape.getDimSize(3);
2757 int64_t inputHeight = inputShape.getDimSize(1);
2758 int64_t inputWidth = inputShape.getDimSize(2);
2760 if ((inputHeight == ShapedType::kDynamic) ||
2761 (inputWidth == ShapedType::kDynamic))
2775 const int64_t outputHeight =
2776 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2780 const int64_t outputWidth =
2781 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2785 if (outputHeight < 0 || outputWidth < 0) {
2788 "calculated output height and width must be non-negative, "
2790 outputHeight,
", width = ", outputWidth);
2793 outputShape[1] = outputHeight;
2794 outputShape[2] = outputWidth;
2800 const Value input = getInput();
2801 const Value output = getOutput();
2802 const RankedTensorType inputType =
2803 llvm::dyn_cast<RankedTensorType>(input.
getType());
2804 const RankedTensorType outputType =
2805 llvm::dyn_cast<RankedTensorType>(output.
getType());
2817 if (llvm::any_of(scaleValues, [](int64_t s) {
return s <= 0; }))
2818 return emitOpError(
"expect all scale values to be > 0, got ")
2821 const int64_t scaleYN = scaleValues[0];
2822 const int64_t scaleYD = scaleValues[1];
2823 const int64_t scaleXN = scaleValues[2];
2824 const int64_t scaleXD = scaleValues[3];
2826 const int64_t offsetY = offsetValues[0];
2827 const int64_t offsetX = offsetValues[1];
2829 const int64_t borderY = borderValues[0];
2830 const int64_t borderX = borderValues[1];
2837 const int64_t oh = outputType.getDimSize(1);
2838 const int64_t ow = outputType.getDimSize(2);
2839 const int64_t ih = inputType.getDimSize(1);
2840 const int64_t iw = inputType.getDimSize(2);
2846 if (ih != ShapedType::kDynamic && ih != 1) {
2847 const std::optional<int64_t> calculatedOutHeightMinusOne =
2848 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
2849 if (!calculatedOutHeightMinusOne.has_value())
2850 return emitOpError(
"expected (input_height - 1) * scale_y_n - offset_y + "
2852 <<
"to be wholly divisible by scale_y_d, got ((" << ih
2853 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
2854 <<
") / " << scaleYD;
2855 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
2856 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
2857 return emitOpError(
"calculated output height did not match expected: ")
2858 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
2865 if (iw != ShapedType::kDynamic && iw != 1) {
2866 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
2867 const std::optional<int64_t> calculatedOutWidthMinusOne =
2869 if (!calculatedOutWidthMinusOne.has_value())
2870 return emitOpError(
"expected (input_width - 1) * scale_x_n - offset_x + "
2872 <<
"to be wholly divisible by scale_x_d, got ((" << iw
2873 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
2874 <<
") / " << scaleXD;
2875 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
2876 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
2877 return emitOpError(
"calculated output width did not match expected: ")
2878 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
2884 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
2885 MLIRContext *context, ::std::optional<Location> location,
2886 ScatterOp::Adaptor adaptor,
2889 outputShape.resize(3, ShapedType::kDynamic);
2891 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
2892 if (valuesInShape.hasRank()) {
2893 outputShape[0] = valuesInShape.getDimSize(0);
2894 outputShape[1] = valuesInShape.getDimSize(1);
2895 outputShape[2] = valuesInShape.getDimSize(2);
2898 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2899 if (indicesShape.hasRank()) {
2900 if (outputShape[0] == ShapedType::kDynamic)
2901 outputShape[0] = indicesShape.getDimSize(0);
2905 if (inputShape.hasRank()) {
2906 if (outputShape[0] == ShapedType::kDynamic)
2907 outputShape[0] = inputShape.getDimSize(0);
2908 if (outputShape[2] == ShapedType::kDynamic)
2909 outputShape[2] = inputShape.getDimSize(2);
2931 int64_t N = ShapedType::kDynamic;
2932 int64_t K = ShapedType::kDynamic;
2933 int64_t
W = ShapedType::kDynamic;
2934 int64_t
C = ShapedType::kDynamic;
2935 if (valuesInShape.hasRank()) {
2936 N = valuesInShape.getDimSize(0);
2937 K = valuesInShape.getDimSize(1);
2938 C = valuesInShape.getDimSize(2);
2940 if (indicesShape.hasRank()) {
2941 const int64_t indicesN = indicesShape.getDimSize(0);
2942 W = indicesShape.getDimSize(1);
2943 if (N == ShapedType::kDynamic)
2945 else if (indicesN != ShapedType::kDynamic && N != indicesN)
2946 return emitOpError() <<
"requires indices dimension 0 to have size " << N
2947 <<
", got " << indicesN;
2949 if (inputShape.hasRank()) {
2950 const int64_t inputN = inputShape.getDimSize(0);
2951 const int64_t inputW = inputShape.getDimSize(1);
2952 const int64_t inputC = inputShape.getDimSize(2);
2953 if (N == ShapedType::kDynamic)
2955 else if (inputN != ShapedType::kDynamic && N != inputN)
2956 return emitOpError() <<
"requires input dimension 0 to have size " << N
2957 <<
", got " << inputN;
2958 if (W == ShapedType::kDynamic)
2960 else if (inputW != ShapedType::kDynamic && W != inputW)
2961 return emitOpError() <<
"requires input dimension 1 to have size " <<
W
2962 <<
", got " << inputW;
2964 if (C == ShapedType::kDynamic)
2966 else if (inputC != ShapedType::kDynamic && C != inputC)
2967 return emitOpError() <<
"requires input dimension 2 to have size " <<
C
2968 <<
", got " << inputC;
2970 if (outputShape.hasRank()) {
2971 const int64_t outputN = outputShape.getDimSize(0);
2972 const int64_t outputK = outputShape.getDimSize(1);
2973 const int64_t outputC = outputShape.getDimSize(2);
2974 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2976 return emitOpError() <<
"requires values_out dimension 0 to have size "
2977 << N <<
", got " << outputN;
2978 if (K == ShapedType::kDynamic)
2980 else if (outputK != ShapedType::kDynamic && K != outputK)
2981 return emitOpError() <<
"requires values_out dimension 1 to have size "
2982 << K <<
", got " << outputK;
2983 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2985 return emitOpError() <<
"requires values_out dimension 2 to have size "
2986 <<
C <<
", got " << outputC;
2988 if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
2989 return emitOpError() <<
"requires dimensions K >= W, got K=" << K
2998 int64_t axisVal = axis.getValue().getSExtValue();
2999 if (!operandShape.
hasRank() || operandShape.
getRank() <= axisVal) {
3005 operandShape.
getDims(outputShape);
3006 outputShape[axisVal] = 1;
3011 #define COMPATIBLE_RETURN_TYPES(OP) \
3012 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
3013 if (l.size() != r.size() || l.size() != 1) \
3015 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
3017 return succeeded(verifyCompatibleShape(l[0], r[0])); \
3020 #define REDUCE_SHAPE_INFER(OP) \
3021 LogicalResult OP::inferReturnTypeComponents( \
3022 MLIRContext *context, ::std::optional<Location> location, \
3023 OP::Adaptor adaptor, \
3024 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3026 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
3027 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
3028 const Properties &prop = adaptor.getProperties(); \
3029 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
3030 inferredReturnShapes); \
3032 COMPATIBLE_RETURN_TYPES(OP)
3040 #undef REDUCE_SHAPE_INFER
3042 #undef COMPATIBLE_RETURN_TYPES
3044 template <
typename T>
3047 TensorType inputType = op.getInput().getType();
3048 TensorType outputType = op.getOutput().getType();
3049 int32_t reduceAxis = op.getAxis();
3051 if (reduceAxis < 0) {
3052 op.emitOpError(
"reduce axis must not be negative");
3056 int64_t inputRank = inputType.getRank();
3059 if (reduceAxis >= inputRank && (reduceAxis != 0 || inputRank != 0)) {
3060 op.emitOpError(
"expect input tensor rank (")
3061 << inputRank <<
") to be larger than reduce axis (" << reduceAxis
3067 int64_t outputRank = outputType.getRank();
3068 if (inputType.
hasRank() && outputRank != inputType.getRank()) {
3070 "expect output tensor rank to be equal to input tensor rank");
3073 if (reduceAxis >= outputRank && (reduceAxis != 0 || outputRank != 0)) {
3074 op.emitOpError(
"expect output tensor rank (")
3075 << outputRank <<
") to be larger than reduce axis (" << reduceAxis
3081 if (outputRank != 0) {
3082 auto outputShape = outputType.
getShape();
3083 if (!outputType.isDynamicDim(reduceAxis) &&
3084 outputShape[reduceAxis] != 1) {
3085 op.emitOpError(
"expect reduced dimension size to be 1, got ")
3086 << outputShape[reduceAxis];
3113 #define NARY_SHAPE_INFER(OP) \
3114 LogicalResult OP::inferReturnTypeComponents( \
3115 MLIRContext *context, ::std::optional<Location> location, \
3116 ValueShapeRange operands, DictionaryAttr attributes, \
3117 OpaqueProperties properties, RegionRange regions, \
3118 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3119 return NAryInferReturnTypes(operands, inferredReturnShapes); \
3159 #undef PRED_SHAPE_INFER
3161 LogicalResult tosa::NegateOp::inferReturnTypeComponents(
3162 MLIRContext *context, ::std::optional<Location> location,
3163 NegateOp::Adaptor adaptor,
3165 ShapeAdaptor inputShape(adaptor.getInput1().getType());
3172 const Type input1Type = getInput1().getType();
3173 const Type outputType = getOutput().getType();
3180 return emitOpError() <<
"requires the same shape for input1 and output";
3183 const Type input1ZpEType =
3185 if (input1EType != input1ZpEType) {
3186 return emitOpError(
"expect both input1 and its zero point are the same "
3187 "element type, got ")
3188 << input1EType <<
" and " << input1ZpEType;
3191 const Type outputZpEType =
3193 if (outputEType != outputZpEType) {
3194 return emitOpError(
"expect both output and its zero point are the same "
3195 "element type, got ")
3196 << outputEType <<
" and " << outputZpEType;
3199 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
3200 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
3203 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3204 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3215 outputShape.resize(4, ShapedType::kDynamic);
3230 if (ShapedType::isStatic(height)) {
3231 int64_t padded = height + pad[0] + pad[1] - kernel[0];
3232 outputShape[1] = padded / stride[0] + 1;
3235 if (ShapedType::isStatic(width)) {
3236 int64_t padded = width + pad[2] + pad[3] - kernel[1];
3237 outputShape[2] = padded / stride[1] + 1;
3244 LogicalResult Conv2DOp::inferReturnTypeComponents(
3245 MLIRContext *context, ::std::optional<Location> location,
3246 Conv2DOp::Adaptor adaptor,
3250 int64_t inputWidth = ShapedType::kDynamic;
3251 int64_t inputHeight = ShapedType::kDynamic;
3252 int64_t weightWidth = ShapedType::kDynamic;
3253 int64_t weightHeight = ShapedType::kDynamic;
3258 if (inputShape.hasRank()) {
3259 outputShape[0] = inputShape.getDimSize(0);
3260 inputHeight = inputShape.getDimSize(1);
3261 inputWidth = inputShape.getDimSize(2);
3265 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3266 if (weightShape.hasRank()) {
3267 outputShape[3] = weightShape.getDimSize(0);
3268 weightHeight = weightShape.getDimSize(1);
3269 weightWidth = weightShape.getDimSize(2);
3274 if (biasShape.hasRank()) {
3275 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3276 ? biasShape.getDimSize(0)
3284 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3285 int64_t inputSize = inputHeight + padding[0] + padding[1];
3286 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3287 int64_t unstridedResult = inputSize - filterSize + 1;
3288 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3291 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3292 int64_t inputSize = inputWidth + padding[2] + padding[3];
3293 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3294 int64_t unstridedResult = inputSize - filterSize + 1;
3295 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3309 LogicalResult Conv3DOp::inferReturnTypeComponents(
3310 MLIRContext *context, ::std::optional<Location> location,
3311 Conv3DOp::Adaptor adaptor,
3315 int64_t inputWidth = ShapedType::kDynamic;
3316 int64_t inputHeight = ShapedType::kDynamic;
3317 int64_t inputDepth = ShapedType::kDynamic;
3319 int64_t weightWidth = ShapedType::kDynamic;
3320 int64_t weightHeight = ShapedType::kDynamic;
3321 int64_t weightDepth = ShapedType::kDynamic;
3325 if (inputShape.hasRank()) {
3326 outputShape[0] = inputShape.getDimSize(0);
3327 inputDepth = inputShape.getDimSize(1);
3328 inputHeight = inputShape.getDimSize(2);
3329 inputWidth = inputShape.getDimSize(3);
3333 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3334 if (weightShape.hasRank()) {
3335 outputShape[4] = weightShape.getDimSize(0);
3336 weightDepth = weightShape.getDimSize(1);
3337 weightHeight = weightShape.getDimSize(2);
3338 weightWidth = weightShape.getDimSize(3);
3343 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
3344 outputShape[4] = biasShape.getDimSize(0);
3351 if (ShapedType::isStatic(inputDepth) && ShapedType::isStatic(weightDepth)) {
3352 int32_t inputSize = inputDepth + pad[0] + pad[1];
3353 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
3354 int32_t unstridedResult = inputSize - filterSize + 1;
3355 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3358 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3359 int32_t inputSize = inputHeight + pad[2] + pad[3];
3360 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
3361 int32_t unstridedResult = inputSize - filterSize + 1;
3362 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3365 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3366 int32_t inputSize = inputWidth + pad[4] + pad[5];
3367 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
3368 int32_t unstridedResult = inputSize - filterSize + 1;
3369 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3383 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3384 MLIRContext *context, ::std::optional<Location> location,
3385 AvgPool2dOp::Adaptor adaptor,
3388 const Properties &prop = adaptor.getProperties();
3390 inferredReturnShapes);
3393 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3394 MLIRContext *context, ::std::optional<Location> location,
3395 MaxPool2dOp::Adaptor adaptor,
3398 const Properties &prop = adaptor.getProperties();
3400 inferredReturnShapes);
3414 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3415 MLIRContext *context, ::std::optional<Location> location,
3416 DepthwiseConv2DOp::Adaptor adaptor,
3420 int64_t inputWidth = ShapedType::kDynamic;
3421 int64_t inputHeight = ShapedType::kDynamic;
3422 int64_t inputChannels = ShapedType::kDynamic;
3424 int64_t weightWidth = ShapedType::kDynamic;
3425 int64_t weightHeight = ShapedType::kDynamic;
3426 int64_t depthChannels = ShapedType::kDynamic;
3430 if (inputShape.hasRank()) {
3431 outputShape[0] = inputShape.getDimSize(0);
3432 inputHeight = inputShape.getDimSize(1);
3433 inputWidth = inputShape.getDimSize(2);
3434 inputChannels = inputShape.getDimSize(3);
3438 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3439 if (weightShape.hasRank()) {
3440 weightHeight = weightShape.getDimSize(0);
3441 weightWidth = weightShape.getDimSize(1);
3442 inputChannels = ShapedType::isDynamic(inputChannels)
3443 ? weightShape.getDimSize(2)
3445 depthChannels = weightShape.getDimSize(3);
3450 if (ShapedType::isStatic(inputChannels) &&
3451 ShapedType::isStatic(depthChannels)) {
3452 outputShape[3] = inputChannels * depthChannels;
3457 if (biasShape.hasRank()) {
3458 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3459 ? biasShape.getDimSize(0)
3467 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3468 int64_t inputSize = inputHeight + padding[0] + padding[1];
3469 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3470 int64_t unstridedResult = inputSize - filterSize + 1;
3471 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3474 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3475 int64_t inputSize = inputWidth + padding[2] + padding[3];
3476 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3477 int64_t unstridedResult = inputSize - filterSize + 1;
3478 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3492 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3493 MLIRContext *context, ::std::optional<Location> location,
3494 TransposeConv2DOp::Adaptor adaptor,
3498 int64_t inputWidth = ShapedType::kDynamic;
3499 int64_t inputHeight = ShapedType::kDynamic;
3500 int64_t weightWidth = ShapedType::kDynamic;
3501 int64_t weightHeight = ShapedType::kDynamic;
3505 if (inputShape.hasRank()) {
3506 outputShape[0] = ShapedType::isDynamic(outputShape[0])
3507 ? inputShape.getDimSize(0)
3509 inputHeight = inputShape.getDimSize(1);
3510 inputWidth = inputShape.getDimSize(2);
3514 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3515 if (weightShape.hasRank()) {
3516 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3517 ? weightShape.getDimSize(0)
3519 weightHeight = weightShape.getDimSize(1);
3520 weightWidth = weightShape.getDimSize(2);
3525 if (biasShape.hasRank()) {
3526 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3527 ? biasShape.getDimSize(0)
3534 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3535 int64_t calculateSize =
3536 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3538 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3541 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3542 int64_t calculateSize =
3543 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3545 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3557 const int64_t strideY = strides[0];
3558 const int64_t strideX = strides[1];
3560 if (strideY < 1 || strideX < 1)
3561 return emitOpError(
"expect all stride values to be >= 1, got [")
3564 const auto checkPadAgainstKernelDim =
3565 [
this](int64_t pad_value, int64_t kernel_dim_size,
3566 llvm::StringRef pad_name,
3567 llvm::StringRef kernel_dim_name) -> LogicalResult {
3568 if (pad_value <= -kernel_dim_size)
3569 return emitOpError(
"expected ")
3570 << pad_name <<
" > -" << kernel_dim_name
3571 <<
", but got: " << pad_name <<
"=" << pad_value <<
" and "
3572 << kernel_dim_name <<
"=" << kernel_dim_size;
3577 const int64_t outPadTop = padding[0];
3578 const int64_t outPadBottom = padding[1];
3579 const int64_t outPadLeft = padding[2];
3580 const int64_t outPadRight = padding[3];
3582 const auto weightType =
3583 llvm::dyn_cast<RankedTensorType>(getWeight().
getType());
3586 const int64_t kernelHeight = weightType.getDimSize(1);
3587 if (ShapedType::isStatic(kernelHeight)) {
3588 if (
failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
3589 "out_pad_top",
"KH")))
3592 if (
failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3593 "out_pad_bottom",
"KH")))
3597 const int64_t kernelWidth = weightType.getDimSize(2);
3598 if (ShapedType::isStatic(kernelWidth)) {
3599 if (
failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
3600 "out_pad_left",
"KW")))
3603 if (
failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3604 "out_pad_right",
"KW")))
3610 const auto outputType =
3611 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
3615 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
3616 if (inputType && weightType) {
3617 const int64_t inputHeight = inputType.getDimSize(1);
3618 const int64_t kernelHeight = weightType.getDimSize(1);
3619 const int64_t outputHeight = outputType.getDimSize(1);
3621 if (ShapedType::isStatic(inputHeight) &&
3622 ShapedType::isStatic(outputHeight)) {
3624 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3626 "dimension mismatch: expected OH == (IH - 1) * stride_y "
3627 "+ out_pad_top + out_pad_bottom + KH, but got ")
3628 << outputHeight <<
" != (" << inputHeight <<
" - 1) * "
3629 << strideY <<
" + " << outPadTop <<
" + " << outPadBottom
3630 <<
" + " << kernelHeight;
3633 const int64_t inputWidth = inputType.getDimSize(2);
3634 const int64_t kernelWidth = weightType.getDimSize(2);
3635 const int64_t outputWidth = outputType.getDimSize(2);
3637 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
3639 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3641 "dimension mismatch: expected OW == (IW - 1) * stride_x "
3642 "+ out_pad_left + out_pad_right + KW, but got ")
3643 << outputWidth <<
" != (" << inputWidth <<
" - 1) * " << strideX
3644 <<
" + " << outPadLeft <<
" + " << outPadRight <<
" + "
3649 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().
getType());
3654 const int64_t biasChannels = biasType.getDimSize(0);
3657 if (biasChannels == ShapedType::kDynamic)
3660 const int64_t outputChannels = outputType.getDimSize(3);
3661 if (!ShapedType::isDynamic(outputChannels) &&
3662 biasChannels != outputChannels && biasChannels != 1)
3664 "bias channels expected to be equal to output channels (")
3665 << outputChannels <<
") or 1, got " << biasChannels;
3671 auto inputType = llvm::dyn_cast<ShapedType>(getInput().
getType());
3673 emitOpError(
"expect shaped tensor for input, got ") << getInput().getType();
3677 auto inputElementType =
3679 if (!mlir::isa<IntegerType>(inputElementType)) {
3680 emitOpError(
"expect input to have integer element type, got ")
3681 << inputElementType;
3685 auto outputType = llvm::dyn_cast<ShapedType>(getOutput().
getType());
3687 emitOpError(
"expect shaped tensor for output, got ")
3688 << getOutput().getType();
3692 auto outputElementType =
3694 if (!mlir::isa<IntegerType>(outputElementType)) {
3695 emitOpError(
"expect output to have integer element type, got ")
3696 << outputElementType;
3708 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
3709 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
3712 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3713 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3716 auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().
getType());
3717 if (!multiplierType) {
3718 emitOpError(
"expect shaped tensor for multiplier, got ")
3719 << getMultiplier().getType();
3723 auto shiftType = llvm::dyn_cast<ShapedType>(getShift().
getType());
3725 emitOpError(
"expect shaped tensor for shift, got ") << getShift().getType();
3730 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
3731 emitOpError(
"expect i32 element type for multiplier for scale32=true, got ")
3732 << multiplierType.getElementType();
3737 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
3739 "expect i16 element type for multiplier for scale32=false, got ")
3740 << multiplierType.getElementType();
3744 if (!inputType.hasRank())
3750 int64_t numChannels = 1;
3751 if (getPerChannel()) {
3752 if (inputType.getRank() < 1) {
3753 emitOpError(
"requires input to be at least rank 1 when per_channel is "
3754 "true, but got rank ")
3755 << inputType.getRank();
3758 numChannels = inputType.getDimSize(inputType.getRank() - 1);
3761 if (!multiplierType.hasRank())
3766 if (multiplierShape[0] != ShapedType::kDynamic &&
3767 multiplierShape[0] != numChannels) {
3768 emitOpError(
"expect shape of { ")
3769 << numChannels <<
" } for multiplier input, got { "
3770 << multiplierShape[0] <<
" }";
3774 if (!shiftType.hasRank())
3779 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3780 emitOpError(
"expect shape of { ")
3781 << numChannels <<
" } for shift input, got { " << shiftShape[0] <<
" }";
3788 LogicalResult RescaleOp::inferReturnTypeComponents(
3789 MLIRContext *context, ::std::optional<Location> location,
3790 RescaleOp::Adaptor adaptor,
3797 LogicalResult IfOp::inferReturnTypeComponents(
3798 MLIRContext *context, ::std::optional<Location> location,
3799 IfOp::Adaptor adaptor,
3802 for (
Region *region : adaptor.getRegions()) {
3803 for (
auto &block : *region)
3804 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3805 yieldOps.push_back(returnOp);
3808 if (yieldOps.empty())
3813 resultKnowledge.reserve(yieldOps.front().getNumOperands());
3814 for (
auto operand : yieldOps.front().getOperands()) {
3815 resultKnowledge.push_back(
3819 for (
auto yieldOp : yieldOps) {
3820 if (resultKnowledge.size() != yieldOp.getNumOperands())
3824 int32_t index = it.index();
3826 resultKnowledge[index],
3830 resultKnowledge[index] = meet;
3835 inferredReturnShapes.push_back(result.getShapedTypeComponents());
3841 LogicalResult WhileOp::inferReturnTypeComponents(
3842 MLIRContext *context, ::std::optional<Location> location,
3843 WhileOp::Adaptor adaptor,
3846 for (
auto &block : adaptor.getBodyGraph())
3847 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3848 yieldOps.push_back(returnOp);
3852 if (yieldOps.empty())
3857 resultKnowledge.reserve(yieldOps.front().getNumOperands());
3858 for (
auto operand : yieldOps.front().getOperands()) {
3859 resultKnowledge.push_back(
3863 for (
auto yieldOp : yieldOps) {
3864 if (resultKnowledge.size() != yieldOp.getNumOperands())
3868 int32_t index = it.index();
3870 resultKnowledge[index],
3872 resultKnowledge[index] = meet;
3878 inferredReturnShapes.push_back(result.getShapedTypeComponents());
3884 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
3885 if (
auto vt = llvm::dyn_cast<VectorType>(
getType()))
3886 return llvm::to_vector<4>(vt.getShape());
3887 return std::nullopt;
3893 StringRef prefix =
"") {
3894 assert(blocksArgs.size() == initializers.size() &&
3895 "expected same length of arguments and initializers");
3896 if (initializers.empty())
3899 parser << prefix <<
'(';
3900 llvm::interleaveComma(
3901 llvm::zip(blocksArgs, initializers), parser,
3902 [&](
auto it) { parser << std::get<0>(it) <<
" = " << std::get<1>(it); });
3930 "expected type for condition operand");
3936 "expected type for condition operand");
3944 FunctionType functionType;
3948 <<
"expected list of types for block arguments "
3949 <<
"followed by arrow type and list of return types";
3951 result.
addTypes(functionType.getResults());
3953 if (functionType.getNumInputs() != operands.size()) {
3955 <<
"expected as many input types as operands "
3956 <<
"(expected " << operands.size() <<
" got "
3957 << functionType.getNumInputs() <<
")";
3988 p <<
" " << getCondition();
3991 getInputList(),
" ");
3993 p << getCondition().getType();
3995 if (!getInputList().empty()) {
3997 llvm::interleaveComma(getInputList().getTypes(), p);
4006 auto &elseRegion = getElseGraph();
4007 if (!elseRegion.
empty()) {
4017 "'then_graph' arguments", getInputList(),
4023 "'else_graph' arguments", getInputList(),
4029 if (getThenGraph().front().mightHaveTerminator()) {
4031 dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
4033 *
this, thenYield.getInputs(),
"'then_graph' results",
4034 getOutputList(),
"'output_list'")
4040 if (getElseGraph().front().mightHaveTerminator()) {
4042 dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
4044 *
this, elseYield.getInputs(),
"'else_graph' results",
4045 getOutputList(),
"'output_list'")
4050 auto condType = getCondition().getType();
4052 return emitOpError() <<
"'condition' must be a size 1 tensor, got "
4060 getOutputList(),
"'output_list'")
4065 "'cond_graph' arguments", getInputList(),
4071 "'body_graph' arguments", getInputList(),
4076 if (getBodyGraph().front().mightHaveTerminator()) {
4078 dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
4080 "'body_graph' results",
4081 getInputList(),
"'input_list'")
4088 if (!getCondGraph().front().mightHaveTerminator())
4092 dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
4096 if (condYield.getInputs().size() != 1)
4097 return emitOpError() <<
"require 'cond_graph' only have one result";
4099 auto condOutType = condYield.getInputs()[0].getType();
4101 return emitOpError() <<
"'cond_graph' result must be a size 1 tensor, got "
4105 return emitOpError() <<
"'cond_graph' result must be a boolean tensor, got "
4116 TensorType inputType = getInput1().getType();
4117 TensorType outputType = getOutput().getType();
4118 int32_t reverseAxis = getAxis();
4120 if (reverseAxis < 0)
4121 return emitOpError(
"expected non-negative reverse axis");
4123 int64_t inputRank = inputType.getRank();
4126 if (reverseAxis >= inputRank && (reverseAxis != 0 || inputRank != 0))
4127 return emitOpError(
"expect input tensor rank (")
4128 << inputRank <<
") to be larger than reverse axis (" << reverseAxis
4132 int64_t outputRank = outputType.getRank();
4133 if (inputType.
hasRank() && outputRank != inputType.getRank())
4135 "expect output tensor rank to be equal to input tensor rank");
4136 if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0))
4137 return emitOpError(
"expect output tensor rank (")
4138 << outputRank <<
") to be larger than reverse axis ("
4139 << reverseAxis <<
")";
4155 auto predicateType = llvm::dyn_cast<ShapedType>(getPred().
getType());
4156 if (!predicateType) {
4157 return emitOpError(
"expect shaped tensor for input1, got ")
4158 << getInput1().getType();
4160 auto predicateElementType = predicateType.getElementType();
4161 if (!predicateElementType.isInteger(1)) {
4162 return emitOpError(
"expect element type of bool for input1, got ")
4163 << predicateElementType;
4170 StringRef symName = getName();
4172 if (succeeded(varOp))
4173 return emitOpError(
"illegal to have multiple declaration of '")
4207 FunctionType functionType;
4212 result.
addTypes(functionType.getResults());
4214 if (functionType.getNumInputs() != operands.size()) {
4216 <<
"expected as many input types as operands "
4217 <<
"(expected " << operands.size() <<
" got "
4218 << functionType.getNumInputs() <<
")";
4228 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
4229 regionArgs[i].type = functionType.getInput(i);
4231 return failure(parser.
parseRegion(*cond, regionArgs) ||
4238 getInputList(),
" ");
4241 getResults().getTypes());
4256 if (llvm::isa<FloatType>(srcElemType)) {
4258 zpType, builder.
getFloatAttr(srcElemType,
static_cast<double>(zp)));
4259 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4261 if (llvm::isa<IntegerType>(srcElemType)) {
4264 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4266 llvm::errs() <<
"zero point is not allowed for unsupported data types\n";
4267 return std::nullopt;
4275 return mlir::isa<tosa::shapeType>(t);
4282 return emitError() <<
"invalid rank (must be >= 0): " << rank;
4288 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
4289 Operation *definingOp = v.getDefiningOp();
4291 return op->
emitOpError(
"shape operand is not compile time resolvable");
4300 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4301 return op->
emitOpError(
"must have operands with tosa shape type");
4305 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4306 return op->
emitOpError(
"must have result with tosa shape type");
4319 auto getRank = [](
const Type type) {
4320 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4326 for (
auto type : operandTypes) {
4327 if (getRank(type) != rank) {
4328 return op->
emitOpError(
"operands don't have matching ranks");
4331 for (
auto type : resultTypes) {
4332 if (getRank(type) != rank) {
4333 return op->
emitOpError(
"result shape has different rank than operands");
4345 auto valuesRank = getValues().getType().getRank();
4346 if (valuesRank != 1)
4347 return emitOpError(
"expect elements in attribute values with rank 1");
4349 auto count = getValues().getNumElements();
4350 auto rank = (cast<tosa::shapeType>(getResult().
getType())).getRank();
4351 if (count != rank && (count != 1 || rank != 0)) {
4352 return emitOpError(
"expect number of elements in attribute values (")
4353 << count <<
") to be equal to the rank (" << rank
4354 <<
") for the result shape type";
4363 #define GET_ATTRDEF_CLASSES
4364 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
4369 #define GET_TYPEDEF_CLASSES
4370 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
4376 #define GET_OP_CLASSES
4377 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType)
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
static FailureOr< tosa::VariableOp > findVariableDecl(Operation *op, StringRef symName)
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
#define REDUCE_SHAPE_INFER(OP)
static LogicalResult verifyConvOp(T op)
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
static LogicalResult verifyReduceOp(T op)
#define NARY_SHAPE_INFER(OP)
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
static LogicalResult verifyConvOpErrorIf(T op)
static LogicalResult verifyConvOpModes(T op)
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static Type getStorageElementTypeOrSelf(Type type)
#define COMPATIBLE_RETURN_TYPES(OP)
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
static LogicalResult verifyPoolingOp(T op)
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
MutableArrayRef< BlockArgument > BlockArgListType
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
This is a utility class for mapping one set of IR entities to another.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
This class indicates that op operates on tosa shape types.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class provides an abstraction over the different types of ranges over Regions.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperator(Operation *op)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
bool isa_tosa_shape_type(mlir::Type t)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)