28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/TypeSwitch.h"
36 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
43 #include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
44 #include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
45 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
46 #include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
49 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
70 return (isa<tosa::IfOp>(dest->getParentOp()) ||
71 isa<tosa::WhileOp>(dest->getParentOp()));
77 TosaDialectBytecodeInterface(
Dialect *dialect)
87 LogicalResult writeAttribute(
Attribute attr,
89 return ::writeAttribute(attr, writer);
99 LogicalResult writeType(
Type type,
101 return ::writeType(type, writer);
108 std::unique_ptr<DialectVersion>
111 reader.
emitError(
"Dialect does not support versioning");
115 LogicalResult upgradeFromVersion(
Operation *topLevelOp,
129 return {&getBodyGraph()};
137 return to_vector(llvm::map_range(shape, [](int64_t dim) {
138 return dim == -1 ? ShapedType::kDynamic : dim;
144 Type elementType = variableOp.getType();
154 void TosaDialect::initialize() {
156 #define GET_TYPEDEF_LIST
157 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
161 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
164 #define GET_ATTRDEF_LIST
165 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
167 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
168 declarePromisedInterfaces<
169 shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
170 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
171 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
172 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
173 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
174 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
175 GreaterEqualOp, MatMulOp>();
182 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
183 return tosa::ConstShapeOp::create(builder, loc, type,
184 llvm::cast<DenseIntElementsAttr>(value));
186 if (llvm::isa<ElementsAttr>(value))
187 return tosa::ConstOp::create(builder, loc, type,
188 llvm::cast<ElementsAttr>(value));
198 ParseResult getShapeAndElementType(
OpAsmParser &parser,
Type parsedType,
200 TypeAttr &typeAttr) {
201 if (
auto shapedType = dyn_cast<ShapedType>(parsedType)) {
202 if (!shapedType.hasRank())
204 <<
"expected ranked type";
206 auto elementType = shapedType.getElementType();
214 <<
"expected shaped type";
231 <<
"expected attribute";
233 if (
auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
234 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
238 <<
"expected Typed attr";
241 initialValueAttr =
nullptr;
245 <<
"expected type after colon";
247 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
252 TypeAttr typeAttr,
Attribute initialValueAttr) {
253 bool needsSpace =
false;
254 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
257 Type elementType = typeAttr.getValue();
258 RankedTensorType tensorType =
265 if (initialValueAttr) {
276 template <
typename EnumType>
277 ParseResult parseAttrEntryWithEnumHandling(
OpAsmParser &parser,
279 llvm::StringRef name;
286 if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
287 if (name ==
"rounding_mode" &&
289 auto sym = symbolizeRoundingMode(kw);
292 <<
"invalid rounding_mode value: " << kw;
299 if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
301 auto sym = symbolizeResizeMode(kw);
304 <<
"invalid resize mode value: " << kw;
312 if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
314 auto sym = symbolizeNanPropagationMode(kw);
317 <<
"invalid nan_mode value: " << kw;
330 template <
typename EnumType>
335 [&]() { return parser.parseOperand(operands.emplace_back()); }))
343 if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
367 parser << namedAttr.
getName().strref() <<
" = ";
369 if (
auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
370 parser << roundingModeAttr.getValue();
371 }
else if (
auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
372 parser << resizeModeAttr.getValue();
373 }
else if (
auto nanPropagationModeAttr =
374 dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
375 parser << nanPropagationModeAttr.getValue();
388 const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
390 if (
auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
391 if (nanAttr.getValue() == kDefaultNanValue) {
393 toPrint.erase(attr.getName());
399 if (!toPrint.empty()) {
401 llvm::interleaveComma(toPrint, parser, [&](
const NamedAttribute namedAttr) {
402 printNamedAttr(parser, namedAttr);
418 llvm::interleaveComma(op->
getAttrs(), parser,
420 printNamedAttr(parser, namedAttr);
432 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
436 printWithEnumHandling(parser, *
this);
440 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
444 printWithEnumHandling(parser, *
this);
448 return parseWithEnumHandling<tosa::ResizeMode>(parser, result);
452 printWithEnumHandling(parser, *
this);
456 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
460 printWithNanPropagationHandling(parser, *
this);
464 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
468 printWithNanPropagationHandling(parser, *
this);
472 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
476 printWithNanPropagationHandling(parser, *
this);
480 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
484 printWithNanPropagationHandling(parser, *
this);
488 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
492 printWithNanPropagationHandling(parser, *
this);
496 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
500 printWithNanPropagationHandling(parser, *
this);
504 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
508 printWithNanPropagationHandling(parser, *
this);
515 std::optional<int64_t>
idivCheck(
const int64_t lhs,
const int64_t rhs) {
523 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
524 srcType = quantType.getStorageType();
533 Value valZp, StringRef name) {
538 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
542 if (!bothInts || !sameBitWidth) {
544 <<
"expected " << name <<
" and " << name
545 <<
"_zp to both be integer of the same bitwidth, but got " << eType
546 <<
" vs. " << eZpType;
553 Value src, int32_t val) {
558 const auto padConstAttr{
559 llvm::isa<FloatType>(srcElemType)
564 return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
571 template <
typename T>
573 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
574 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
576 auto inputEType = inputType.getElementType();
577 auto weightEType = weightType.getElementType();
579 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
581 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
582 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
583 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
585 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
586 inputEType = quantType.getStorageType();
588 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
589 weightEType = quantType.getStorageType();
591 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
592 biasEType = quantType.getStorageType();
594 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
595 resultEType = quantType.getStorageType();
597 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
601 "expect both bias and result to have same element type, got ")
602 << biasEType <<
" and " << resultEType;
606 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
607 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
608 if (inputEType != weightEType) {
610 "expect both input and weight to have same element type, got ")
611 << inputEType <<
" and " << weightEType;
616 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
617 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
620 if (inputIsFloat != weightIsFloat) {
622 "expect both input and weight to be float or not together, got ")
623 << inputEType <<
" and " << weightEType;
628 if (inputEType != inputZpEType) {
629 return op.emitOpError(
"expect both input and its zero point are the same "
630 "element type, got ")
631 << inputEType <<
" and " << inputZpEType;
635 if (weightEType != weightZpEType) {
636 return op.emitOpError(
"expect both weight and its zero point are the same "
637 "element type, got ")
638 << weightEType <<
" and " << weightZpEType;
641 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
642 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
645 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
646 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
654 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().
getType());
655 auto outputType = llvm::dyn_cast<TensorType>(getOutput().
getType());
657 if (!attrType || !outputType) {
658 emitOpError(
"expected tensors for attr/result type");
662 if (
auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
663 outputType.getElementType())) {
664 if (result.getStorageType() == attrType.getElementType())
668 if (attrType.getElementType() != outputType.getElementType()) {
669 emitOpError(
"expected same attr/result element types");
676 template <
typename T>
679 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
681 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
682 inputEType = quantType.getStorageType();
684 auto accType = op.getAccType();
685 if (inputEType.isInteger(8) && !accType.isInteger(32))
686 return op.emitOpError(
"accumulator type for i8 tensor is not i32");
688 if (inputEType.isInteger(16) && !accType.isInteger(48))
689 return op.emitOpError(
"accumulator type for i16 tensor is not i48");
691 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
692 return op.emitOpError(
"accumulator type for f8 tensor is not f16");
694 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
695 return op.emitOpError(
"accumulator type for f16 tensor is not f16/f32");
697 if (inputEType.isBF16() && !accType.isF32())
698 return op.emitOpError(
"accumulator type for bf16 tensor is not f32");
700 if (inputEType.isF32() && !accType.isF32())
701 return op.emitOpError(
"accumulator type for f32 tensor is not f32");
704 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
706 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
707 resultEType = quantType.getStorageType();
717 template <
typename T>
720 if (llvm::any_of(padding, [](int64_t p) {
return p < 0; }))
721 return op.emitOpError(
"expect all padding values to be >= 0, got ")
725 if (llvm::any_of(strides, [](int64_t s) {
return s < 1; }))
726 return op.emitOpError(
"expect all stride values to be >= 1, got ")
730 if (llvm::any_of(dilations, [](int64_t d) {
return d < 1; }))
731 return op.emitOpError(
"expect all dilation values to be >= 1, got ")
734 const RankedTensorType outputType =
735 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
740 const RankedTensorType inputType =
741 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
742 const RankedTensorType weightType =
743 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
745 if (inputType && weightType) {
746 const auto verifyOutputSize =
747 [&op](
const int64_t inputSize,
const int64_t kernelSize,
748 const int64_t outputSize,
const int64_t padBefore,
749 const int64_t padAfter,
const int64_t stride,
750 const int64_t dilation,
const llvm::StringRef dimName,
751 const llvm::StringRef dimAxis,
752 const llvm::StringRef padBeforeName,
753 const llvm::StringRef padAfterName) -> LogicalResult {
754 if (inputSize == ShapedType::kDynamic ||
755 kernelSize == ShapedType::kDynamic)
760 const std::optional<int64_t> calculatedOutSizeMinusOne =
idivCheck(
761 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
763 if (!calculatedOutSizeMinusOne.has_value())
764 return op.emitOpError(
"expected input_")
765 << dimName <<
" - 1 + pad_" << padBeforeName <<
" + pad_"
766 << padAfterName <<
" - (kernel_" << dimName
767 <<
" - 1) * dilation_" << dimAxis
768 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
769 << inputSize <<
" - 1 + " << padBefore <<
" + " << padAfter
770 <<
" - (" << kernelSize <<
" - 1) * " << dilation <<
") / "
773 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
774 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
775 return op.emitOpError(
"calculated output ")
776 << dimName <<
" did not match expected: "
777 <<
"calculated=" << calculatedOutSize
778 <<
", expected=" << outputSize;
784 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
785 if (
failed(verifyOutputSize(
786 inputType.getDimSize(1), weightType.getDimSize(1),
787 outputType.getDimSize(1), padding[0], padding[1], strides[0],
788 dilations[0],
"height",
"y",
"top",
"bottom")))
791 if (
failed(verifyOutputSize(
792 inputType.getDimSize(2), weightType.getDimSize(2),
793 outputType.getDimSize(2), padding[2], padding[3], strides[1],
794 dilations[1],
"width",
"x",
"left",
"right")))
799 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
800 if (
failed(verifyOutputSize(
801 inputType.getDimSize(1), weightType.getDimSize(0),
802 outputType.getDimSize(1), padding[0], padding[1], strides[0],
803 dilations[0],
"height",
"y",
"top",
"bottom")))
806 if (
failed(verifyOutputSize(
807 inputType.getDimSize(2), weightType.getDimSize(1),
808 outputType.getDimSize(2), padding[2], padding[3], strides[1],
809 dilations[1],
"width",
"x",
"left",
"right")))
814 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
815 if (
failed(verifyOutputSize(
816 inputType.getDimSize(1), weightType.getDimSize(1),
817 outputType.getDimSize(1), padding[0], padding[1], strides[0],
818 dilations[0],
"depth",
"d",
"front",
"back")))
821 if (
failed(verifyOutputSize(
822 inputType.getDimSize(2), weightType.getDimSize(2),
823 outputType.getDimSize(2), padding[2], padding[3], strides[1],
824 dilations[1],
"height",
"y",
"top",
"bottom")))
827 if (
failed(verifyOutputSize(
828 inputType.getDimSize(3), weightType.getDimSize(3),
829 outputType.getDimSize(3), padding[4], padding[5], strides[2],
830 dilations[2],
"width",
"x",
"left",
"right")))
835 const RankedTensorType biasType =
836 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
841 const int64_t biasChannels = biasType.getDimSize(0);
842 const int64_t outputChannels =
843 outputType.getDimSize(outputType.getRank() - 1);
844 if (biasChannels == ShapedType::kDynamic ||
845 outputChannels == ShapedType::kDynamic)
849 if (biasChannels != outputChannels && biasChannels != 1)
850 return op.emitOpError(
851 "bias channels expected to be equal to output channels (")
852 << outputChannels <<
") or 1, got " << biasChannels;
859 StringRef name1,
Type type2,
861 auto shapeType1 = dyn_cast<ShapedType>(type1);
862 auto shapeType2 = dyn_cast<ShapedType>(type2);
863 if (!shapeType1 || !shapeType2)
866 auto elemType1 = shapeType1.getElementType();
867 auto elemType2 = shapeType2.getElementType();
868 if (elemType1 != elemType2)
870 <<
"require same element type for " << name1 <<
" (" << elemType1
871 <<
") and " << name2 <<
" (" << elemType2 <<
")";
875 <<
"require same shapes for " << name1 <<
" (" << type1 <<
") and "
876 << name2 <<
" (" << type2 <<
")";
886 if (list1.size() != list2.size())
888 <<
"require same number of values in " << name1 <<
" ("
889 << list1.size() <<
") and " << name2 <<
" (" << list2.size() <<
")";
891 for (
auto [type1, type2] :
905 return shapeAdaptor.
getNumElements() == 1 ? success() : failure();
913 tosa::VariableOp varOp =
nullptr;
927 if (
auto tosaOp = dyn_cast<tosa::VariableOp>(tempOp)) {
928 if (symName == tosaOp.getName()) {
943 template <
typename T>
945 StringRef symName = op.getName();
948 return op->emitOpError(
"'")
949 << symName <<
"' has not been declared by 'tosa.variable'";
962 template <
typename T>
964 auto inputType = llvm::dyn_cast<TensorType>(inType);
965 auto outputType = llvm::dyn_cast<TensorType>(outType);
967 op.emitOpError(
"expect shaped tensor for input, got ") << inType;
971 op.emitOpError(
"expect shaped tensor for output, got ") << outType;
974 auto inputElementType = inputType.getElementType();
975 auto outputElementType = outputType.getElementType();
976 auto inputQuantType =
977 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
978 auto outputQuantType =
979 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
980 if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
981 (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
982 inputElementType != outputElementType) {
987 op.emitOpError(
"expect input and output to have same element type, got ")
988 << inputElementType <<
" and " << outputElementType;
995 const ShapedType resultType = llvm::cast<ShapedType>(
getType());
998 if (
const auto resultETy = resultType.getElementType();
999 !resultETy.isIntOrIndex())
1000 return emitOpError(
"result tensor is not of integer type");
1002 const auto inputType = llvm::cast<ShapedType>(getInput().
getType());
1003 if (!inputType.hasRank())
1007 const int64_t axis = getAxisAttr().getInt();
1008 if (((axis < 0) || axis >= inputType.getRank()))
1009 return emitOpError(
"specified axis is outside the rank of the tensor");
1011 if (!resultType.hasRank())
1017 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
1019 return emitOpError(
"expected output shape '")
1020 << expectedOutputShape <<
"', got '" << outputShape <<
"'";
1025 template <
typename T>
1028 if (llvm::any_of(kernel, [](int64_t s) {
return s < 1; }))
1029 return op.emitOpError(
"expect all kernel values to be >= 1, got ")
1033 if (llvm::any_of(strides, [](int64_t s) {
return s < 1; }))
1034 return op.emitOpError(
"expect all stride values to be >= 1, got ")
1038 if (llvm::any_of(padding, [](int64_t p) {
return p < 0; }))
1039 return op.emitOpError(
"expect all padding values to be >= 0, got ")
1043 const int64_t kernelX = kernel[1];
1044 const int64_t padLeft = padding[2];
1045 const int64_t padRight = padding[3];
1046 if (padRight >= kernelX || padLeft >= kernelX)
1047 return op.emitOpError(
"expected left/right padding to be less than the "
1048 "width of the kernel, got pad_left=")
1049 << padLeft <<
", pad_right=" << padRight <<
", kernel_x=" << kernelX;
1051 const int64_t kernelY = kernel[0];
1052 const int64_t padTop = padding[0];
1053 const int64_t padBottom = padding[1];
1054 if (padTop >= kernelY || padBottom >= kernelY)
1055 return op.emitOpError(
"expected top/bottom padding to be less than the "
1056 "height of the kernel, got pad_top=")
1057 << padTop <<
", pad_bottom=" << padBottom
1058 <<
", kernel_y=" << kernelY;
1060 const auto inputType =
1061 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
1062 const auto outputType =
1063 llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
1064 if (!inputType || !outputType)
1067 const auto verifyOutputSize =
1068 [&op](
const int64_t inputSize,
const int64_t outputSize,
1069 const int64_t kernelSize,
const int64_t strideSize,
1070 const int64_t padBefore,
const int64_t padAfter,
1071 const llvm::StringRef dimName,
const llvm::StringRef dimAxis,
1072 const llvm::StringRef padBeforeName,
1073 const llvm::StringRef padAfterName) -> LogicalResult {
1074 if (ShapedType::isDynamic(inputSize))
1077 const std::optional<int64_t> calculatedOutSizeMinusOne =
1078 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1079 if (!calculatedOutSizeMinusOne.has_value())
1080 return op.emitOpError(
"expected input_")
1081 << dimName <<
" + pad_" << padBeforeName <<
" + pad_"
1082 << padAfterName <<
" - kernel_" << dimAxis
1083 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
1084 << inputSize <<
" + " << padBefore <<
" + " << padAfter <<
" - "
1085 << kernelSize <<
") / " << strideSize;
1087 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1088 if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1089 return op.emitOpError(
"calculated output ")
1090 << dimName <<
" did not match expected: "
1091 <<
"calculated=" << calculatedOutSize
1092 <<
", expected=" << outputSize;
1097 if (
failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
1098 kernel[0], strides[0], padding[0], padding[1],
1099 "height",
"y",
"top",
"bottom")))
1102 if (
failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
1103 kernel[1], strides[1], padding[2], padding[3],
1104 "width",
"x",
"left",
"right")))
1119 auto accType = getAccType();
1120 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1121 return emitOpError(
"accumulator type for integer tensor is not i32");
1123 if (inputETy.
isF16() && !(accType.isF16() || accType.isF32()))
1124 return emitOpError(
"accumulator type for f16 tensor is not f16/f32");
1126 if (inputETy.
isBF16() && !accType.isF32())
1127 return emitOpError(
"accumulator type for bf16 tensor is not f32");
1129 if (inputETy.
isF32() && !accType.isF32())
1130 return emitOpError(
"accumulator type for f32 tensor is not f32");
1132 if (inputETy != inputZpETy)
1133 return emitOpError(
"expect both input and its zero point are the same "
1134 "element type, got ")
1135 << inputETy <<
" and " << inputZpETy;
1137 if (resultETy != outputZpETy)
1138 return emitOpError(
"expect both output and its zero point are the same "
1139 "element type, got ")
1140 << resultETy <<
" and " << outputZpETy;
1142 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
1143 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
1146 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1147 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
1155 llvm::cast<ShapedType>(getInput().
getType()).getElementType();
1156 if (
auto quantType =
1157 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1158 inputETy = quantType.getStorageType();
1161 llvm::cast<ShapedType>(getOutput().
getType()).getElementType();
1162 if (
auto quantType =
1163 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1164 outputETy = quantType.getStorageType();
1166 if (inputETy != outputETy)
1167 return emitOpError(
"input/output element types are incompatible.");
1169 auto maxValAttr = getMaxValAttr();
1170 auto minValAttr = getMinValAttr();
1174 if (inputETy.
isInteger(dataTypeBitWidth)) {
1178 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1179 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1180 if (!intMaxValAttr || !intMinValAttr ||
1181 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1182 (intMaxValAttr.getType() != inputETy))
1183 return emitOpError(
"min/max attributes types are incompatible with "
1184 "input/output element types.");
1187 const bool isBoolean = inputETy.
isInteger(1);
1188 const APInt minVal = intMinValAttr.getValue();
1189 const APInt maxVal = intMaxValAttr.getValue();
1190 if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1191 return emitOpError(
"expected min_val <= max_val, got min_val=")
1192 << minValAttr <<
", max_val=" << maxValAttr;
1197 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1198 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1199 if (!floatMaxValAttr || !floatMinValAttr ||
1200 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1201 (floatMaxValAttr.getType() != inputETy))
1202 return emitOpError(
"min/max attributes types are incompatible with "
1203 "input/output element types.");
1205 const APFloat minVal = floatMinValAttr.getValue();
1206 const APFloat maxVal = floatMaxValAttr.getValue();
1207 if (minVal.isNaN() || maxVal.isNaN())
1208 return emitOpError(
"min/max attributes should not be 'NaN', got min_val=")
1209 << minValAttr <<
", max_val=" << maxValAttr;
1211 if (maxVal < minVal)
1212 return emitOpError(
"expected min_val <= max_val, got min_val=")
1213 << minValAttr <<
", max_val=" << maxValAttr;
1233 result.
addOperands({input, weight, bias, zps.first, zps.second});
1238 Type finalOutputType = outputType;
1255 result.
addOperands({input, weight, bias, zps.first, zps.second});
1259 Type finalOutputType = outputType;
1276 result.
addOperands({a, b, zps.first, zps.second});
1278 Type finalOutputType{outputType};
1281 auto inputBits = eType.getIntOrFloatBitWidth();
1283 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1284 assert(outputShapedType &&
"Output must be a shaped type");
1286 IntegerType accElementType;
1287 if (inputBits == 16)
1292 finalOutputType = outputShapedType.clone(accElementType);
1303 DenseArrayAttr kernel, DenseArrayAttr stride,
1304 DenseArrayAttr pad, TypeAttr accType) {
1307 int64_t outputZp{0};
1309 if (
auto quantAttr =
1311 inputZp = quantAttr.getInputZp();
1312 outputZp = quantAttr.getOutputZp();
1314 const std::optional<Value> inputZpOp =
1319 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1321 const std::optional<Value> outputZpOp =
1324 (void)
emitError(loc,
"Failed to create output zero point tensor for "
1325 "quantized AVG_POOL2D op");
1328 if (inputZpOp && outputZpOp) {
1329 result.
addOperands({input, inputZpOp.value(), outputZpOp.value()});
1340 result.
types.push_back(outputType);
1350 int64_t input1Zp{0};
1351 int64_t outputZp{0};
1354 input1Zp = quantAttr.getInputZp();
1355 outputZp = quantAttr.getOutputZp();
1357 const std::optional<Value> input1ZpOp =
1361 loc,
"Failed to create input1 zero point for quantized NEGATE op");
1364 const std::optional<Value> outputZpOp =
1368 loc,
"Failed to create output zero point for quantized NEGATE op");
1371 if (input1ZpOp && outputZpOp) {
1372 result.
addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1380 result.
types.push_back(outputType);
1393 zp =
static_cast<int32_t
>(quantAttr.getInputZp());
1396 result.
addOperands({input, paddings, padConstOp});
1397 result.
types.push_back(outputType);
1401 StringRef name,
Type variableType,
1406 auto shapedType = dyn_cast<ShapedType>(variableType);
1408 (void)
emitError(loc,
"variable type must be a shaped type");
1411 if (!shapedType.hasRank()) {
1412 (void)
emitError(loc,
"variable type must be a ranked type");
1416 auto elementType = shapedType.getElementType();
1433 int64_t outRank = 0;
1434 for (
int i = 0, e = operands.size(); i != e; ++i) {
1436 if (!shape.hasRank()) {
1441 outRank = std::max<int64_t>(outRank, shape.getRank());
1444 outShape.resize(outRank, 1);
1446 for (
int i = 0, e = operands.size(); i != e; ++i) {
1448 auto rankDiff = outShape.size() - shape.getRank();
1450 for (
size_t i = 0, e = shape.getRank(); i < e; ++i) {
1451 auto dim1 = outShape[i + rankDiff];
1452 auto dim2 = shape.getDimSize(i);
1453 auto resolvedDim = dim1;
1457 }
else if (dim2 == 1) {
1459 }
else if (dim1 != dim2) {
1462 outShape[i + rankDiff] = resolvedDim;
1469 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1470 MLIRContext *context, ::std::optional<Location> location,
1471 ArgMaxOp::Adaptor adaptor,
1474 IntegerAttr axis = adaptor.getProperties().axis;
1475 int32_t axisVal = axis.getValue().getSExtValue();
1477 if (!inputShape.hasRank()) {
1483 outShape.reserve(inputShape.getRank() - 1);
1484 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1487 outShape.push_back(inputShape.getDimSize(i));
1494 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1495 MLIRContext *context, ::std::optional<Location> location,
1496 RFFT2dOp::Adaptor adaptor,
1498 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1500 if (!inputShape.hasRank())
1504 outputShape.resize(3, ShapedType::kDynamic);
1505 outputShape[0] = inputShape.getDimSize(0);
1506 outputShape[1] = inputShape.getDimSize(1);
1507 int64_t inWidth = inputShape.getDimSize(2);
1511 if (inWidth != ShapedType::kDynamic)
1512 outputShape[2] = inWidth / 2 + 1;
1521 const llvm::StringRef dimName) {
1522 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1525 << dimName <<
" to be a power of two, got " << dimSize;
1531 const auto outputTypes = getResultTypes();
1533 return emitOpError(
"expected output shapes to match, got ") << outputTypes;
1535 const auto inputType =
1536 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1540 const int64_t height = inputType.getDimSize(1);
1541 if (ShapedType::isStatic(height) &&
1545 const int64_t width = inputType.getDimSize(2);
1546 if (ShapedType::isStatic(width) &&
1550 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1556 outputType.getShape().drop_back())))
1557 return emitOpError(
"expected batch and height dimensions of input/output "
1558 "to match, got input=")
1559 << inputType <<
" output=" << outputType;
1562 const int64_t outputWidth = outputType.getDimSize(2);
1563 if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1564 (outputWidth != (width / 2) + 1))
1566 "expected output width to be equal to input_width / 2 + 1, got ")
1572 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1573 MLIRContext *context, ::std::optional<Location> location,
1574 FFT2dOp::Adaptor adaptor,
1576 inferredReturnShapes.push_back(
1578 inferredReturnShapes.push_back(
1584 const auto inputRealType =
1585 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1586 const auto inputImagType =
1587 llvm::dyn_cast<RankedTensorType>(getInputImag().
getType());
1588 if (!inputRealType || !inputImagType)
1591 const auto trySelectStaticDim = [](
const int64_t a,
const int64_t b) {
1592 return ShapedType::isDynamic(a) ? a : b;
1595 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1596 inputImagType.getDimSize(1));
1597 if (ShapedType::isStatic(height) &&
1601 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1602 inputImagType.getDimSize(2));
1603 if (ShapedType::isStatic(width) &&
1610 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1611 MLIRContext *context, ::std::optional<Location> location,
1612 ConcatOp::Adaptor adaptor,
1615 const Properties &prop = adaptor.getProperties();
1616 int32_t axis = prop.axis.getValue().getSExtValue();
1618 bool hasRankedInput =
false;
1619 for (
auto operand : adaptor.getOperands()) {
1621 if (!operandShape.hasRank())
1625 if (!hasRankedInput)
1626 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1629 for (
int i = 0, s = operandShape.getRank(); i < s; i++) {
1630 if (i == axis || operandShape.isDynamicDim(i))
1632 if (outputShape[i] == ShapedType::kDynamic)
1633 outputShape[i] = operandShape.getDimSize(i);
1634 if (outputShape[i] != operandShape.getDimSize(i))
1636 "Cannot concat tensors with different sizes"
1637 " on the non-axis dimension ",
1641 hasRankedInput =
true;
1644 if (adaptor.getInput1().empty())
1648 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1649 if (!hasRankedInput) {
1655 int64_t concatDimSize = 0;
1656 for (
auto operand : adaptor.getOperands()) {
1661 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1662 concatDimSize = ShapedType::kDynamic;
1666 concatDimSize += operandShape.getDimSize(axis);
1669 outputShape[axis] = concatDimSize;
1677 auto outType = getOutput().getType();
1681 if (inputList.empty())
1682 return emitOpError(
"expect at least one input");
1684 if (!llvm::all_of(inputList, [&](
auto input) {
1686 *
this, input.getType(), outType));
1691 const int32_t axis = getAxis();
1693 for (
const auto &input : inputList) {
1694 const Type inputType = input.getType();
1696 if (currShape.hasRank()) {
1697 firstRankedInputShape = currShape;
1699 if (axis < 0 || axis >= firstRankedInputShape.
getRank())
1700 return emitOpError(
"expect axis to be within range 0 < axis < "
1701 "rank(input1[firstRankedTensorIdx]), got ")
1707 const auto allOperandsHasRank = [](
const Value input) {
1710 if (llvm::all_of(inputList, allOperandsHasRank)) {
1711 const int64_t firstInputRank = firstRankedInputShape.
getRank();
1713 for (
const auto &[index, input] :
llvm::enumerate(inputList.drop_front())) {
1715 const int64_t inputRank = inputShape.getRank();
1716 const size_t operandNum = index + 1;
1719 if (inputRank != firstInputRank)
1721 "expect all operands to have the same rank, but got ")
1722 << firstInputRank <<
" vs " << inputRank <<
" on operands 0 and "
1726 for (
int i = 0; i < inputRank; i++) {
1727 const int64_t inputDim = inputShape.getDimSize(i);
1728 const int64_t firstInputDim = firstRankedInputShape.
getDimSize(i);
1729 if (i == axis || firstRankedInputShape.
isDynamicDim(i) ||
1730 inputShape.isDynamicDim(i))
1732 if (inputDim != firstInputDim)
1733 return emitOpError(
"expect all operand shapes to have the same sizes "
1734 "on non-axis dimensions, but got ")
1735 << inputDim <<
" vs " << firstInputDim <<
" at index " << i
1736 <<
" on operands 0 and " << operandNum;
1741 int64_t axisSum = 0;
1742 for (
const auto &input : inputList) {
1744 if (inputShape.isDynamicDim(axis)) {
1749 axisSum += inputShape.getDimSize(axis);
1752 if (axisSum >= 0 && outputShape.hasRank() &&
1753 !outputShape.isDynamicDim(axis) &&
1754 axisSum != outputShape.getDimSize(axis))
1755 return emitOpError(
"requires sum of axis dimensions of input1 "
1756 "equal to output axis dimension, got ")
1757 << axisSum <<
" and " << outputShape.getDimSize(axis);
1763 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1764 MLIRContext *context, ::std::optional<Location> location,
1781 if (l.size() != r.size() || l.size() != 1)
1786 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1787 MLIRContext *context, ::std::optional<Location> location,
1788 MatMulOp::Adaptor adaptor,
1795 outShape.resize(3, ShapedType::kDynamic);
1797 if (lhsShape.hasRank()) {
1798 outShape[0] = lhsShape.getDimSize(0);
1799 outShape[1] = lhsShape.getDimSize(1);
1802 if (rhsShape.hasRank()) {
1803 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1805 outShape[2] = rhsShape.getDimSize(2);
1813 auto aType = llvm::dyn_cast<ShapedType>(getA().
getType());
1814 auto bType = llvm::dyn_cast<ShapedType>(getB().
getType());
1818 return emitOpError(
"expect a shaped tensor for input a, got ")
1819 << getA().getType();
1822 return emitOpError(
"expect a shaped tensor for input b, got ")
1823 << getB().getType();
1825 auto aElementType = aType.getElementType();
1826 auto bElementType = bType.getElementType();
1828 auto aQuantizedEType =
1829 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1830 auto bQuantizedEType =
1831 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1833 if (aQuantizedEType || bQuantizedEType) {
1834 if (!aQuantizedEType || !bQuantizedEType) {
1835 return emitOpError(
"expect operands to be both quantized or both not "
1837 << aElementType <<
" and " << bElementType;
1840 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1841 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1842 if (aQuantWidth != bQuantWidth) {
1843 return emitOpError(
"expect quantized operands to have same widths, got ")
1844 << aQuantWidth <<
" and " << bQuantWidth;
1848 if (aElementType != bElementType) {
1849 return emitOpError(
"expect same element type for inputs a and b, got ")
1850 << aElementType <<
" and " << bElementType;
1857 if (aEType != aZpEType) {
1858 return emitOpError(
"expect input a and a_zp have the same "
1859 "element type, got ")
1860 << aEType <<
" and " << aZpEType;
1865 if (bEType != bZpEType) {
1866 return emitOpError(
"expect input b and b_zp have the same "
1867 "element type, got ")
1868 << bEType <<
" and " << bZpEType;
1871 FailureOr<int64_t> maybeAZp = getAZeroPoint();
1872 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1875 FailureOr<int64_t> maybeBZp = getBZeroPoint();
1876 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1882 LogicalResult tosa::PadOp::inferReturnTypeComponents(
1883 MLIRContext *context, ::std::optional<Location> location,
1884 PadOp::Adaptor adaptor,
1886 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1888 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
1893 if (!inputShape.hasRank()) {
1894 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
1903 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1908 outputShape.reserve(inputShape.getRank());
1909 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1910 if (inputShape.isDynamicDim(i)) {
1911 outputShape.push_back(ShapedType::kDynamic);
1914 auto padFront = paddingValues[i * 2];
1915 auto padBack = paddingValues[i * 2 + 1];
1916 if (padFront < 0 || padBack < 0) {
1918 outputShape.push_back(ShapedType::kDynamic);
1922 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
1936 if (
auto padConst = getPadConst()) {
1944 RankedTensorType inputType =
1945 llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1946 RankedTensorType outputType =
1947 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
1948 if (!inputType || !outputType)
1951 auto inputRank = inputType.getRank();
1952 auto outputRank = outputType.getRank();
1953 if (inputRank != outputRank)
1954 return emitOpError() <<
"expect same input and output tensor rank, but got "
1955 <<
"inputRank: " << inputRank
1956 <<
", outputRank: " << outputRank;
1963 auto paddingValues = paddingAttr.getValues<APInt>();
1964 if (paddingValues.size() !=
static_cast<size_t>(inputRank * 2))
1965 return emitOpError() <<
"padding tensor must have " << inputRank
1966 <<
" * 2 = " << inputRank * 2 <<
" elements, but got "
1967 << paddingValues.size();
1969 auto inputShape = inputType.getShape();
1970 auto outputShape = outputType.getShape();
1972 for (int64_t i = 0; i < inputRank; ++i) {
1973 int64_t padStart = paddingValues[i * 2].getSExtValue();
1974 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
1976 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
1977 return emitOpError()
1978 <<
"invalid padding values at dimension " << i
1979 <<
": values must be non-negative or -1 for dynamic padding, got ["
1980 << padStart <<
", " << padEnd <<
"]";
1984 if (inputShape[i] == ShapedType::kDynamic ||
1985 outputShape[i] == ShapedType::kDynamic)
1988 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
1989 return emitOpError() <<
"mismatch in output shape at dimension " << i
1990 <<
": expected " << inputShape[i] <<
" + "
1991 << padStart <<
" + " << padEnd <<
" = "
1992 << (inputShape[i] + padStart + padEnd)
1993 <<
", but got " << outputShape[i];
2000 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
2001 MLIRContext *context, ::std::optional<Location> location,
2002 SliceOp::Adaptor adaptor,
2011 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
2019 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2022 if (inputShape.hasRank()) {
2023 for (
size_t i = 0; i < size.size(); i++) {
2024 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
2025 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
2026 start[i] < inputShape.getDimSize(i))) {
2028 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
2031 outputShape[i] = size[i];
2035 if (size[i] == -1) {
2036 outputShape[i] = inputShape.getDimSize(i) - start[i];
2037 }
else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2039 outputShape[i] = size[i];
2058 if (inputShape.hasRank()) {
2059 const auto inputRank = inputShape.getRank();
2061 if (outputShape.hasRank() && inputRank != outputShape.getRank())
2063 "expect input1 and output to have the same ranks, got ")
2064 << inputRank <<
" and " << outputShape.getRank();
2066 const auto startShapeRank =
2067 llvm::cast<tosa::shapeType>(getStart().
getType()).getRank();
2068 if (inputRank != startShapeRank)
2069 return emitOpError(
"length of start is not equal to rank of input shape");
2071 const auto sizeShapeRank =
2072 llvm::cast<tosa::shapeType>(getSize().
getType()).getRank();
2073 if (inputRank != sizeShapeRank)
2074 return emitOpError(
"length of size is not equal to rank of input shape");
2080 LogicalResult tosa::MulOp::inferReturnTypeComponents(
2081 MLIRContext *context, ::std::optional<Location> location,
2097 const Value output = getOutput();
2102 if (
auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2103 IntegerType lhsIntType =
2105 IntegerType rhsIntType =
2107 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2108 return emitOpError(
"requires the same element type for all operands");
2113 if (lhsIntType.getWidth() > resIntType.getWidth())
2114 return emitOpError(
"invalid data type size for operands or result");
2119 for (
int i = 0; i < 2; ++i) {
2122 "requires the same element type for all operands and results");
2126 ElementsAttr shift_elem;
2128 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
2130 return emitOpError() <<
"require shift to be 0 for float type";
2138 TypeRange operandTypes = getOperandTypes();
2139 ShapedType aType = cast<ShapedType>(operandTypes[0]);
2140 ShapedType bType = cast<ShapedType>(operandTypes[1]);
2142 const bool aHasRank = aType.hasRank();
2143 const bool bHasRank = bType.hasRank();
2144 if (aHasRank && bHasRank) {
2145 const int64_t aRank = aType.getRank();
2146 const int64_t bRank = bType.getRank();
2148 return emitOpError(
"a and b operands don't have matching ranks, got ")
2149 << aRank <<
" and " << bRank;
2154 aType.getShape(), bType.getShape(), resultShape))
2155 return emitOpError(
"a and b operands don't have broadcast-compatible "
2157 << aType <<
" and " << bType;
2160 ShapedType resultType = cast<ShapedType>(output.
getType());
2161 if (!resultType.hasRank())
2164 const int64_t resultRank = resultType.getRank();
2165 if (aHasRank && resultRank != aType.getRank())
2166 return emitOpError(
"result type has different rank than a, got ")
2167 << resultRank <<
" vs " << aType.getRank();
2168 if (bHasRank && resultRank != bType.getRank())
2169 return emitOpError(
"result type has different rank than b, got ")
2170 << resultRank <<
" vs " << bType.getRank();
2175 LogicalResult tosa::TableOp::inferReturnTypeComponents(
2176 MLIRContext *context, ::std::optional<Location> location,
2177 TableOp::Adaptor adaptor,
2179 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2181 if (!inputShape.hasRank()) {
2186 inferredReturnShapes.resize(1);
2187 inputShape.getDims(inferredReturnShapes[0]);
2192 TensorType inputType = getInput1().getType();
2193 TensorType outputType = getOutput().getType();
2196 inputType.getRank() != outputType.getRank())
2197 return emitOpError()
2198 <<
"expected input tensor rank to equal result tensor rank";
2200 auto inputDims = inputType.
getShape();
2201 auto outputDims = outputType.
getShape();
2203 int64_t dim = it.index();
2204 auto [inputDim, outputDim] = it.value();
2205 if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2206 return emitOpError() <<
"dim(result, " << dim <<
") = " << outputDim
2207 <<
" doesn't match dim(input, " << dim
2208 <<
") = " << inputDim;
2220 multiples = llvm::to_vector(
2221 llvm::map_range(multiplesAttr.getValues<APInt>(),
2222 [](
const APInt &val) { return val.getSExtValue(); }));
2226 LogicalResult tosa::TileOp::inferReturnTypeComponents(
2227 MLIRContext *context, ::std::optional<Location> location,
2228 TileOp::Adaptor adaptor,
2235 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2243 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2245 if (!inputShape.hasRank()) {
2246 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2247 inferredReturnShapes.push_back(
2250 }
else if (
static_cast<size_t>(inputShape.getRank()) != multiples.size())
2254 outputShape.reserve(multiples.size());
2255 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
2256 if (multiples[i] == ShapedType::kDynamic) {
2257 outputShape.push_back(ShapedType::kDynamic);
2259 int64_t dim = inputShape.getDimSize(i);
2260 if (dim != ShapedType::kDynamic)
2261 dim *= multiples[i];
2262 outputShape.push_back(dim);
2276 ShapedType inputType = llvm::cast<ShapedType>(getInput1().
getType());
2277 ShapedType outputType = llvm::cast<ShapedType>(
getType());
2279 shapeType multiplesType =
2280 llvm::cast<tosa::shapeType>(getMultiples().
getType());
2282 auto multiplesRank = multiplesType.getRank();
2284 if (inputType.hasRank()) {
2285 if (inputType.getRank() != multiplesRank)
2286 return emitOpError(
"expect 'multiples' to have rank ")
2287 << inputType.getRank() <<
" but got " << multiplesRank <<
".";
2288 if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
2289 return emitOpError(
"expect same input and output tensor rank.");
2290 }
else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2291 return emitOpError(
"expect 'multiples' array to have length ")
2292 << outputType.getRank() <<
" but got " << multiplesRank <<
".";
2295 if (getConstantMultiples(multiples).succeeded() &&
2296 llvm::any_of(multiples, [](int64_t v) {
return v <= 0 && v != -1; }))
2298 "expect element of 'multiples' to be positive integer or -1.");
2304 if (l.size() != r.size() || l.size() != 1)
2309 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2310 MLIRContext *context, ::std::optional<Location> location,
2311 ReshapeOp::Adaptor adaptor,
2313 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2318 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2328 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2329 inferredReturnShapes.push_back(
2337 int64_t numElements = inputShape.getNumElements();
2338 int64_t staticMul = 1;
2339 for (
auto val : newShapeValue) {
2340 if (ShapedType::isStatic(val)) {
2346 for (
auto &val : newShapeValue) {
2347 if (ShapedType::isDynamic(val))
2348 val = numElements / staticMul;
2351 inferredReturnShapes.push_back(
2362 TensorType inputType = getInput1().getType();
2367 return mlir::success();
2370 int missingDims = llvm::count(shapeValues, -1);
2371 if (missingDims > 1)
2372 return emitOpError() <<
"expected at most one target dimension to be -1";
2374 const auto outputType = dyn_cast<RankedTensorType>(
getType());
2378 if ((int64_t)shapeValues.size() != outputType.getRank())
2379 return emitOpError() <<
"new shape does not match result rank";
2381 for (
auto [newShapeDim, outputShapeDim] :
2382 zip(shapeValues, outputType.getShape())) {
2383 if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2384 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2385 return emitOpError() <<
"new shape is inconsistent with result shape";
2387 if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2388 return emitOpError() <<
"new shape has invalid tensor dimension size "
2392 if (inputType.hasStaticShape()) {
2393 int64_t inputElementsNum = inputType.getNumElements();
2394 if (outputType.hasStaticShape()) {
2395 int64_t outputElementsNum = outputType.getNumElements();
2396 if (inputElementsNum != outputElementsNum) {
2397 return emitOpError() <<
"cannot reshape " << inputElementsNum
2398 <<
" elements into " << outputElementsNum;
2402 int64_t newShapeElementsNum = std::accumulate(
2403 shapeValues.begin(), shapeValues.end(), 1LL,
2404 [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; });
2405 bool isStaticNewShape =
2406 llvm::all_of(shapeValues, [](int64_t s) {
return s > 0; });
2407 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2408 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2409 return emitOpError() <<
"cannot reshape " << inputElementsNum
2410 <<
" elements into " << newShapeElementsNum;
2414 return mlir::success();
2421 ElementsAttr zpAttr;
2426 Type zpElemType = zpAttr.getElementType();
2428 if (llvm::isa<FloatType>(zpElemType)) {
2429 if (zpAttr.getValues<APFloat>()[0].isZero()) {
2436 if (llvm::isa<IntegerType>(zpElemType)) {
2438 return zpAttr.getValues<APInt>()[0].getSExtValue();
2440 return zpAttr.getValues<APInt>()[0].getZExtValue();
2447 template <
typename T>
2449 const std::string &operand) {
2452 if (!zpElemType.
isInteger(8) && zp != 0) {
2454 std::string lower = operand;
2455 std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
2456 return op.emitOpError()
2457 << lower <<
" zero point must be zero for non-int8 integer types";
2465 const std::string &operand) {
2466 bool isInputZp = (operand ==
"Input");
2468 bool tensorUnsigned =
2469 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2470 StringRef tensorName = isInputZp ?
"input" :
"output";
2476 !(zpElemType.
isInteger(16) && tensorUnsigned)) {
2477 return op.emitOpError()
2478 <<
"expect " << tensorName <<
"_zp of 0, got " << zp;
2480 if (zpElemType.
isInteger(16) && tensorUnsigned && zp != 32768) {
2481 return op.emitOpError() <<
"expect " << tensorName
2482 <<
"_zp of 0 or 32768 for unsigned int16 "
2483 << tensorName <<
", got " << zp;
2490 #define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2491 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2492 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2494 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2495 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2514 #undef ZERO_POINT_HELPER
2516 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2517 MLIRContext *context, ::std::optional<Location> location,
2518 TransposeOp::Adaptor adaptor,
2520 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2529 const auto inputRank = inputShape.
getRank();
2533 if (adaptor.getPerms().size() !=
static_cast<size_t>(inputRank)) {
2539 if (inputRank == 0) {
2545 bool allTheSame =
true;
2546 for (
int i = 1, s = inputRank; i < s; i++) {
2556 outputShape.resize(inputRank, inputShape.
getDimSize(0));
2561 outputShape.resize(inputRank, ShapedType::kDynamic);
2564 if (llvm::any_of(adaptor.getPerms(),
2565 [inputRank](
const auto i) { return i >= inputRank; }))
2568 outputShape.reserve(inputRank);
2569 for (
int i = 0, s = inputRank; i < s; i++) {
2570 outputShape[i] = inputShape.
getDimSize(adaptor.getPerms()[i]);
2589 if (inputShape.hasRank() &&
2590 constantPerms.size() !=
static_cast<size_t>(inputShape.getRank()))
2591 return emitOpError() <<
"expected perms attribute to have size "
2592 << inputShape.getRank()
2593 <<
" (input rank) but got size "
2594 << constantPerms.size();
2596 if (inputShape.hasRank() && outputShape.hasRank() &&
2597 inputShape.getRank() != outputShape.getRank())
2598 return emitOpError()
2599 <<
"expected input tensor rank to equal result tensor rank";
2601 if (outputShape.hasRank() &&
2602 constantPerms.size() !=
static_cast<size_t>(outputShape.getRank()))
2603 return emitOpError() <<
"expected perms attribute to have size "
2604 << outputShape.getRank()
2605 <<
" (output rank) but got size "
2606 << constantPerms.size();
2608 if (!llvm::all_of(constantPerms,
2609 [&constantPerms](int32_t s) {
2611 static_cast<size_t>(s) < constantPerms.size();
2614 constantPerms, [](int32_t v) -> int64_t {
return v; }))))
2615 return emitOpError() <<
"expected valid permutation indices";
2618 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2619 inputShape.getNumElements() != outputShape.getNumElements())
2620 return emitOpError() <<
"expected input1 and output to have same numbers "
2622 << inputShape.getNumElements() <<
" and "
2623 << outputShape.getNumElements();
2627 if (inputShape.hasRank() && outputShape.hasRank()) {
2628 for (
auto i = 0; i < outputShape.getRank(); i++) {
2629 if (inputShape.isDynamicDim(constantPerms[i]) ||
2630 outputShape.isDynamicDim(i))
2633 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2634 return emitOpError()
2635 <<
"expected output tensor dim " << i <<
" to match "
2636 <<
"input dim " << constantPerms[i] <<
" with value of "
2637 << inputShape.getDimSize(constantPerms[i]);
2649 Value input = getInput1();
2650 auto inputType = cast<TensorType>(input.
getType());
2653 for (
auto dim : transposePerms) {
2654 int32_t dimInInput = transposePerms[dim];
2655 if (inputType.isDynamicDim(dimInInput))
2657 tensor::DimOp::create(builder, getLoc(), input, dimInInput)
2661 builder.
getIndexAttr(inputType.getDimSize(dimInInput));
2664 reifiedReturnShapes.emplace_back(std::move(returnedDims));
2668 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2669 MLIRContext *context, ::std::optional<Location> location,
2670 GatherOp::Adaptor adaptor,
2673 outputShape.resize(3, ShapedType::kDynamic);
2675 ShapeAdaptor valuesShape(adaptor.getValues().getType());
2676 if (valuesShape.hasRank()) {
2677 outputShape[0] = valuesShape.getDimSize(0);
2678 outputShape[2] = valuesShape.getDimSize(2);
2681 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2682 if (indicesShape.hasRank()) {
2683 if (outputShape[0] == ShapedType::kDynamic)
2684 outputShape[0] = indicesShape.getDimSize(0);
2685 if (outputShape[1] == ShapedType::kDynamic)
2686 outputShape[1] = indicesShape.getDimSize(1);
2704 int64_t N = ShapedType::kDynamic;
2705 int64_t
W = ShapedType::kDynamic;
2706 int64_t
C = ShapedType::kDynamic;
2708 if (valuesShape.hasRank()) {
2709 N = valuesShape.getDimSize(0);
2710 C = valuesShape.getDimSize(2);
2712 if (indicesShape.hasRank()) {
2713 const int64_t indicesN = indicesShape.getDimSize(0);
2714 W = indicesShape.getDimSize(1);
2715 if (N == ShapedType::kDynamic)
2717 else if (indicesN != ShapedType::kDynamic && N != indicesN)
2718 return emitOpError() <<
"requires indices dimension 0 to have size " << N
2719 <<
", got " << indicesN;
2721 if (outputShape.hasRank()) {
2722 const int64_t outputN = outputShape.getDimSize(0);
2723 const int64_t outputW = outputShape.getDimSize(1);
2724 const int64_t outputC = outputShape.getDimSize(2);
2725 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2727 return emitOpError() <<
"requires output dimension 0 to have size " << N
2728 <<
", got " << outputN;
2730 if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2732 return emitOpError() <<
"requires output dimension 1 to have size " <<
W
2733 <<
", got " << outputW;
2734 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2736 return emitOpError() <<
"requires output dimension 2 to have size " <<
C
2737 <<
", got " << outputC;
2742 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2743 MLIRContext *context, ::std::optional<Location> location,
2744 ResizeOp::Adaptor adaptor,
2747 outputShape.resize(4, ShapedType::kDynamic);
2750 if (!inputShape.hasRank())
2753 outputShape[0] = inputShape.getDimSize(0);
2754 outputShape[3] = inputShape.getDimSize(3);
2755 int64_t inputHeight = inputShape.getDimSize(1);
2756 int64_t inputWidth = inputShape.getDimSize(2);
2758 if ((inputHeight == ShapedType::kDynamic) ||
2759 (inputWidth == ShapedType::kDynamic))
2773 const int64_t outputHeight =
2774 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2778 const int64_t outputWidth =
2779 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2783 if (outputHeight < 0 || outputWidth < 0) {
2786 "calculated output height and width must be non-negative, "
2788 outputHeight,
", width = ", outputWidth);
2791 outputShape[1] = outputHeight;
2792 outputShape[2] = outputWidth;
2798 const Value input = getInput();
2799 const Value output = getOutput();
2800 const RankedTensorType inputType =
2801 llvm::dyn_cast<RankedTensorType>(input.
getType());
2802 const RankedTensorType outputType =
2803 llvm::dyn_cast<RankedTensorType>(output.
getType());
2815 if (llvm::any_of(scaleValues, [](int64_t s) {
return s <= 0; }))
2816 return emitOpError(
"expect all scale values to be > 0, got ")
2819 const int64_t scaleYN = scaleValues[0];
2820 const int64_t scaleYD = scaleValues[1];
2821 const int64_t scaleXN = scaleValues[2];
2822 const int64_t scaleXD = scaleValues[3];
2824 const int64_t offsetY = offsetValues[0];
2825 const int64_t offsetX = offsetValues[1];
2827 const int64_t borderY = borderValues[0];
2828 const int64_t borderX = borderValues[1];
2835 const int64_t oh = outputType.getDimSize(1);
2836 const int64_t ow = outputType.getDimSize(2);
2837 const int64_t ih = inputType.getDimSize(1);
2838 const int64_t iw = inputType.getDimSize(2);
2844 if (ih != ShapedType::kDynamic && ih != 1) {
2845 const std::optional<int64_t> calculatedOutHeightMinusOne =
2846 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
2847 if (!calculatedOutHeightMinusOne.has_value())
2848 return emitOpError(
"expected (input_height - 1) * scale_y_n - offset_y + "
2850 <<
"to be wholly divisible by scale_y_d, got ((" << ih
2851 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
2852 <<
") / " << scaleYD;
2853 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
2854 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
2855 return emitOpError(
"calculated output height did not match expected: ")
2856 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
2863 if (iw != ShapedType::kDynamic && iw != 1) {
2864 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
2865 const std::optional<int64_t> calculatedOutWidthMinusOne =
2867 if (!calculatedOutWidthMinusOne.has_value())
2868 return emitOpError(
"expected (input_width - 1) * scale_x_n - offset_x + "
2870 <<
"to be wholly divisible by scale_x_d, got ((" << iw
2871 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
2872 <<
") / " << scaleXD;
2873 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
2874 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
2875 return emitOpError(
"calculated output width did not match expected: ")
2876 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
2882 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
2883 MLIRContext *context, ::std::optional<Location> location,
2884 ScatterOp::Adaptor adaptor,
2887 outputShape.resize(3, ShapedType::kDynamic);
2889 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
2890 if (valuesInShape.hasRank()) {
2891 outputShape[0] = valuesInShape.getDimSize(0);
2892 outputShape[1] = valuesInShape.getDimSize(1);
2893 outputShape[2] = valuesInShape.getDimSize(2);
2896 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2897 if (indicesShape.hasRank()) {
2898 if (outputShape[0] == ShapedType::kDynamic)
2899 outputShape[0] = indicesShape.getDimSize(0);
2903 if (inputShape.hasRank()) {
2904 if (outputShape[0] == ShapedType::kDynamic)
2905 outputShape[0] = inputShape.getDimSize(0);
2906 if (outputShape[2] == ShapedType::kDynamic)
2907 outputShape[2] = inputShape.getDimSize(2);
2929 int64_t N = ShapedType::kDynamic;
2930 int64_t K = ShapedType::kDynamic;
2931 int64_t
W = ShapedType::kDynamic;
2932 int64_t
C = ShapedType::kDynamic;
2933 if (valuesInShape.hasRank()) {
2934 N = valuesInShape.getDimSize(0);
2935 K = valuesInShape.getDimSize(1);
2936 C = valuesInShape.getDimSize(2);
2938 if (indicesShape.hasRank()) {
2939 const int64_t indicesN = indicesShape.getDimSize(0);
2940 W = indicesShape.getDimSize(1);
2941 if (N == ShapedType::kDynamic)
2943 else if (indicesN != ShapedType::kDynamic && N != indicesN)
2944 return emitOpError() <<
"requires indices dimension 0 to have size " << N
2945 <<
", got " << indicesN;
2947 if (inputShape.hasRank()) {
2948 const int64_t inputN = inputShape.getDimSize(0);
2949 const int64_t inputW = inputShape.getDimSize(1);
2950 const int64_t inputC = inputShape.getDimSize(2);
2951 if (N == ShapedType::kDynamic)
2953 else if (inputN != ShapedType::kDynamic && N != inputN)
2954 return emitOpError() <<
"requires input dimension 0 to have size " << N
2955 <<
", got " << inputN;
2956 if (W == ShapedType::kDynamic)
2958 else if (inputW != ShapedType::kDynamic && W != inputW)
2959 return emitOpError() <<
"requires input dimension 1 to have size " <<
W
2960 <<
", got " << inputW;
2962 if (C == ShapedType::kDynamic)
2964 else if (inputC != ShapedType::kDynamic && C != inputC)
2965 return emitOpError() <<
"requires input dimension 2 to have size " <<
C
2966 <<
", got " << inputC;
2968 if (outputShape.hasRank()) {
2969 const int64_t outputN = outputShape.getDimSize(0);
2970 const int64_t outputK = outputShape.getDimSize(1);
2971 const int64_t outputC = outputShape.getDimSize(2);
2972 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2974 return emitOpError() <<
"requires values_out dimension 0 to have size "
2975 << N <<
", got " << outputN;
2976 if (K == ShapedType::kDynamic)
2978 else if (outputK != ShapedType::kDynamic && K != outputK)
2979 return emitOpError() <<
"requires values_out dimension 1 to have size "
2980 << K <<
", got " << outputK;
2981 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2983 return emitOpError() <<
"requires values_out dimension 2 to have size "
2984 <<
C <<
", got " << outputC;
2986 if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
2987 return emitOpError() <<
"requires dimensions K >= W, got K=" << K
2996 int64_t axisVal = axis.getValue().getSExtValue();
2997 if (!operandShape.
hasRank() || operandShape.
getRank() <= axisVal) {
3003 operandShape.
getDims(outputShape);
3004 outputShape[axisVal] = 1;
3009 #define COMPATIBLE_RETURN_TYPES(OP) \
3010 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
3011 if (l.size() != r.size() || l.size() != 1) \
3013 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
3015 return succeeded(verifyCompatibleShape(l[0], r[0])); \
3018 #define REDUCE_SHAPE_INFER(OP) \
3019 LogicalResult OP::inferReturnTypeComponents( \
3020 MLIRContext *context, ::std::optional<Location> location, \
3021 OP::Adaptor adaptor, \
3022 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3024 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
3025 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
3026 const Properties &prop = adaptor.getProperties(); \
3027 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
3028 inferredReturnShapes); \
3030 COMPATIBLE_RETURN_TYPES(OP)
3038 #undef REDUCE_SHAPE_INFER
3040 #undef COMPATIBLE_RETURN_TYPES
3042 template <
typename T>
3045 TensorType inputType = op.getInput().getType();
3046 TensorType outputType = op.getOutput().getType();
3047 int32_t reduceAxis = op.getAxis();
3049 if (reduceAxis < 0) {
3050 op.emitOpError(
"reduce axis must not be negative");
3054 int64_t inputRank = inputType.getRank();
3057 if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
3058 op.emitOpError(
"expect input tensor rank (")
3059 << inputRank <<
") to be larger than reduce axis (" << reduceAxis
3065 int64_t outputRank = outputType.getRank();
3066 if (inputType.
hasRank() && outputRank != inputType.getRank()) {
3068 "expect output tensor rank to be equal to input tensor rank");
3071 if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
3072 op.emitOpError(
"expect output tensor rank (")
3073 << outputRank <<
") to be larger than reduce axis (" << reduceAxis
3079 if (outputRank != 0) {
3080 auto outputShape = outputType.
getShape();
3081 if (!outputType.isDynamicDim(reduceAxis) &&
3082 outputShape[reduceAxis] != 1) {
3083 op.emitOpError(
"expect reduced dimension size to be 1, got ")
3084 << outputShape[reduceAxis];
3111 #define NARY_SHAPE_INFER(OP) \
3112 LogicalResult OP::inferReturnTypeComponents( \
3113 MLIRContext *context, ::std::optional<Location> location, \
3114 ValueShapeRange operands, DictionaryAttr attributes, \
3115 OpaqueProperties properties, RegionRange regions, \
3116 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3117 return NAryInferReturnTypes(operands, inferredReturnShapes); \
3157 #undef PRED_SHAPE_INFER
3159 LogicalResult tosa::NegateOp::inferReturnTypeComponents(
3160 MLIRContext *context, ::std::optional<Location> location,
3161 NegateOp::Adaptor adaptor,
3163 ShapeAdaptor inputShape(adaptor.getInput1().getType());
3170 const Type input1Type = getInput1().getType();
3171 const Type outputType = getOutput().getType();
3178 return emitOpError() <<
"requires the same shape for input1 and output";
3181 const Type input1ZpEType =
3183 if (input1EType != input1ZpEType) {
3184 return emitOpError(
"expect both input1 and its zero point are the same "
3185 "element type, got ")
3186 << input1EType <<
" and " << input1ZpEType;
3189 const Type outputZpEType =
3191 if (outputEType != outputZpEType) {
3192 return emitOpError(
"expect both output and its zero point are the same "
3193 "element type, got ")
3194 << outputEType <<
" and " << outputZpEType;
3197 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
3198 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
3201 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3202 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3213 outputShape.resize(4, ShapedType::kDynamic);
3228 if (ShapedType::isStatic(height)) {
3229 int64_t padded = height + pad[0] + pad[1] - kernel[0];
3230 outputShape[1] = padded / stride[0] + 1;
3233 if (ShapedType::isStatic(width)) {
3234 int64_t padded = width + pad[2] + pad[3] - kernel[1];
3235 outputShape[2] = padded / stride[1] + 1;
3242 LogicalResult Conv2DOp::inferReturnTypeComponents(
3243 MLIRContext *context, ::std::optional<Location> location,
3244 Conv2DOp::Adaptor adaptor,
3248 int64_t inputWidth = ShapedType::kDynamic;
3249 int64_t inputHeight = ShapedType::kDynamic;
3250 int64_t weightWidth = ShapedType::kDynamic;
3251 int64_t weightHeight = ShapedType::kDynamic;
3256 if (inputShape.hasRank()) {
3257 outputShape[0] = inputShape.getDimSize(0);
3258 inputHeight = inputShape.getDimSize(1);
3259 inputWidth = inputShape.getDimSize(2);
3263 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3264 if (weightShape.hasRank()) {
3265 outputShape[3] = weightShape.getDimSize(0);
3266 weightHeight = weightShape.getDimSize(1);
3267 weightWidth = weightShape.getDimSize(2);
3272 if (biasShape.hasRank()) {
3273 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3274 ? biasShape.getDimSize(0)
3282 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3283 int64_t inputSize = inputHeight + padding[0] + padding[1];
3284 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3285 int64_t unstridedResult = inputSize - filterSize + 1;
3286 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3289 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3290 int64_t inputSize = inputWidth + padding[2] + padding[3];
3291 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3292 int64_t unstridedResult = inputSize - filterSize + 1;
3293 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3307 LogicalResult Conv3DOp::inferReturnTypeComponents(
3308 MLIRContext *context, ::std::optional<Location> location,
3309 Conv3DOp::Adaptor adaptor,
3313 int64_t inputWidth = ShapedType::kDynamic;
3314 int64_t inputHeight = ShapedType::kDynamic;
3315 int64_t inputDepth = ShapedType::kDynamic;
3317 int64_t weightWidth = ShapedType::kDynamic;
3318 int64_t weightHeight = ShapedType::kDynamic;
3319 int64_t weightDepth = ShapedType::kDynamic;
3323 if (inputShape.hasRank()) {
3324 outputShape[0] = inputShape.getDimSize(0);
3325 inputDepth = inputShape.getDimSize(1);
3326 inputHeight = inputShape.getDimSize(2);
3327 inputWidth = inputShape.getDimSize(3);
3331 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3332 if (weightShape.hasRank()) {
3333 outputShape[4] = weightShape.getDimSize(0);
3334 weightDepth = weightShape.getDimSize(1);
3335 weightHeight = weightShape.getDimSize(2);
3336 weightWidth = weightShape.getDimSize(3);
3341 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
3342 outputShape[4] = biasShape.getDimSize(0);
3349 if (ShapedType::isStatic(inputDepth) && ShapedType::isStatic(weightDepth)) {
3350 int32_t inputSize = inputDepth + pad[0] + pad[1];
3351 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
3352 int32_t unstridedResult = inputSize - filterSize + 1;
3353 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3356 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3357 int32_t inputSize = inputHeight + pad[2] + pad[3];
3358 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
3359 int32_t unstridedResult = inputSize - filterSize + 1;
3360 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3363 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3364 int32_t inputSize = inputWidth + pad[4] + pad[5];
3365 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
3366 int32_t unstridedResult = inputSize - filterSize + 1;
3367 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3381 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3382 MLIRContext *context, ::std::optional<Location> location,
3383 AvgPool2dOp::Adaptor adaptor,
3386 const Properties &prop = adaptor.getProperties();
3388 inferredReturnShapes);
3391 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3392 MLIRContext *context, ::std::optional<Location> location,
3393 MaxPool2dOp::Adaptor adaptor,
3396 const Properties &prop = adaptor.getProperties();
3398 inferredReturnShapes);
3412 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3413 MLIRContext *context, ::std::optional<Location> location,
3414 DepthwiseConv2DOp::Adaptor adaptor,
3418 int64_t inputWidth = ShapedType::kDynamic;
3419 int64_t inputHeight = ShapedType::kDynamic;
3420 int64_t inputChannels = ShapedType::kDynamic;
3422 int64_t weightWidth = ShapedType::kDynamic;
3423 int64_t weightHeight = ShapedType::kDynamic;
3424 int64_t depthChannels = ShapedType::kDynamic;
3428 if (inputShape.hasRank()) {
3429 outputShape[0] = inputShape.getDimSize(0);
3430 inputHeight = inputShape.getDimSize(1);
3431 inputWidth = inputShape.getDimSize(2);
3432 inputChannels = inputShape.getDimSize(3);
3436 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3437 if (weightShape.hasRank()) {
3438 weightHeight = weightShape.getDimSize(0);
3439 weightWidth = weightShape.getDimSize(1);
3440 inputChannels = ShapedType::isDynamic(inputChannels)
3441 ? weightShape.getDimSize(2)
3443 depthChannels = weightShape.getDimSize(3);
3448 if (ShapedType::isStatic(inputChannels) &&
3449 ShapedType::isStatic(depthChannels)) {
3450 outputShape[3] = inputChannels * depthChannels;
3455 if (biasShape.hasRank()) {
3456 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3457 ? biasShape.getDimSize(0)
3465 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3466 int64_t inputSize = inputHeight + padding[0] + padding[1];
3467 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3468 int64_t unstridedResult = inputSize - filterSize + 1;
3469 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3472 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3473 int64_t inputSize = inputWidth + padding[2] + padding[3];
3474 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3475 int64_t unstridedResult = inputSize - filterSize + 1;
3476 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3490 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3491 MLIRContext *context, ::std::optional<Location> location,
3492 TransposeConv2DOp::Adaptor adaptor,
3496 int64_t inputWidth = ShapedType::kDynamic;
3497 int64_t inputHeight = ShapedType::kDynamic;
3498 int64_t weightWidth = ShapedType::kDynamic;
3499 int64_t weightHeight = ShapedType::kDynamic;
3503 if (inputShape.hasRank()) {
3504 outputShape[0] = ShapedType::isDynamic(outputShape[0])
3505 ? inputShape.getDimSize(0)
3507 inputHeight = inputShape.getDimSize(1);
3508 inputWidth = inputShape.getDimSize(2);
3512 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3513 if (weightShape.hasRank()) {
3514 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3515 ? weightShape.getDimSize(0)
3517 weightHeight = weightShape.getDimSize(1);
3518 weightWidth = weightShape.getDimSize(2);
3523 if (biasShape.hasRank()) {
3524 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3525 ? biasShape.getDimSize(0)
3532 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3533 int64_t calculateSize =
3534 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3536 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3539 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3540 int64_t calculateSize =
3541 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3543 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3555 const int64_t strideY = strides[0];
3556 const int64_t strideX = strides[1];
3558 if (strideY < 1 || strideX < 1)
3559 return emitOpError(
"expect all stride values to be >= 1, got [")
3562 const auto checkPadAgainstKernelDim =
3563 [
this](int64_t pad_value, int64_t kernel_dim_size,
3564 llvm::StringRef pad_name,
3565 llvm::StringRef kernel_dim_name) -> LogicalResult {
3566 if (pad_value <= -kernel_dim_size)
3567 return emitOpError(
"expected ")
3568 << pad_name <<
" > -" << kernel_dim_name
3569 <<
", but got: " << pad_name <<
"=" << pad_value <<
" and "
3570 << kernel_dim_name <<
"=" << kernel_dim_size;
3575 const int64_t outPadTop = padding[0];
3576 const int64_t outPadBottom = padding[1];
3577 const int64_t outPadLeft = padding[2];
3578 const int64_t outPadRight = padding[3];
3580 const auto weightType =
3581 llvm::dyn_cast<RankedTensorType>(getWeight().
getType());
3584 const int64_t kernelHeight = weightType.getDimSize(1);
3585 if (ShapedType::isStatic(kernelHeight)) {
3586 if (
failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
3587 "out_pad_top",
"KH")))
3590 if (
failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3591 "out_pad_bottom",
"KH")))
3595 const int64_t kernelWidth = weightType.getDimSize(2);
3596 if (ShapedType::isStatic(kernelWidth)) {
3597 if (
failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
3598 "out_pad_left",
"KW")))
3601 if (
failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3602 "out_pad_right",
"KW")))
3608 const auto outputType =
3609 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
3613 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
3614 if (inputType && weightType) {
3615 const int64_t inputHeight = inputType.getDimSize(1);
3616 const int64_t kernelHeight = weightType.getDimSize(1);
3617 const int64_t outputHeight = outputType.getDimSize(1);
3619 if (ShapedType::isStatic(inputHeight) &&
3620 ShapedType::isStatic(outputHeight)) {
3622 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3624 "dimension mismatch: expected OH == (IH - 1) * stride_y "
3625 "+ out_pad_top + out_pad_bottom + KH, but got ")
3626 << outputHeight <<
" != (" << inputHeight <<
" - 1) * "
3627 << strideY <<
" + " << outPadTop <<
" + " << outPadBottom
3628 <<
" + " << kernelHeight;
3631 const int64_t inputWidth = inputType.getDimSize(2);
3632 const int64_t kernelWidth = weightType.getDimSize(2);
3633 const int64_t outputWidth = outputType.getDimSize(2);
3635 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
3637 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3639 "dimension mismatch: expected OW == (IW - 1) * stride_x "
3640 "+ out_pad_left + out_pad_right + KW, but got ")
3641 << outputWidth <<
" != (" << inputWidth <<
" - 1) * " << strideX
3642 <<
" + " << outPadLeft <<
" + " << outPadRight <<
" + "
3647 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().
getType());
3652 const int64_t biasChannels = biasType.getDimSize(0);
3655 if (biasChannels == ShapedType::kDynamic)
3658 const int64_t outputChannels = outputType.getDimSize(3);
3659 if (!ShapedType::isDynamic(outputChannels) &&
3660 biasChannels != outputChannels && biasChannels != 1)
3662 "bias channels expected to be equal to output channels (")
3663 << outputChannels <<
") or 1, got " << biasChannels;
3669 auto inputType = llvm::dyn_cast<ShapedType>(getInput().
getType());
3671 emitOpError(
"expect shaped tensor for input, got ") << getInput().getType();
3675 auto inputElementType =
3677 if (!mlir::isa<IntegerType>(inputElementType)) {
3678 emitOpError(
"expect input to have integer element type, got ")
3679 << inputElementType;
3683 auto outputType = llvm::dyn_cast<ShapedType>(getOutput().
getType());
3685 emitOpError(
"expect shaped tensor for output, got ")
3686 << getOutput().getType();
3690 auto outputElementType =
3692 if (!mlir::isa<IntegerType>(outputElementType)) {
3693 emitOpError(
"expect output to have integer element type, got ")
3694 << outputElementType;
3706 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
3707 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
3710 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3711 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3714 auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().
getType());
3715 if (!multiplierType) {
3716 emitOpError(
"expect shaped tensor for multiplier, got ")
3717 << getMultiplier().getType();
3721 auto shiftType = llvm::dyn_cast<ShapedType>(getShift().
getType());
3723 emitOpError(
"expect shaped tensor for shift, got ") << getShift().getType();
3728 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
3729 emitOpError(
"expect i32 element type for multiplier for scale32=true, got ")
3730 << multiplierType.getElementType();
3735 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
3737 "expect i16 element type for multiplier for scale32=false, got ")
3738 << multiplierType.getElementType();
3742 if (!inputType.hasRank())
3748 int64_t numChannels = 1;
3749 if (getPerChannel()) {
3750 if (inputType.getRank() < 1) {
3751 emitOpError(
"requires input to be at least rank 1 when per_channel is "
3752 "true, but got rank ")
3753 << inputType.getRank();
3756 numChannels = inputType.getDimSize(inputType.getRank() - 1);
3759 if (!multiplierType.hasRank())
3764 if (multiplierShape[0] != ShapedType::kDynamic &&
3765 multiplierShape[0] != numChannels) {
3766 emitOpError(
"expect shape of { ")
3767 << numChannels <<
" } for multiplier input, got { "
3768 << multiplierShape[0] <<
" }";
3772 if (!shiftType.hasRank())
3777 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3778 emitOpError(
"expect shape of { ")
3779 << numChannels <<
" } for shift input, got { " << shiftShape[0] <<
" }";
3786 LogicalResult RescaleOp::inferReturnTypeComponents(
3787 MLIRContext *context, ::std::optional<Location> location,
3788 RescaleOp::Adaptor adaptor,
3795 LogicalResult IfOp::inferReturnTypeComponents(
3796 MLIRContext *context, ::std::optional<Location> location,
3797 IfOp::Adaptor adaptor,
3800 for (
Region *region : adaptor.getRegions()) {
3801 for (
auto &block : *region)
3802 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3803 yieldOps.push_back(returnOp);
3806 if (yieldOps.empty())
3811 resultKnowledge.reserve(yieldOps.front().getNumOperands());
3812 for (
auto operand : yieldOps.front().getOperands()) {
3813 resultKnowledge.push_back(
3817 for (
auto yieldOp : yieldOps) {
3818 if (resultKnowledge.size() != yieldOp.getNumOperands())
3822 int32_t index = it.index();
3824 resultKnowledge[index],
3828 resultKnowledge[index] = meet;
3833 inferredReturnShapes.push_back(result.getShapedTypeComponents());
3839 LogicalResult WhileOp::inferReturnTypeComponents(
3840 MLIRContext *context, ::std::optional<Location> location,
3841 WhileOp::Adaptor adaptor,
3844 for (
auto &block : adaptor.getBodyGraph())
3845 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3846 yieldOps.push_back(returnOp);
3850 if (yieldOps.empty())
3855 resultKnowledge.reserve(yieldOps.front().getNumOperands());
3856 for (
auto operand : yieldOps.front().getOperands()) {
3857 resultKnowledge.push_back(
3861 for (
auto yieldOp : yieldOps) {
3862 if (resultKnowledge.size() != yieldOp.getNumOperands())
3866 int32_t index = it.index();
3868 resultKnowledge[index],
3870 resultKnowledge[index] = meet;
3876 inferredReturnShapes.push_back(result.getShapedTypeComponents());
3882 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
3883 if (
auto vt = llvm::dyn_cast<VectorType>(
getType()))
3884 return llvm::to_vector<4>(vt.getShape());
3885 return std::nullopt;
3891 StringRef prefix =
"") {
3892 assert(blocksArgs.size() == initializers.size() &&
3893 "expected same length of arguments and initializers");
3894 if (initializers.empty())
3897 parser << prefix <<
'(';
3898 llvm::interleaveComma(
3899 llvm::zip(blocksArgs, initializers), parser,
3900 [&](
auto it) { parser << std::get<0>(it) <<
" = " << std::get<1>(it); });
3928 "expected type for condition operand");
3934 "expected type for condition operand");
3942 FunctionType functionType;
3946 <<
"expected list of types for block arguments "
3947 <<
"followed by arrow type and list of return types";
3949 result.
addTypes(functionType.getResults());
3951 if (functionType.getNumInputs() != operands.size()) {
3953 <<
"expected as many input types as operands "
3954 <<
"(expected " << operands.size() <<
" got "
3955 << functionType.getNumInputs() <<
")";
3986 p <<
" " << getCondition();
3989 getInputList(),
" ");
3991 p << getCondition().getType();
3993 if (!getInputList().empty()) {
3995 llvm::interleaveComma(getInputList().getTypes(), p);
4004 auto &elseRegion = getElseGraph();
4005 if (!elseRegion.
empty()) {
4015 "'then_graph' arguments", getInputList(),
4021 "'else_graph' arguments", getInputList(),
4026 auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
4028 "'then_graph' results", getOutputList(),
4033 auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
4035 "'else_graph' results", getOutputList(),
4040 auto condType = getCondition().getType();
4042 return emitOpError() <<
"'condition' must be a size 1 tensor, got "
4050 getOutputList(),
"'output_list'")
4055 "'cond_graph' arguments", getInputList(),
4061 "'body_graph' arguments", getInputList(),
4066 auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
4068 "'body_graph' results", getInputList(),
4075 auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
4076 if (condYield.getInputs().size() != 1)
4077 return emitOpError() <<
"require 'cond_graph' only have one result";
4079 auto condOutType = condYield.getInputs()[0].getType();
4081 return emitOpError() <<
"'cond_graph' result must be a size 1 tensor, got "
4085 return emitOpError() <<
"'cond_graph' result must be a boolean tensor, got "
4096 TensorType inputType = getInput1().getType();
4097 TensorType outputType = getOutput().getType();
4098 int32_t reverseAxis = getAxis();
4100 if (reverseAxis < 0)
4101 return emitOpError(
"expected non-negative reverse axis");
4103 int64_t inputRank = inputType.getRank();
4106 if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
4107 return emitOpError(
"expect input tensor rank (")
4108 << inputRank <<
") to be larger than reverse axis (" << reverseAxis
4112 int64_t outputRank = outputType.getRank();
4113 if (inputType.
hasRank() && outputRank != inputType.getRank())
4115 "expect output tensor rank to be equal to input tensor rank");
4116 if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
4117 return emitOpError(
"expect output tensor rank (")
4118 << outputRank <<
") to be larger than reverse axis ("
4119 << reverseAxis <<
")";
4135 auto predicateType = llvm::dyn_cast<ShapedType>(getPred().
getType());
4136 if (!predicateType) {
4137 return emitOpError(
"expect shaped tensor for input1, got ")
4138 << getInput1().getType();
4140 auto predicateElementType = predicateType.getElementType();
4141 if (!predicateElementType.isInteger(1)) {
4142 return emitOpError(
"expect element type of bool for input1, got ")
4143 << predicateElementType;
4150 StringRef symName = getName();
4152 if (succeeded(varOp))
4153 return emitOpError(
"illegal to have multiple declaration of '")
4187 FunctionType functionType;
4192 result.
addTypes(functionType.getResults());
4194 if (functionType.getNumInputs() != operands.size()) {
4196 <<
"expected as many input types as operands "
4197 <<
"(expected " << operands.size() <<
" got "
4198 << functionType.getNumInputs() <<
")";
4208 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
4209 regionArgs[i].type = functionType.getInput(i);
4211 return failure(parser.
parseRegion(*cond, regionArgs) ||
4218 getInputList(),
" ");
4221 getResults().getTypes());
4236 if (llvm::isa<FloatType>(srcElemType)) {
4238 zpType, builder.
getFloatAttr(srcElemType,
static_cast<double>(zp)));
4239 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4241 if (llvm::isa<IntegerType>(srcElemType)) {
4244 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4246 llvm::errs() <<
"zero point is not allowed for unsupported data types\n";
4247 return std::nullopt;
4255 return mlir::isa<tosa::shapeType>(t);
4262 return emitError() <<
"invalid rank (must be >= 0): " << rank;
4268 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
4269 Operation *definingOp = v.getDefiningOp();
4271 return op->
emitOpError(
"shape operand is not compile time resolvable");
4280 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4281 return op->
emitOpError(
"must have operands with tosa shape type");
4285 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4286 return op->
emitOpError(
"must have result with tosa shape type");
4299 auto getRank = [](
const Type type) {
4300 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4306 for (
auto type : operandTypes) {
4307 if (getRank(type) != rank) {
4308 return op->
emitOpError(
"operands don't have matching ranks");
4311 for (
auto type : resultTypes) {
4312 if (getRank(type) != rank) {
4313 return op->
emitOpError(
"result shape has different rank than operands");
4325 auto valuesRank = getValues().getType().getRank();
4326 if (valuesRank != 1)
4327 return emitOpError(
"expect elements in attribute values with rank 1");
4329 auto count = getValues().getNumElements();
4330 auto rank = (cast<tosa::shapeType>(getResult().
getType())).getRank();
4331 if (!(count == rank || (count == 1 && rank == 0))) {
4332 return emitOpError(
"expect number of elements in attribute values (")
4333 << count <<
") to be equal to the rank (" << rank
4334 <<
") for the result shape type";
4343 #define GET_ATTRDEF_CLASSES
4344 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
4349 #define GET_TYPEDEF_CLASSES
4350 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
4356 #define GET_OP_CLASSES
4357 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType)
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
static FailureOr< tosa::VariableOp > findVariableDecl(Operation *op, StringRef symName)
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
#define REDUCE_SHAPE_INFER(OP)
static LogicalResult verifyConvOp(T op)
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
static LogicalResult verifyReduceOp(T op)
#define NARY_SHAPE_INFER(OP)
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
static LogicalResult verifyConvOpErrorIf(T op)
static LogicalResult verifyConvOpModes(T op)
std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
#define COMPATIBLE_RETURN_TYPES(OP)
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Type getStorageElementTypeOrSelf(Type type)
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
static LogicalResult verifyPoolingOp(T op)
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
MutableArrayRef< BlockArgument > BlockArgListType
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
This is a utility class for mapping one set of IR entities to another.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
This class indicates that op operates on tosa shape types.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class provides an abstraction over the different types of ranges over Regions.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperator(Operation *op)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
bool isa_tosa_shape_type(mlir::Type t)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)