28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/TypeSwitch.h"
36 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
43 #include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
44 #include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
45 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
46 #include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
49 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
70 return (isa<tosa::IfOp>(dest->getParentOp()) ||
71 isa<tosa::WhileOp>(dest->getParentOp()));
77 TosaDialectBytecodeInterface(
Dialect *dialect)
87 LogicalResult writeAttribute(
Attribute attr,
89 return ::writeAttribute(attr, writer);
99 LogicalResult writeType(
Type type,
101 return ::writeType(type, writer);
108 std::unique_ptr<DialectVersion>
111 reader.
emitError(
"Dialect does not support versioning");
115 LogicalResult upgradeFromVersion(
Operation *topLevelOp,
129 return {&getBodyGraph()};
137 return to_vector(llvm::map_range(shape, [](int64_t dim) {
138 return dim == -1 ? ShapedType::kDynamic : dim;
144 Type elementType = variableOp.getType();
154 void TosaDialect::initialize() {
156 #define GET_TYPEDEF_LIST
157 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
161 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
164 #define GET_ATTRDEF_LIST
165 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
167 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
168 declarePromisedInterfaces<
169 shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
170 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
171 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
172 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
173 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
174 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
175 GreaterEqualOp, MatMulOp>();
182 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
183 return tosa::ConstShapeOp::create(builder, loc, type,
184 llvm::cast<DenseIntElementsAttr>(value));
186 if (llvm::isa<ElementsAttr>(value))
187 return tosa::ConstOp::create(builder, loc, type,
188 llvm::cast<ElementsAttr>(value));
198 ParseResult getShapeAndElementType(
OpAsmParser &parser,
Type parsedType,
200 TypeAttr &typeAttr) {
201 if (
auto shapedType = dyn_cast<ShapedType>(parsedType)) {
202 if (!shapedType.hasRank())
204 <<
"expected ranked type";
206 auto elementType = shapedType.getElementType();
214 <<
"expected shaped type";
231 <<
"expected attribute";
233 if (
auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
234 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
238 <<
"expected Typed attr";
241 initialValueAttr =
nullptr;
245 <<
"expected type after colon";
247 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
252 TypeAttr typeAttr,
Attribute initialValueAttr) {
253 bool needsSpace =
false;
254 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
257 Type elementType = typeAttr.getValue();
258 RankedTensorType tensorType =
265 if (initialValueAttr) {
276 template <
typename EnumType>
277 ParseResult parseAttrEntryWithEnumHandling(
OpAsmParser &parser,
279 llvm::StringRef name;
286 if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
287 if (name ==
"rounding_mode" &&
289 auto sym = symbolizeRoundingMode(kw);
292 <<
"invalid rounding_mode value: " << kw;
299 if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
301 auto sym = symbolizeResizeMode(kw);
304 <<
"invalid resize mode value: " << kw;
312 if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
314 auto sym = symbolizeNanPropagationMode(kw);
317 <<
"invalid nan_mode value: " << kw;
330 template <
typename EnumType>
335 [&]() { return parser.parseOperand(operands.emplace_back()); }))
343 if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
367 parser << namedAttr.
getName().strref() <<
" = ";
369 if (
auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
370 parser << roundingModeAttr.getValue();
371 }
else if (
auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
372 parser << resizeModeAttr.getValue();
373 }
else if (
auto nanPropagationModeAttr =
374 dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
375 parser << nanPropagationModeAttr.getValue();
388 const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
390 if (
auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
391 if (nanAttr.getValue() == kDefaultNanValue) {
393 toPrint.erase(attr.getName());
399 if (!toPrint.empty()) {
401 llvm::interleaveComma(toPrint, parser, [&](
const NamedAttribute namedAttr) {
402 printNamedAttr(parser, namedAttr);
418 llvm::interleaveComma(op->
getAttrs(), parser,
420 printNamedAttr(parser, namedAttr);
432 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
436 printWithEnumHandling(parser, *
this);
440 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
444 printWithEnumHandling(parser, *
this);
448 return parseWithEnumHandling<tosa::ResizeMode>(parser, result);
452 printWithEnumHandling(parser, *
this);
456 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
460 printWithNanPropagationHandling(parser, *
this);
464 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
468 printWithNanPropagationHandling(parser, *
this);
472 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
476 printWithNanPropagationHandling(parser, *
this);
480 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
484 printWithNanPropagationHandling(parser, *
this);
488 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
492 printWithNanPropagationHandling(parser, *
this);
496 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
500 printWithNanPropagationHandling(parser, *
this);
504 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
508 printWithNanPropagationHandling(parser, *
this);
515 static std::optional<int64_t>
idivCheck(
const int64_t lhs,
const int64_t rhs) {
523 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
524 srcType = quantType.getStorageType();
533 Value valZp, StringRef name) {
538 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
542 if (!bothInts || !sameBitWidth) {
544 <<
"expected " << name <<
" and " << name
545 <<
"_zp to both be integer of the same bitwidth, but got " << eType
546 <<
" vs. " << eZpType;
553 Value src, int32_t val) {
558 const auto padConstAttr{
559 llvm::isa<FloatType>(srcElemType)
564 return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
571 template <
typename T>
573 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
574 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
576 auto inputEType = inputType.getElementType();
577 auto weightEType = weightType.getElementType();
579 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
581 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
582 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
583 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
585 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
586 inputEType = quantType.getStorageType();
588 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
589 weightEType = quantType.getStorageType();
591 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
592 biasEType = quantType.getStorageType();
594 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
595 resultEType = quantType.getStorageType();
597 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
601 "expect both bias and result to have same element type, got ")
602 << biasEType <<
" and " << resultEType;
606 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
607 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
608 if (inputEType != weightEType) {
610 "expect both input and weight to have same element type, got ")
611 << inputEType <<
" and " << weightEType;
616 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
617 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
620 if (inputIsFloat != weightIsFloat) {
622 "expect both input and weight to be float or not together, got ")
623 << inputEType <<
" and " << weightEType;
628 if (inputEType != inputZpEType) {
629 return op.emitOpError(
"expect both input and its zero point are the same "
630 "element type, got ")
631 << inputEType <<
" and " << inputZpEType;
635 if (weightEType != weightZpEType) {
636 return op.emitOpError(
"expect both weight and its zero point are the same "
637 "element type, got ")
638 << weightEType <<
" and " << weightZpEType;
641 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
642 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
645 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
646 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
654 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().
getType());
655 auto outputType = llvm::dyn_cast<TensorType>(getOutput().
getType());
657 if (!attrType || !outputType) {
658 emitOpError(
"expected tensors for attr/result type");
662 if (
auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
663 outputType.getElementType())) {
664 if (result.getStorageType() == attrType.getElementType())
668 if (attrType.getElementType() != outputType.getElementType()) {
669 emitOpError(
"expected same attr/result element types");
676 template <
typename T>
679 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
681 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
682 inputEType = quantType.getStorageType();
684 auto accType = op.getAccType();
685 if (inputEType.isInteger(8) && !accType.isInteger(32))
686 return op.emitOpError(
"accumulator type for i8 tensor is not i32");
688 if (inputEType.isInteger(16) && !accType.isInteger(48))
689 return op.emitOpError(
"accumulator type for i16 tensor is not i48");
691 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
692 return op.emitOpError(
"accumulator type for f8 tensor is not f16");
694 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
695 return op.emitOpError(
"accumulator type for f16 tensor is not f16/f32");
697 if (inputEType.isBF16() && !accType.isF32())
698 return op.emitOpError(
"accumulator type for bf16 tensor is not f32");
700 if (inputEType.isF32() && !accType.isF32())
701 return op.emitOpError(
"accumulator type for f32 tensor is not f32");
704 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
706 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
707 resultEType = quantType.getStorageType();
717 template <
typename T>
720 if (llvm::any_of(padding, [](int64_t p) {
return p < 0; }))
721 return op.emitOpError(
"expect all padding values to be >= 0, got ")
725 if (llvm::any_of(strides, [](int64_t s) {
return s < 1; }))
726 return op.emitOpError(
"expect all stride values to be >= 1, got ")
730 if (llvm::any_of(dilations, [](int64_t d) {
return d < 1; }))
731 return op.emitOpError(
"expect all dilation values to be >= 1, got ")
734 const RankedTensorType outputType =
735 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
740 const RankedTensorType inputType =
741 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
742 const RankedTensorType weightType =
743 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
745 if (inputType && weightType) {
746 const auto verifyOutputSize =
747 [&op](
const int64_t inputSize,
const int64_t kernelSize,
748 const int64_t outputSize,
const int64_t padBefore,
749 const int64_t padAfter,
const int64_t stride,
750 const int64_t dilation,
const llvm::StringRef dimName,
751 const llvm::StringRef dimAxis,
752 const llvm::StringRef padBeforeName,
753 const llvm::StringRef padAfterName) -> LogicalResult {
754 if (inputSize == ShapedType::kDynamic ||
755 kernelSize == ShapedType::kDynamic)
760 const std::optional<int64_t> calculatedOutSizeMinusOne =
idivCheck(
761 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
763 if (!calculatedOutSizeMinusOne.has_value())
764 return op.emitOpError(
"expected input_")
765 << dimName <<
" - 1 + pad_" << padBeforeName <<
" + pad_"
766 << padAfterName <<
" - (kernel_" << dimName
767 <<
" - 1) * dilation_" << dimAxis
768 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
769 << inputSize <<
" - 1 + " << padBefore <<
" + " << padAfter
770 <<
" - (" << kernelSize <<
" - 1) * " << dilation <<
") / "
773 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
774 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
775 return op.emitOpError(
"calculated output ")
776 << dimName <<
" did not match expected: "
777 <<
"calculated=" << calculatedOutSize
778 <<
", expected=" << outputSize;
784 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
785 if (
failed(verifyOutputSize(
786 inputType.getDimSize(1), weightType.getDimSize(1),
787 outputType.getDimSize(1), padding[0], padding[1], strides[0],
788 dilations[0],
"height",
"y",
"top",
"bottom")))
791 if (
failed(verifyOutputSize(
792 inputType.getDimSize(2), weightType.getDimSize(2),
793 outputType.getDimSize(2), padding[2], padding[3], strides[1],
794 dilations[1],
"width",
"x",
"left",
"right")))
799 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
800 if (
failed(verifyOutputSize(
801 inputType.getDimSize(1), weightType.getDimSize(0),
802 outputType.getDimSize(1), padding[0], padding[1], strides[0],
803 dilations[0],
"height",
"y",
"top",
"bottom")))
806 if (
failed(verifyOutputSize(
807 inputType.getDimSize(2), weightType.getDimSize(1),
808 outputType.getDimSize(2), padding[2], padding[3], strides[1],
809 dilations[1],
"width",
"x",
"left",
"right")))
814 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
815 if (
failed(verifyOutputSize(
816 inputType.getDimSize(1), weightType.getDimSize(1),
817 outputType.getDimSize(1), padding[0], padding[1], strides[0],
818 dilations[0],
"depth",
"d",
"front",
"back")))
821 if (
failed(verifyOutputSize(
822 inputType.getDimSize(2), weightType.getDimSize(2),
823 outputType.getDimSize(2), padding[2], padding[3], strides[1],
824 dilations[1],
"height",
"y",
"top",
"bottom")))
827 if (
failed(verifyOutputSize(
828 inputType.getDimSize(3), weightType.getDimSize(3),
829 outputType.getDimSize(3), padding[4], padding[5], strides[2],
830 dilations[2],
"width",
"x",
"left",
"right")))
835 const RankedTensorType biasType =
836 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
841 const int64_t biasChannels = biasType.getDimSize(0);
842 const int64_t outputChannels =
843 outputType.getDimSize(outputType.getRank() - 1);
844 if (biasChannels == ShapedType::kDynamic ||
845 outputChannels == ShapedType::kDynamic)
849 if (biasChannels != outputChannels && biasChannels != 1)
850 return op.emitOpError(
851 "bias channels expected to be equal to output channels (")
852 << outputChannels <<
") or 1, got " << biasChannels;
859 StringRef name1,
Type type2,
861 auto shapeType1 = dyn_cast<ShapedType>(type1);
862 auto shapeType2 = dyn_cast<ShapedType>(type2);
863 if (!shapeType1 || !shapeType2)
866 auto elemType1 = shapeType1.getElementType();
867 auto elemType2 = shapeType2.getElementType();
868 if (elemType1 != elemType2)
870 <<
"require same element type for " << name1 <<
" (" << elemType1
871 <<
") and " << name2 <<
" (" << elemType2 <<
")";
875 <<
"require same shapes for " << name1 <<
" (" << type1 <<
") and "
876 << name2 <<
" (" << type2 <<
")";
886 if (list1.size() != list2.size())
888 <<
"require same number of values in " << name1 <<
" ("
889 << list1.size() <<
") and " << name2 <<
" (" << list2.size() <<
")";
891 for (
auto [type1, type2] :
905 return shapeAdaptor.
getNumElements() == 1 ? success() : failure();
908 template <
typename T>
911 op->template getParentWithTrait<OpTrait::SymbolTable>();
918 const auto varOp = symTable.
lookup<tosa::VariableOp>(op.getName());
923 << op.getName() <<
"' has not been declared by 'tosa.variable'";
935 template <
typename T>
937 auto inputType = llvm::dyn_cast<TensorType>(inType);
938 auto outputType = llvm::dyn_cast<TensorType>(outType);
940 op.emitOpError(
"expect shaped tensor for input, got ") << inType;
944 op.emitOpError(
"expect shaped tensor for output, got ") << outType;
947 auto inputElementType = inputType.getElementType();
948 auto outputElementType = outputType.getElementType();
949 auto inputQuantType =
950 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
951 auto outputQuantType =
952 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
953 if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
954 (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
955 inputElementType != outputElementType) {
960 op.emitOpError(
"expect input and output to have same element type, got ")
961 << inputElementType <<
" and " << outputElementType;
968 const ShapedType resultType = llvm::cast<ShapedType>(
getType());
971 if (
const auto resultETy = resultType.getElementType();
972 !resultETy.isIntOrIndex())
973 return emitOpError(
"result tensor is not of integer type");
975 const auto inputType = llvm::cast<ShapedType>(getInput().
getType());
976 if (!inputType.hasRank())
980 const int64_t axis = getAxisAttr().getInt();
981 if (((axis < 0) || axis >= inputType.getRank()))
982 return emitOpError(
"specified axis is outside the rank of the tensor");
984 if (!resultType.hasRank())
990 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
992 return emitOpError(
"expected output shape '")
993 << expectedOutputShape <<
"', got '" << outputShape <<
"'";
998 template <
typename T>
1001 if (llvm::any_of(kernel, [](int64_t s) {
return s < 1; }))
1002 return op.emitOpError(
"expect all kernel values to be >= 1, got ")
1006 if (llvm::any_of(strides, [](int64_t s) {
return s < 1; }))
1007 return op.emitOpError(
"expect all stride values to be >= 1, got ")
1011 if (llvm::any_of(padding, [](int64_t p) {
return p < 0; }))
1012 return op.emitOpError(
"expect all padding values to be >= 0, got ")
1016 const int64_t kernelX = kernel[1];
1017 const int64_t padLeft = padding[2];
1018 const int64_t padRight = padding[3];
1019 if (padRight >= kernelX || padLeft >= kernelX)
1020 return op.emitOpError(
"expected left/right padding to be less than the "
1021 "width of the kernel, got pad_left=")
1022 << padLeft <<
", pad_right=" << padRight <<
", kernel_x=" << kernelX;
1024 const int64_t kernelY = kernel[0];
1025 const int64_t padTop = padding[0];
1026 const int64_t padBottom = padding[1];
1027 if (padTop >= kernelY || padBottom >= kernelY)
1028 return op.emitOpError(
"expected top/bottom padding to be less than the "
1029 "height of the kernel, got pad_top=")
1030 << padTop <<
", pad_bottom=" << padBottom
1031 <<
", kernel_y=" << kernelY;
1033 const auto inputType =
1034 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
1035 const auto outputType =
1036 llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
1037 if (!inputType || !outputType)
1040 const auto verifyOutputSize =
1041 [&op](
const int64_t inputSize,
const int64_t outputSize,
1042 const int64_t kernelSize,
const int64_t strideSize,
1043 const int64_t padBefore,
const int64_t padAfter,
1044 const llvm::StringRef dimName,
const llvm::StringRef dimAxis,
1045 const llvm::StringRef padBeforeName,
1046 const llvm::StringRef padAfterName) -> LogicalResult {
1047 if (ShapedType::isDynamic(inputSize))
1050 const std::optional<int64_t> calculatedOutSizeMinusOne =
1051 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1052 if (!calculatedOutSizeMinusOne.has_value())
1053 return op.emitOpError(
"expected input_")
1054 << dimName <<
" + pad_" << padBeforeName <<
" + pad_"
1055 << padAfterName <<
" - kernel_" << dimAxis
1056 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
1057 << inputSize <<
" + " << padBefore <<
" + " << padAfter <<
" - "
1058 << kernelSize <<
") / " << strideSize;
1060 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1061 if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1062 return op.emitOpError(
"calculated output ")
1063 << dimName <<
" did not match expected: "
1064 <<
"calculated=" << calculatedOutSize
1065 <<
", expected=" << outputSize;
1070 if (
failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
1071 kernel[0], strides[0], padding[0], padding[1],
1072 "height",
"y",
"top",
"bottom")))
1075 if (
failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
1076 kernel[1], strides[1], padding[2], padding[3],
1077 "width",
"x",
"left",
"right")))
1092 auto accType = getAccType();
1093 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1094 return emitOpError(
"accumulator type for integer tensor is not i32");
1096 if (inputETy.
isF16() && !(accType.isF16() || accType.isF32()))
1097 return emitOpError(
"accumulator type for f16 tensor is not f16/f32");
1099 if (inputETy.
isBF16() && !accType.isF32())
1100 return emitOpError(
"accumulator type for bf16 tensor is not f32");
1102 if (inputETy.
isF32() && !accType.isF32())
1103 return emitOpError(
"accumulator type for f32 tensor is not f32");
1105 if (inputETy != inputZpETy)
1106 return emitOpError(
"expect both input and its zero point are the same "
1107 "element type, got ")
1108 << inputETy <<
" and " << inputZpETy;
1110 if (resultETy != outputZpETy)
1111 return emitOpError(
"expect both output and its zero point are the same "
1112 "element type, got ")
1113 << resultETy <<
" and " << outputZpETy;
1115 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
1116 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
1119 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1120 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
1128 llvm::cast<ShapedType>(getInput().
getType()).getElementType();
1129 if (
auto quantType =
1130 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1131 inputETy = quantType.getStorageType();
1134 llvm::cast<ShapedType>(getOutput().
getType()).getElementType();
1135 if (
auto quantType =
1136 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1137 outputETy = quantType.getStorageType();
1139 if (inputETy != outputETy)
1140 return emitOpError(
"input/output element types are incompatible.");
1142 auto maxValAttr = getMaxValAttr();
1143 auto minValAttr = getMinValAttr();
1147 if (inputETy.
isInteger(dataTypeBitWidth)) {
1151 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1152 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1153 if (!intMaxValAttr || !intMinValAttr ||
1154 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1155 (intMaxValAttr.getType() != inputETy))
1156 return emitOpError(
"min/max attributes types are incompatible with "
1157 "input/output element types.");
1160 const bool isBoolean = inputETy.
isInteger(1);
1161 const APInt minVal = intMinValAttr.getValue();
1162 const APInt maxVal = intMaxValAttr.getValue();
1163 if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1164 return emitOpError(
"expected min_val <= max_val, got min_val=")
1165 << minValAttr <<
", max_val=" << maxValAttr;
1170 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1171 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1172 if (!floatMaxValAttr || !floatMinValAttr ||
1173 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1174 (floatMaxValAttr.getType() != inputETy))
1175 return emitOpError(
"min/max attributes types are incompatible with "
1176 "input/output element types.");
1178 const APFloat minVal = floatMinValAttr.getValue();
1179 const APFloat maxVal = floatMaxValAttr.getValue();
1180 if (minVal.isNaN() || maxVal.isNaN())
1181 return emitOpError(
"min/max attributes should not be 'NaN', got min_val=")
1182 << minValAttr <<
", max_val=" << maxValAttr;
1184 if (maxVal < minVal)
1185 return emitOpError(
"expected min_val <= max_val, got min_val=")
1186 << minValAttr <<
", max_val=" << maxValAttr;
1206 result.
addOperands({input, weight, bias, zps.first, zps.second});
1211 Type finalOutputType = outputType;
1228 result.
addOperands({input, weight, bias, zps.first, zps.second});
1232 Type finalOutputType = outputType;
1249 result.
addOperands({a, b, zps.first, zps.second});
1251 Type finalOutputType{outputType};
1254 auto inputBits = eType.getIntOrFloatBitWidth();
1256 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1257 assert(outputShapedType &&
"Output must be a shaped type");
1259 IntegerType accElementType;
1260 if (inputBits == 16)
1265 finalOutputType = outputShapedType.clone(accElementType);
1276 DenseArrayAttr kernel, DenseArrayAttr stride,
1277 DenseArrayAttr pad, TypeAttr accType) {
1280 int64_t outputZp{0};
1282 if (
auto quantAttr =
1284 inputZp = quantAttr.getInputZp();
1285 outputZp = quantAttr.getOutputZp();
1287 const std::optional<Value> inputZpOp =
1292 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1294 const std::optional<Value> outputZpOp =
1297 (void)
emitError(loc,
"Failed to create output zero point tensor for "
1298 "quantized AVG_POOL2D op");
1301 if (inputZpOp && outputZpOp) {
1302 result.
addOperands({input, inputZpOp.value(), outputZpOp.value()});
1313 result.
types.push_back(outputType);
1323 int64_t input1Zp{0};
1324 int64_t outputZp{0};
1327 input1Zp = quantAttr.getInputZp();
1328 outputZp = quantAttr.getOutputZp();
1330 const std::optional<Value> input1ZpOp =
1334 loc,
"Failed to create input1 zero point for quantized NEGATE op");
1337 const std::optional<Value> outputZpOp =
1341 loc,
"Failed to create output zero point for quantized NEGATE op");
1344 if (input1ZpOp && outputZpOp) {
1345 result.
addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1353 result.
types.push_back(outputType);
1366 zp =
static_cast<int32_t
>(quantAttr.getInputZp());
1369 result.
addOperands({input, paddings, padConstOp});
1370 result.
types.push_back(outputType);
1374 StringRef name,
Type variableType,
1379 auto shapedType = dyn_cast<ShapedType>(variableType);
1381 (void)
emitError(loc,
"variable type must be a shaped type");
1384 if (!shapedType.hasRank()) {
1385 (void)
emitError(loc,
"variable type must be a ranked type");
1389 auto elementType = shapedType.getElementType();
1406 int64_t outRank = 0;
1407 for (
int i = 0, e = operands.size(); i != e; ++i) {
1409 if (!shape.hasRank()) {
1414 outRank = std::max<int64_t>(outRank, shape.getRank());
1417 outShape.resize(outRank, 1);
1419 for (
int i = 0, e = operands.size(); i != e; ++i) {
1421 auto rankDiff = outShape.size() - shape.getRank();
1423 for (
size_t i = 0, e = shape.getRank(); i < e; ++i) {
1424 auto dim1 = outShape[i + rankDiff];
1425 auto dim2 = shape.getDimSize(i);
1426 auto resolvedDim = dim1;
1430 }
else if (dim2 == 1) {
1432 }
else if (dim1 != dim2) {
1435 outShape[i + rankDiff] = resolvedDim;
1442 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1443 MLIRContext *context, ::std::optional<Location> location,
1444 ArgMaxOp::Adaptor adaptor,
1447 IntegerAttr axis = adaptor.getProperties().axis;
1448 int32_t axisVal = axis.getValue().getSExtValue();
1450 if (!inputShape.hasRank()) {
1456 outShape.reserve(inputShape.getRank() - 1);
1457 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1460 outShape.push_back(inputShape.getDimSize(i));
1467 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1468 MLIRContext *context, ::std::optional<Location> location,
1469 RFFT2dOp::Adaptor adaptor,
1471 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1473 if (!inputShape.hasRank())
1477 outputShape.resize(3, ShapedType::kDynamic);
1478 outputShape[0] = inputShape.getDimSize(0);
1479 outputShape[1] = inputShape.getDimSize(1);
1480 int64_t inWidth = inputShape.getDimSize(2);
1484 if (inWidth != ShapedType::kDynamic)
1485 outputShape[2] = inWidth / 2 + 1;
1494 const llvm::StringRef dimName) {
1495 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1498 << dimName <<
" to be a power of two, got " << dimSize;
1504 const auto outputTypes = getResultTypes();
1506 return emitOpError(
"expected output shapes to match, got ") << outputTypes;
1508 const auto inputType =
1509 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1513 const int64_t height = inputType.getDimSize(1);
1514 if (ShapedType::isStatic(height) &&
1518 const int64_t width = inputType.getDimSize(2);
1519 if (ShapedType::isStatic(width) &&
1523 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1529 outputType.getShape().drop_back())))
1530 return emitOpError(
"expected batch and height dimensions of input/output "
1531 "to match, got input=")
1532 << inputType <<
" output=" << outputType;
1535 const int64_t outputWidth = outputType.getDimSize(2);
1536 if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1537 (outputWidth != (width / 2) + 1))
1539 "expected output width to be equal to input_width / 2 + 1, got ")
1545 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1546 MLIRContext *context, ::std::optional<Location> location,
1547 FFT2dOp::Adaptor adaptor,
1549 inferredReturnShapes.push_back(
1551 inferredReturnShapes.push_back(
1557 const auto inputRealType =
1558 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1559 const auto inputImagType =
1560 llvm::dyn_cast<RankedTensorType>(getInputImag().
getType());
1561 if (!inputRealType || !inputImagType)
1564 const auto trySelectStaticDim = [](
const int64_t a,
const int64_t b) {
1565 return ShapedType::isDynamic(a) ? a : b;
1568 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1569 inputImagType.getDimSize(1));
1570 if (ShapedType::isStatic(height) &&
1574 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1575 inputImagType.getDimSize(2));
1576 if (ShapedType::isStatic(width) &&
1583 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1584 MLIRContext *context, ::std::optional<Location> location,
1585 ConcatOp::Adaptor adaptor,
1588 const Properties &prop = adaptor.getProperties();
1589 int32_t axis = prop.axis.getValue().getSExtValue();
1591 bool hasRankedInput =
false;
1592 for (
auto operand : adaptor.getOperands()) {
1594 if (!operandShape.hasRank())
1598 if (!hasRankedInput)
1599 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1602 for (
int i = 0, s = operandShape.getRank(); i < s; i++) {
1603 if (i == axis || operandShape.isDynamicDim(i))
1605 if (outputShape[i] == ShapedType::kDynamic)
1606 outputShape[i] = operandShape.getDimSize(i);
1607 if (outputShape[i] != operandShape.getDimSize(i))
1609 "Cannot concat tensors with different sizes"
1610 " on the non-axis dimension ",
1614 hasRankedInput =
true;
1617 if (adaptor.getInput1().empty())
1621 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1622 if (!hasRankedInput) {
1628 int64_t concatDimSize = 0;
1629 for (
auto operand : adaptor.getOperands()) {
1634 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1635 concatDimSize = ShapedType::kDynamic;
1639 concatDimSize += operandShape.getDimSize(axis);
1642 outputShape[axis] = concatDimSize;
1650 auto outType = getOutput().getType();
1654 if (inputList.empty())
1655 return emitOpError(
"expect at least one input");
1657 if (!llvm::all_of(inputList, [&](
auto input) {
1659 *
this, input.getType(), outType));
1664 const int32_t axis = getAxis();
1666 for (
const auto &input : inputList) {
1667 const Type inputType = input.getType();
1669 if (currShape.hasRank()) {
1670 firstRankedInputShape = currShape;
1672 if (axis < 0 || axis >= firstRankedInputShape.
getRank())
1673 return emitOpError(
"expect axis to be within range 0 < axis < "
1674 "rank(input1[firstRankedTensorIdx]), got ")
1680 const auto allOperandsHasRank = [](
const Value input) {
1683 if (llvm::all_of(inputList, allOperandsHasRank)) {
1684 const int64_t firstInputRank = firstRankedInputShape.
getRank();
1686 for (
const auto &[index, input] :
llvm::enumerate(inputList.drop_front())) {
1688 const int64_t inputRank = inputShape.getRank();
1689 const size_t operandNum = index + 1;
1692 if (inputRank != firstInputRank)
1694 "expect all operands to have the same rank, but got ")
1695 << firstInputRank <<
" vs " << inputRank <<
" on operands 0 and "
1699 for (
int i = 0; i < inputRank; i++) {
1700 const int64_t inputDim = inputShape.getDimSize(i);
1701 const int64_t firstInputDim = firstRankedInputShape.
getDimSize(i);
1702 if (i == axis || firstRankedInputShape.
isDynamicDim(i) ||
1703 inputShape.isDynamicDim(i))
1705 if (inputDim != firstInputDim)
1706 return emitOpError(
"expect all operand shapes to have the same sizes "
1707 "on non-axis dimensions, but got ")
1708 << inputDim <<
" vs " << firstInputDim <<
" at index " << i
1709 <<
" on operands 0 and " << operandNum;
1714 int64_t axisSum = 0;
1715 for (
const auto &input : inputList) {
1717 if (inputShape.isDynamicDim(axis)) {
1722 axisSum += inputShape.getDimSize(axis);
1725 if (axisSum >= 0 && outputShape.hasRank() &&
1726 !outputShape.isDynamicDim(axis) &&
1727 axisSum != outputShape.getDimSize(axis))
1728 return emitOpError(
"requires sum of axis dimensions of input1 "
1729 "equal to output axis dimension, got ")
1730 << axisSum <<
" and " << outputShape.getDimSize(axis);
1736 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1737 MLIRContext *context, ::std::optional<Location> location,
1754 if (l.size() != r.size() || l.size() != 1)
1759 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1760 MLIRContext *context, ::std::optional<Location> location,
1761 MatMulOp::Adaptor adaptor,
1768 outShape.resize(3, ShapedType::kDynamic);
1770 if (lhsShape.hasRank()) {
1771 outShape[0] = lhsShape.getDimSize(0);
1772 outShape[1] = lhsShape.getDimSize(1);
1775 if (rhsShape.hasRank()) {
1776 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1778 outShape[2] = rhsShape.getDimSize(2);
1786 auto aType = llvm::dyn_cast<ShapedType>(getA().
getType());
1787 auto bType = llvm::dyn_cast<ShapedType>(getB().
getType());
1791 return emitOpError(
"expect a shaped tensor for input a, got ")
1792 << getA().getType();
1795 return emitOpError(
"expect a shaped tensor for input b, got ")
1796 << getB().getType();
1798 auto aElementType = aType.getElementType();
1799 auto bElementType = bType.getElementType();
1801 auto aQuantizedEType =
1802 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1803 auto bQuantizedEType =
1804 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1806 if (aQuantizedEType || bQuantizedEType) {
1807 if (!aQuantizedEType || !bQuantizedEType) {
1808 return emitOpError(
"expect operands to be both quantized or both not "
1810 << aElementType <<
" and " << bElementType;
1813 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1814 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1815 if (aQuantWidth != bQuantWidth) {
1816 return emitOpError(
"expect quantized operands to have same widths, got ")
1817 << aQuantWidth <<
" and " << bQuantWidth;
1824 if (aEType != aZpEType) {
1825 return emitOpError(
"expect input a and a_zp have the same "
1826 "element type, got ")
1827 << aEType <<
" and " << aZpEType;
1832 if (bEType != bZpEType) {
1833 return emitOpError(
"expect input b and b_zp have the same "
1834 "element type, got ")
1835 << bEType <<
" and " << bZpEType;
1838 FailureOr<int64_t> maybeAZp = getAZeroPoint();
1839 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1842 FailureOr<int64_t> maybeBZp = getBZeroPoint();
1843 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1849 LogicalResult tosa::PadOp::inferReturnTypeComponents(
1850 MLIRContext *context, ::std::optional<Location> location,
1851 PadOp::Adaptor adaptor,
1853 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1855 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
1860 if (!inputShape.hasRank()) {
1861 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
1870 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
1875 outputShape.reserve(inputShape.getRank());
1876 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1877 if (inputShape.isDynamicDim(i)) {
1878 outputShape.push_back(ShapedType::kDynamic);
1881 auto padFront = paddingValues[i * 2];
1882 auto padBack = paddingValues[i * 2 + 1];
1883 if (padFront < 0 || padBack < 0) {
1885 outputShape.push_back(ShapedType::kDynamic);
1889 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
1903 if (
auto padConst = getPadConst()) {
1911 RankedTensorType inputType =
1912 llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
1913 RankedTensorType outputType =
1914 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
1915 if (!inputType || !outputType)
1918 auto inputRank = inputType.getRank();
1919 auto outputRank = outputType.getRank();
1920 if (inputRank != outputRank)
1921 return emitOpError() <<
"expect same input and output tensor rank, but got "
1922 <<
"inputRank: " << inputRank
1923 <<
", outputRank: " << outputRank;
1930 auto paddingValues = paddingAttr.getValues<APInt>();
1931 if (paddingValues.size() !=
static_cast<size_t>(inputRank * 2))
1932 return emitOpError() <<
"padding tensor must have " << inputRank
1933 <<
" * 2 = " << inputRank * 2 <<
" elements, but got "
1934 << paddingValues.size();
1936 auto inputShape = inputType.getShape();
1937 auto outputShape = outputType.getShape();
1939 for (int64_t i = 0; i < inputRank; ++i) {
1940 int64_t padStart = paddingValues[i * 2].getSExtValue();
1941 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
1943 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
1944 return emitOpError()
1945 <<
"invalid padding values at dimension " << i
1946 <<
": values must be non-negative or -1 for dynamic padding, got ["
1947 << padStart <<
", " << padEnd <<
"]";
1951 if (inputShape[i] == ShapedType::kDynamic ||
1952 outputShape[i] == ShapedType::kDynamic)
1955 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
1956 return emitOpError() <<
"mismatch in output shape at dimension " << i
1957 <<
": expected " << inputShape[i] <<
" + "
1958 << padStart <<
" + " << padEnd <<
" = "
1959 << (inputShape[i] + padStart + padEnd)
1960 <<
", but got " << outputShape[i];
1967 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
1968 MLIRContext *context, ::std::optional<Location> location,
1969 SliceOp::Adaptor adaptor,
1978 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
1986 ShapeAdaptor inputShape(adaptor.getInput1().getType());
1989 if (inputShape.hasRank()) {
1990 for (
size_t i = 0; i < size.size(); i++) {
1991 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
1992 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
1993 start[i] < inputShape.getDimSize(i))) {
1995 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
1998 outputShape[i] = size[i];
2002 if (size[i] == -1) {
2003 outputShape[i] = inputShape.getDimSize(i) - start[i];
2004 }
else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2006 outputShape[i] = size[i];
2025 if (inputShape.hasRank()) {
2026 const auto inputRank = inputShape.getRank();
2028 if (outputShape.hasRank() && inputRank != outputShape.getRank())
2030 "expect input1 and output to have the same ranks, got ")
2031 << inputRank <<
" and " << outputShape.getRank();
2033 const auto startShapeRank =
2034 llvm::cast<tosa::shapeType>(getStart().
getType()).getRank();
2035 if (inputRank != startShapeRank)
2036 return emitOpError(
"length of start is not equal to rank of input shape");
2038 const auto sizeShapeRank =
2039 llvm::cast<tosa::shapeType>(getSize().
getType()).getRank();
2040 if (inputRank != sizeShapeRank)
2041 return emitOpError(
"length of size is not equal to rank of input shape");
2047 LogicalResult tosa::MulOp::inferReturnTypeComponents(
2048 MLIRContext *context, ::std::optional<Location> location,
2064 const Value output = getOutput();
2069 if (
auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2070 IntegerType lhsIntType =
2072 IntegerType rhsIntType =
2074 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2075 return emitOpError(
"requires the same element type for all operands");
2080 if (lhsIntType.getWidth() > resIntType.getWidth())
2081 return emitOpError(
"invalid data type size for operands or result");
2086 for (
int i = 0; i < 2; ++i) {
2089 "requires the same element type for all operands and results");
2093 ElementsAttr shift_elem;
2095 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
2097 return emitOpError() <<
"require shift to be 0 for float type";
2105 TypeRange operandTypes = getOperandTypes();
2106 ShapedType aType = cast<ShapedType>(operandTypes[0]);
2107 ShapedType bType = cast<ShapedType>(operandTypes[1]);
2109 const bool aHasRank = aType.hasRank();
2110 const bool bHasRank = bType.hasRank();
2111 if (aHasRank && bHasRank) {
2112 const int64_t aRank = aType.getRank();
2113 const int64_t bRank = bType.getRank();
2115 return emitOpError(
"a and b operands don't have matching ranks, got ")
2116 << aRank <<
" and " << bRank;
2121 aType.getShape(), bType.getShape(), resultShape))
2122 return emitOpError(
"a and b operands don't have broadcast-compatible "
2124 << aType <<
" and " << bType;
2127 ShapedType resultType = cast<ShapedType>(output.getType());
2128 if (!resultType.hasRank())
2131 const int64_t resultRank = resultType.getRank();
2132 if (aHasRank && resultRank != aType.getRank())
2133 return emitOpError(
"result type has different rank than a, got ")
2134 << resultRank <<
" vs " << aType.getRank();
2135 if (bHasRank && resultRank != bType.getRank())
2136 return emitOpError(
"result type has different rank than b, got ")
2137 << resultRank <<
" vs " << bType.getRank();
2142 LogicalResult tosa::TableOp::inferReturnTypeComponents(
2143 MLIRContext *context, ::std::optional<Location> location,
2144 TableOp::Adaptor adaptor,
2146 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2148 if (!inputShape.hasRank()) {
2153 inferredReturnShapes.resize(1);
2154 inputShape.getDims(inferredReturnShapes[0]);
2159 const TensorType inputType = getInput1().getType();
2160 const TensorType outputType = getOutput().getType();
2165 if (inputType.getRank() != outputType.getRank())
2166 return emitOpError()
2167 <<
"expected input tensor rank to equal result tensor rank";
2169 auto inputDims = inputType.
getShape();
2170 auto outputDims = outputType.
getShape();
2172 int64_t dim = it.index();
2173 auto [inputDim, outputDim] = it.value();
2174 if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2175 return emitOpError() <<
"dim(result, " << dim <<
") = " << outputDim
2176 <<
" doesn't match dim(input, " << dim
2177 <<
") = " << inputDim;
2189 multiples = llvm::to_vector(
2190 llvm::map_range(multiplesAttr.getValues<APInt>(),
2191 [](
const APInt &val) { return val.getSExtValue(); }));
2195 LogicalResult tosa::TileOp::inferReturnTypeComponents(
2196 MLIRContext *context, ::std::optional<Location> location,
2197 TileOp::Adaptor adaptor,
2204 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2212 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2214 if (!inputShape.hasRank()) {
2215 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2216 inferredReturnShapes.push_back(
2219 }
else if (
static_cast<size_t>(inputShape.getRank()) != multiples.size())
2223 outputShape.reserve(multiples.size());
2224 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
2225 if (multiples[i] == ShapedType::kDynamic) {
2226 outputShape.push_back(ShapedType::kDynamic);
2228 int64_t dim = inputShape.getDimSize(i);
2229 if (dim != ShapedType::kDynamic)
2230 dim *= multiples[i];
2231 outputShape.push_back(dim);
2245 ShapedType inputType = llvm::cast<ShapedType>(getInput1().
getType());
2246 ShapedType outputType = llvm::cast<ShapedType>(
getType());
2248 shapeType multiplesType =
2249 llvm::cast<tosa::shapeType>(getMultiples().
getType());
2251 auto multiplesRank = multiplesType.getRank();
2253 if (inputType.hasRank()) {
2254 if (inputType.getRank() != multiplesRank)
2255 return emitOpError(
"expect 'multiples' to have rank ")
2256 << inputType.getRank() <<
" but got " << multiplesRank <<
".";
2257 if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
2258 return emitOpError(
"expect same input and output tensor rank.");
2259 }
else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2260 return emitOpError(
"expect 'multiples' array to have length ")
2261 << outputType.getRank() <<
" but got " << multiplesRank <<
".";
2264 if (getConstantMultiples(multiples).succeeded() &&
2265 llvm::any_of(multiples, [](int64_t v) {
return v <= 0 && v != -1; }))
2267 "expect element of 'multiples' to be positive integer or -1.");
2273 if (l.size() != r.size() || l.size() != 1)
2278 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2279 MLIRContext *context, ::std::optional<Location> location,
2280 ReshapeOp::Adaptor adaptor,
2282 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2287 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2297 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2298 inferredReturnShapes.push_back(
2306 int64_t numElements = inputShape.getNumElements();
2307 int64_t staticMul = 1;
2308 for (
auto val : newShapeValue) {
2309 if (ShapedType::isStatic(val)) {
2315 for (
auto &val : newShapeValue) {
2316 if (ShapedType::isDynamic(val))
2317 val = numElements / staticMul;
2320 inferredReturnShapes.push_back(
2331 TensorType inputType = getInput1().getType();
2336 return mlir::success();
2339 int missingDims = llvm::count(shapeValues, -1);
2340 if (missingDims > 1)
2341 return emitOpError() <<
"expected at most one target dimension to be -1";
2343 const auto outputType = dyn_cast<RankedTensorType>(
getType());
2347 if ((int64_t)shapeValues.size() != outputType.getRank())
2348 return emitOpError() <<
"new shape does not match result rank";
2350 for (
auto [newShapeDim, outputShapeDim] :
2351 zip(shapeValues, outputType.getShape())) {
2352 if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2353 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2354 return emitOpError() <<
"new shape is inconsistent with result shape";
2356 if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2357 return emitOpError() <<
"new shape has invalid tensor dimension size "
2361 if (inputType.hasStaticShape()) {
2362 int64_t inputElementsNum = inputType.getNumElements();
2363 if (outputType.hasStaticShape()) {
2364 int64_t outputElementsNum = outputType.getNumElements();
2365 if (inputElementsNum != outputElementsNum) {
2366 return emitOpError() <<
"cannot reshape " << inputElementsNum
2367 <<
" elements into " << outputElementsNum;
2371 int64_t newShapeElementsNum =
2372 llvm::accumulate(shapeValues, int64_t(1), [](int64_t acc, int64_t dim) {
2373 return (dim > 0) ? acc * dim : acc;
2375 bool isStaticNewShape =
2376 llvm::all_of(shapeValues, [](int64_t s) {
return s > 0; });
2377 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2378 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2379 return emitOpError() <<
"cannot reshape " << inputElementsNum
2380 <<
" elements into " << newShapeElementsNum;
2384 return mlir::success();
2391 ElementsAttr zpAttr;
2396 Type zpElemType = zpAttr.getElementType();
2398 if (llvm::isa<FloatType>(zpElemType)) {
2399 if (zpAttr.getValues<APFloat>()[0].isZero()) {
2406 if (llvm::isa<IntegerType>(zpElemType)) {
2408 return zpAttr.getValues<APInt>()[0].getSExtValue();
2410 return zpAttr.getValues<APInt>()[0].getZExtValue();
2417 template <
typename T>
2419 const std::string &operand) {
2422 if (!zpElemType.
isInteger(8) && zp != 0) {
2424 std::string lower = operand;
2425 std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
2426 return op.emitOpError()
2427 << lower <<
" zero point must be zero for non-int8 integer types";
2435 const std::string &operand) {
2436 bool isInputZp = (operand ==
"Input");
2438 bool tensorUnsigned =
2439 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2440 StringRef tensorName = isInputZp ?
"input" :
"output";
2446 !(zpElemType.
isInteger(16) && tensorUnsigned)) {
2447 return op.emitOpError()
2448 <<
"expect " << tensorName <<
"_zp of 0, got " << zp;
2450 if (zpElemType.
isInteger(16) && tensorUnsigned && zp != 32768) {
2451 return op.emitOpError() <<
"expect " << tensorName
2452 <<
"_zp of 0 or 32768 for unsigned int16 "
2453 << tensorName <<
", got " << zp;
2460 #define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2461 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2462 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2464 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2465 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2484 #undef ZERO_POINT_HELPER
2486 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2487 MLIRContext *context, ::std::optional<Location> location,
2488 TransposeOp::Adaptor adaptor,
2490 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2499 const auto inputRank = inputShape.
getRank();
2503 if (adaptor.getPerms().size() !=
static_cast<size_t>(inputRank)) {
2509 if (inputRank == 0) {
2515 bool allTheSame =
true;
2516 for (
int i = 1, s = inputRank; i < s; i++) {
2526 outputShape.resize(inputRank, inputShape.
getDimSize(0));
2531 outputShape.resize(inputRank, ShapedType::kDynamic);
2534 if (llvm::any_of(adaptor.getPerms(),
2535 [inputRank](
const auto i) { return i >= inputRank; }))
2538 outputShape.reserve(inputRank);
2539 for (
int i = 0, s = inputRank; i < s; i++) {
2540 outputShape[i] = inputShape.
getDimSize(adaptor.getPerms()[i]);
2559 if (inputShape.hasRank() &&
2560 constantPerms.size() !=
static_cast<size_t>(inputShape.getRank()))
2561 return emitOpError() <<
"expected perms attribute to have size "
2562 << inputShape.getRank()
2563 <<
" (input rank) but got size "
2564 << constantPerms.size();
2566 if (inputShape.hasRank() && outputShape.hasRank() &&
2567 inputShape.getRank() != outputShape.getRank())
2568 return emitOpError()
2569 <<
"expected input tensor rank to equal result tensor rank";
2571 if (outputShape.hasRank() &&
2572 constantPerms.size() !=
static_cast<size_t>(outputShape.getRank()))
2573 return emitOpError() <<
"expected perms attribute to have size "
2574 << outputShape.getRank()
2575 <<
" (output rank) but got size "
2576 << constantPerms.size();
2578 if (!llvm::all_of(constantPerms,
2579 [&constantPerms](int32_t s) {
2581 static_cast<size_t>(s) < constantPerms.size();
2584 constantPerms, [](int32_t v) -> int64_t {
return v; }))))
2585 return emitOpError() <<
"expected valid permutation indices";
2588 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2589 inputShape.getNumElements() != outputShape.getNumElements())
2590 return emitOpError() <<
"expected input1 and output to have same numbers "
2592 << inputShape.getNumElements() <<
" and "
2593 << outputShape.getNumElements();
2597 if (inputShape.hasRank() && outputShape.hasRank()) {
2598 for (
auto i = 0; i < outputShape.getRank(); i++) {
2599 if (inputShape.isDynamicDim(constantPerms[i]) ||
2600 outputShape.isDynamicDim(i))
2603 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2604 return emitOpError()
2605 <<
"expected output tensor dim " << i <<
" to match "
2606 <<
"input dim " << constantPerms[i] <<
" with value of "
2607 << inputShape.getDimSize(constantPerms[i]);
2619 Value input = getInput1();
2620 auto inputType = cast<TensorType>(input.
getType());
2623 for (
auto dim : transposePerms) {
2624 int32_t dimInInput = transposePerms[dim];
2625 if (inputType.isDynamicDim(dimInInput))
2627 tensor::DimOp::create(builder, getLoc(), input, dimInInput)
2631 builder.
getIndexAttr(inputType.getDimSize(dimInInput));
2634 reifiedReturnShapes.emplace_back(std::move(returnedDims));
2638 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2639 MLIRContext *context, ::std::optional<Location> location,
2640 GatherOp::Adaptor adaptor,
2643 outputShape.resize(3, ShapedType::kDynamic);
2645 ShapeAdaptor valuesShape(adaptor.getValues().getType());
2646 if (valuesShape.hasRank()) {
2647 outputShape[0] = valuesShape.getDimSize(0);
2648 outputShape[2] = valuesShape.getDimSize(2);
2651 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2652 if (indicesShape.hasRank()) {
2653 if (outputShape[0] == ShapedType::kDynamic)
2654 outputShape[0] = indicesShape.getDimSize(0);
2655 if (outputShape[1] == ShapedType::kDynamic)
2656 outputShape[1] = indicesShape.getDimSize(1);
2674 int64_t N = ShapedType::kDynamic;
2675 int64_t
W = ShapedType::kDynamic;
2676 int64_t
C = ShapedType::kDynamic;
2678 if (valuesShape.hasRank()) {
2679 N = valuesShape.getDimSize(0);
2680 C = valuesShape.getDimSize(2);
2682 if (indicesShape.hasRank()) {
2683 const int64_t indicesN = indicesShape.getDimSize(0);
2684 W = indicesShape.getDimSize(1);
2685 if (N == ShapedType::kDynamic)
2687 else if (indicesN != ShapedType::kDynamic && N != indicesN)
2688 return emitOpError() <<
"requires indices dimension 0 to have size " << N
2689 <<
", got " << indicesN;
2691 if (outputShape.hasRank()) {
2692 const int64_t outputN = outputShape.getDimSize(0);
2693 const int64_t outputW = outputShape.getDimSize(1);
2694 const int64_t outputC = outputShape.getDimSize(2);
2695 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2697 return emitOpError() <<
"requires output dimension 0 to have size " << N
2698 <<
", got " << outputN;
2700 if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2702 return emitOpError() <<
"requires output dimension 1 to have size " <<
W
2703 <<
", got " << outputW;
2704 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2706 return emitOpError() <<
"requires output dimension 2 to have size " <<
C
2707 <<
", got " << outputC;
2712 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2713 MLIRContext *context, ::std::optional<Location> location,
2714 ResizeOp::Adaptor adaptor,
2717 outputShape.resize(4, ShapedType::kDynamic);
2720 if (!inputShape.hasRank())
2723 outputShape[0] = inputShape.getDimSize(0);
2724 outputShape[3] = inputShape.getDimSize(3);
2725 int64_t inputHeight = inputShape.getDimSize(1);
2726 int64_t inputWidth = inputShape.getDimSize(2);
2728 if ((inputHeight == ShapedType::kDynamic) ||
2729 (inputWidth == ShapedType::kDynamic))
2743 const int64_t outputHeight =
2744 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2748 const int64_t outputWidth =
2749 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2753 if (outputHeight < 0 || outputWidth < 0) {
2756 "calculated output height and width must be non-negative, "
2758 outputHeight,
", width = ", outputWidth);
2761 outputShape[1] = outputHeight;
2762 outputShape[2] = outputWidth;
2768 const Value input = getInput();
2769 const Value output = getOutput();
2770 const RankedTensorType inputType =
2771 llvm::dyn_cast<RankedTensorType>(input.
getType());
2772 const RankedTensorType outputType =
2773 llvm::dyn_cast<RankedTensorType>(output.getType());
2785 if (llvm::any_of(scaleValues, [](int64_t s) {
return s <= 0; }))
2786 return emitOpError(
"expect all scale values to be > 0, got ")
2789 const int64_t scaleYN = scaleValues[0];
2790 const int64_t scaleYD = scaleValues[1];
2791 const int64_t scaleXN = scaleValues[2];
2792 const int64_t scaleXD = scaleValues[3];
2794 const int64_t offsetY = offsetValues[0];
2795 const int64_t offsetX = offsetValues[1];
2797 const int64_t borderY = borderValues[0];
2798 const int64_t borderX = borderValues[1];
2805 const int64_t oh = outputType.getDimSize(1);
2806 const int64_t ow = outputType.getDimSize(2);
2807 const int64_t ih = inputType.getDimSize(1);
2808 const int64_t iw = inputType.getDimSize(2);
2814 if (ih != ShapedType::kDynamic && ih != 1) {
2815 const std::optional<int64_t> calculatedOutHeightMinusOne =
2816 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
2817 if (!calculatedOutHeightMinusOne.has_value())
2818 return emitOpError(
"expected (input_height - 1) * scale_y_n - offset_y + "
2820 <<
"to be wholly divisible by scale_y_d, got ((" << ih
2821 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
2822 <<
") / " << scaleYD;
2823 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
2824 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
2825 return emitOpError(
"calculated output height did not match expected: ")
2826 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
2833 if (iw != ShapedType::kDynamic && iw != 1) {
2834 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
2835 const std::optional<int64_t> calculatedOutWidthMinusOne =
2837 if (!calculatedOutWidthMinusOne.has_value())
2838 return emitOpError(
"expected (input_width - 1) * scale_x_n - offset_x + "
2840 <<
"to be wholly divisible by scale_x_d, got ((" << iw
2841 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
2842 <<
") / " << scaleXD;
2843 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
2844 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
2845 return emitOpError(
"calculated output width did not match expected: ")
2846 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
2852 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
2853 MLIRContext *context, ::std::optional<Location> location,
2854 ScatterOp::Adaptor adaptor,
2857 outputShape.resize(3, ShapedType::kDynamic);
2859 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
2860 if (valuesInShape.hasRank()) {
2861 outputShape[0] = valuesInShape.getDimSize(0);
2862 outputShape[1] = valuesInShape.getDimSize(1);
2863 outputShape[2] = valuesInShape.getDimSize(2);
2866 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2867 if (indicesShape.hasRank()) {
2868 if (outputShape[0] == ShapedType::kDynamic)
2869 outputShape[0] = indicesShape.getDimSize(0);
2873 if (inputShape.hasRank()) {
2874 if (outputShape[0] == ShapedType::kDynamic)
2875 outputShape[0] = inputShape.getDimSize(0);
2876 if (outputShape[2] == ShapedType::kDynamic)
2877 outputShape[2] = inputShape.getDimSize(2);
2899 int64_t N = ShapedType::kDynamic;
2900 int64_t K = ShapedType::kDynamic;
2901 int64_t
W = ShapedType::kDynamic;
2902 int64_t
C = ShapedType::kDynamic;
2903 if (valuesInShape.hasRank()) {
2904 N = valuesInShape.getDimSize(0);
2905 K = valuesInShape.getDimSize(1);
2906 C = valuesInShape.getDimSize(2);
2908 if (indicesShape.hasRank()) {
2909 const int64_t indicesN = indicesShape.getDimSize(0);
2910 W = indicesShape.getDimSize(1);
2911 if (N == ShapedType::kDynamic)
2913 else if (indicesN != ShapedType::kDynamic && N != indicesN)
2914 return emitOpError() <<
"requires indices dimension 0 to have size " << N
2915 <<
", got " << indicesN;
2917 if (inputShape.hasRank()) {
2918 const int64_t inputN = inputShape.getDimSize(0);
2919 const int64_t inputW = inputShape.getDimSize(1);
2920 const int64_t inputC = inputShape.getDimSize(2);
2921 if (N == ShapedType::kDynamic)
2923 else if (inputN != ShapedType::kDynamic && N != inputN)
2924 return emitOpError() <<
"requires input dimension 0 to have size " << N
2925 <<
", got " << inputN;
2926 if (W == ShapedType::kDynamic)
2928 else if (inputW != ShapedType::kDynamic && W != inputW)
2929 return emitOpError() <<
"requires input dimension 1 to have size " <<
W
2930 <<
", got " << inputW;
2932 if (C == ShapedType::kDynamic)
2934 else if (inputC != ShapedType::kDynamic && C != inputC)
2935 return emitOpError() <<
"requires input dimension 2 to have size " <<
C
2936 <<
", got " << inputC;
2938 if (outputShape.hasRank()) {
2939 const int64_t outputN = outputShape.getDimSize(0);
2940 const int64_t outputK = outputShape.getDimSize(1);
2941 const int64_t outputC = outputShape.getDimSize(2);
2942 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2944 return emitOpError() <<
"requires values_out dimension 0 to have size "
2945 << N <<
", got " << outputN;
2946 if (K == ShapedType::kDynamic)
2948 else if (outputK != ShapedType::kDynamic && K != outputK)
2949 return emitOpError() <<
"requires values_out dimension 1 to have size "
2950 << K <<
", got " << outputK;
2951 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2953 return emitOpError() <<
"requires values_out dimension 2 to have size "
2954 <<
C <<
", got " << outputC;
2956 if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
2957 return emitOpError() <<
"requires dimensions K >= W, got K=" << K
2966 int64_t axisVal = axis.getValue().getSExtValue();
2967 if (!operandShape.
hasRank() || operandShape.
getRank() <= axisVal) {
2973 operandShape.
getDims(outputShape);
2974 outputShape[axisVal] = 1;
2979 #define COMPATIBLE_RETURN_TYPES(OP) \
2980 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
2981 if (l.size() != r.size() || l.size() != 1) \
2983 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
2985 return succeeded(verifyCompatibleShape(l[0], r[0])); \
2988 #define REDUCE_SHAPE_INFER(OP) \
2989 LogicalResult OP::inferReturnTypeComponents( \
2990 MLIRContext *context, ::std::optional<Location> location, \
2991 OP::Adaptor adaptor, \
2992 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
2994 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
2995 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
2996 const Properties &prop = adaptor.getProperties(); \
2997 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
2998 inferredReturnShapes); \
3000 COMPATIBLE_RETURN_TYPES(OP)
3008 #undef REDUCE_SHAPE_INFER
3010 #undef COMPATIBLE_RETURN_TYPES
3012 template <
typename T>
3015 TensorType inputType = op.getInput().getType();
3016 TensorType outputType = op.getOutput().getType();
3017 int32_t reduceAxis = op.getAxis();
3019 if (reduceAxis < 0) {
3020 op.emitOpError(
"reduce axis must not be negative");
3024 int64_t inputRank = inputType.getRank();
3027 if (reduceAxis >= inputRank && (reduceAxis != 0 || inputRank != 0)) {
3028 op.emitOpError(
"expect input tensor rank (")
3029 << inputRank <<
") to be larger than reduce axis (" << reduceAxis
3035 int64_t outputRank = outputType.getRank();
3036 if (inputType.
hasRank() && outputRank != inputType.getRank()) {
3038 "expect output tensor rank to be equal to input tensor rank");
3041 if (reduceAxis >= outputRank && (reduceAxis != 0 || outputRank != 0)) {
3042 op.emitOpError(
"expect output tensor rank (")
3043 << outputRank <<
") to be larger than reduce axis (" << reduceAxis
3049 if (outputRank != 0) {
3050 auto outputShape = outputType.
getShape();
3051 if (!outputType.isDynamicDim(reduceAxis) &&
3052 outputShape[reduceAxis] != 1) {
3053 op.emitOpError(
"expect reduced dimension size to be 1, got ")
3054 << outputShape[reduceAxis];
3081 #define NARY_SHAPE_INFER(OP) \
3082 LogicalResult OP::inferReturnTypeComponents( \
3083 MLIRContext *context, ::std::optional<Location> location, \
3084 ValueShapeRange operands, DictionaryAttr attributes, \
3085 OpaqueProperties properties, RegionRange regions, \
3086 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3087 return NAryInferReturnTypes(operands, inferredReturnShapes); \
3127 #undef PRED_SHAPE_INFER
3129 LogicalResult tosa::NegateOp::inferReturnTypeComponents(
3130 MLIRContext *context, ::std::optional<Location> location,
3131 NegateOp::Adaptor adaptor,
3133 ShapeAdaptor inputShape(adaptor.getInput1().getType());
3140 const Type input1Type = getInput1().getType();
3141 const Type outputType = getOutput().getType();
3148 return emitOpError() <<
"requires the same shape for input1 and output";
3151 const Type input1ZpEType =
3153 if (input1EType != input1ZpEType) {
3154 return emitOpError(
"expect both input1 and its zero point are the same "
3155 "element type, got ")
3156 << input1EType <<
" and " << input1ZpEType;
3159 const Type outputZpEType =
3161 if (outputEType != outputZpEType) {
3162 return emitOpError(
"expect both output and its zero point are the same "
3163 "element type, got ")
3164 << outputEType <<
" and " << outputZpEType;
3167 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
3168 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
3171 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3172 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3183 outputShape.resize(4, ShapedType::kDynamic);
3198 if (ShapedType::isStatic(height)) {
3199 int64_t padded = height + pad[0] + pad[1] - kernel[0];
3200 outputShape[1] = padded / stride[0] + 1;
3203 if (ShapedType::isStatic(width)) {
3204 int64_t padded = width + pad[2] + pad[3] - kernel[1];
3205 outputShape[2] = padded / stride[1] + 1;
3212 LogicalResult Conv2DOp::inferReturnTypeComponents(
3213 MLIRContext *context, ::std::optional<Location> location,
3214 Conv2DOp::Adaptor adaptor,
3218 int64_t inputWidth = ShapedType::kDynamic;
3219 int64_t inputHeight = ShapedType::kDynamic;
3220 int64_t weightWidth = ShapedType::kDynamic;
3221 int64_t weightHeight = ShapedType::kDynamic;
3226 if (inputShape.hasRank()) {
3227 outputShape[0] = inputShape.getDimSize(0);
3228 inputHeight = inputShape.getDimSize(1);
3229 inputWidth = inputShape.getDimSize(2);
3233 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3234 if (weightShape.hasRank()) {
3235 outputShape[3] = weightShape.getDimSize(0);
3236 weightHeight = weightShape.getDimSize(1);
3237 weightWidth = weightShape.getDimSize(2);
3242 if (biasShape.hasRank()) {
3243 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3244 ? biasShape.getDimSize(0)
3252 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3253 int64_t inputSize = inputHeight + padding[0] + padding[1];
3254 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3255 int64_t unstridedResult = inputSize - filterSize + 1;
3256 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3259 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3260 int64_t inputSize = inputWidth + padding[2] + padding[3];
3261 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3262 int64_t unstridedResult = inputSize - filterSize + 1;
3263 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3277 LogicalResult Conv3DOp::inferReturnTypeComponents(
3278 MLIRContext *context, ::std::optional<Location> location,
3279 Conv3DOp::Adaptor adaptor,
3283 int64_t inputWidth = ShapedType::kDynamic;
3284 int64_t inputHeight = ShapedType::kDynamic;
3285 int64_t inputDepth = ShapedType::kDynamic;
3287 int64_t weightWidth = ShapedType::kDynamic;
3288 int64_t weightHeight = ShapedType::kDynamic;
3289 int64_t weightDepth = ShapedType::kDynamic;
3293 if (inputShape.hasRank()) {
3294 outputShape[0] = inputShape.getDimSize(0);
3295 inputDepth = inputShape.getDimSize(1);
3296 inputHeight = inputShape.getDimSize(2);
3297 inputWidth = inputShape.getDimSize(3);
3301 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3302 if (weightShape.hasRank()) {
3303 outputShape[4] = weightShape.getDimSize(0);
3304 weightDepth = weightShape.getDimSize(1);
3305 weightHeight = weightShape.getDimSize(2);
3306 weightWidth = weightShape.getDimSize(3);
3311 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
3312 outputShape[4] = biasShape.getDimSize(0);
3319 if (ShapedType::isStatic(inputDepth) && ShapedType::isStatic(weightDepth)) {
3320 int32_t inputSize = inputDepth + pad[0] + pad[1];
3321 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
3322 int32_t unstridedResult = inputSize - filterSize + 1;
3323 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3326 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3327 int32_t inputSize = inputHeight + pad[2] + pad[3];
3328 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
3329 int32_t unstridedResult = inputSize - filterSize + 1;
3330 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3333 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3334 int32_t inputSize = inputWidth + pad[4] + pad[5];
3335 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
3336 int32_t unstridedResult = inputSize - filterSize + 1;
3337 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3351 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3352 MLIRContext *context, ::std::optional<Location> location,
3353 AvgPool2dOp::Adaptor adaptor,
3356 const Properties &prop = adaptor.getProperties();
3358 inferredReturnShapes);
3361 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3362 MLIRContext *context, ::std::optional<Location> location,
3363 MaxPool2dOp::Adaptor adaptor,
3366 const Properties &prop = adaptor.getProperties();
3368 inferredReturnShapes);
3382 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3383 MLIRContext *context, ::std::optional<Location> location,
3384 DepthwiseConv2DOp::Adaptor adaptor,
3388 int64_t inputWidth = ShapedType::kDynamic;
3389 int64_t inputHeight = ShapedType::kDynamic;
3390 int64_t inputChannels = ShapedType::kDynamic;
3392 int64_t weightWidth = ShapedType::kDynamic;
3393 int64_t weightHeight = ShapedType::kDynamic;
3394 int64_t depthChannels = ShapedType::kDynamic;
3398 if (inputShape.hasRank()) {
3399 outputShape[0] = inputShape.getDimSize(0);
3400 inputHeight = inputShape.getDimSize(1);
3401 inputWidth = inputShape.getDimSize(2);
3402 inputChannels = inputShape.getDimSize(3);
3406 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3407 if (weightShape.hasRank()) {
3408 weightHeight = weightShape.getDimSize(0);
3409 weightWidth = weightShape.getDimSize(1);
3410 inputChannels = ShapedType::isDynamic(inputChannels)
3411 ? weightShape.getDimSize(2)
3413 depthChannels = weightShape.getDimSize(3);
3418 if (ShapedType::isStatic(inputChannels) &&
3419 ShapedType::isStatic(depthChannels)) {
3420 outputShape[3] = inputChannels * depthChannels;
3425 if (biasShape.hasRank()) {
3426 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3427 ? biasShape.getDimSize(0)
3435 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3436 int64_t inputSize = inputHeight + padding[0] + padding[1];
3437 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3438 int64_t unstridedResult = inputSize - filterSize + 1;
3439 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3442 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3443 int64_t inputSize = inputWidth + padding[2] + padding[3];
3444 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3445 int64_t unstridedResult = inputSize - filterSize + 1;
3446 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3460 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3461 MLIRContext *context, ::std::optional<Location> location,
3462 TransposeConv2DOp::Adaptor adaptor,
3466 int64_t inputWidth = ShapedType::kDynamic;
3467 int64_t inputHeight = ShapedType::kDynamic;
3468 int64_t weightWidth = ShapedType::kDynamic;
3469 int64_t weightHeight = ShapedType::kDynamic;
3473 if (inputShape.hasRank()) {
3474 outputShape[0] = ShapedType::isDynamic(outputShape[0])
3475 ? inputShape.getDimSize(0)
3477 inputHeight = inputShape.getDimSize(1);
3478 inputWidth = inputShape.getDimSize(2);
3482 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3483 if (weightShape.hasRank()) {
3484 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3485 ? weightShape.getDimSize(0)
3487 weightHeight = weightShape.getDimSize(1);
3488 weightWidth = weightShape.getDimSize(2);
3493 if (biasShape.hasRank()) {
3494 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3495 ? biasShape.getDimSize(0)
3502 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3503 int64_t calculateSize =
3504 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3506 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3509 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3510 int64_t calculateSize =
3511 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3513 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3525 const int64_t strideY = strides[0];
3526 const int64_t strideX = strides[1];
3528 if (strideY < 1 || strideX < 1)
3529 return emitOpError(
"expect all stride values to be >= 1, got [")
3532 const auto checkPadAgainstKernelDim =
3533 [
this](int64_t pad_value, int64_t kernel_dim_size,
3534 llvm::StringRef pad_name,
3535 llvm::StringRef kernel_dim_name) -> LogicalResult {
3536 if (pad_value <= -kernel_dim_size)
3537 return emitOpError(
"expected ")
3538 << pad_name <<
" > -" << kernel_dim_name
3539 <<
", but got: " << pad_name <<
"=" << pad_value <<
" and "
3540 << kernel_dim_name <<
"=" << kernel_dim_size;
3545 const int64_t outPadTop = padding[0];
3546 const int64_t outPadBottom = padding[1];
3547 const int64_t outPadLeft = padding[2];
3548 const int64_t outPadRight = padding[3];
3550 const auto weightType =
3551 llvm::dyn_cast<RankedTensorType>(getWeight().
getType());
3554 const int64_t kernelHeight = weightType.getDimSize(1);
3555 if (ShapedType::isStatic(kernelHeight)) {
3556 if (
failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
3557 "out_pad_top",
"KH")))
3560 if (
failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3561 "out_pad_bottom",
"KH")))
3565 const int64_t kernelWidth = weightType.getDimSize(2);
3566 if (ShapedType::isStatic(kernelWidth)) {
3567 if (
failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
3568 "out_pad_left",
"KW")))
3571 if (
failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3572 "out_pad_right",
"KW")))
3578 const auto outputType =
3579 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
3583 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
3584 if (inputType && weightType) {
3585 const int64_t inputHeight = inputType.getDimSize(1);
3586 const int64_t kernelHeight = weightType.getDimSize(1);
3587 const int64_t outputHeight = outputType.getDimSize(1);
3589 if (ShapedType::isStatic(inputHeight) &&
3590 ShapedType::isStatic(outputHeight)) {
3592 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3594 "dimension mismatch: expected OH == (IH - 1) * stride_y "
3595 "+ out_pad_top + out_pad_bottom + KH, but got ")
3596 << outputHeight <<
" != (" << inputHeight <<
" - 1) * "
3597 << strideY <<
" + " << outPadTop <<
" + " << outPadBottom
3598 <<
" + " << kernelHeight;
3601 const int64_t inputWidth = inputType.getDimSize(2);
3602 const int64_t kernelWidth = weightType.getDimSize(2);
3603 const int64_t outputWidth = outputType.getDimSize(2);
3605 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
3607 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3609 "dimension mismatch: expected OW == (IW - 1) * stride_x "
3610 "+ out_pad_left + out_pad_right + KW, but got ")
3611 << outputWidth <<
" != (" << inputWidth <<
" - 1) * " << strideX
3612 <<
" + " << outPadLeft <<
" + " << outPadRight <<
" + "
3617 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().
getType());
3622 const int64_t biasChannels = biasType.getDimSize(0);
3625 if (biasChannels == ShapedType::kDynamic)
3628 const int64_t outputChannels = outputType.getDimSize(3);
3629 if (!ShapedType::isDynamic(outputChannels) &&
3630 biasChannels != outputChannels && biasChannels != 1)
3632 "bias channels expected to be equal to output channels (")
3633 << outputChannels <<
") or 1, got " << biasChannels;
3639 auto inputType = llvm::dyn_cast<ShapedType>(getInput().
getType());
3641 emitOpError(
"expect shaped tensor for input, got ") << getInput().getType();
3645 auto inputElementType =
3647 if (!mlir::isa<IntegerType>(inputElementType)) {
3648 emitOpError(
"expect input to have integer element type, got ")
3649 << inputElementType;
3653 auto outputType = llvm::dyn_cast<ShapedType>(getOutput().
getType());
3655 emitOpError(
"expect shaped tensor for output, got ")
3656 << getOutput().getType();
3660 auto outputElementType =
3662 if (!mlir::isa<IntegerType>(outputElementType)) {
3663 emitOpError(
"expect output to have integer element type, got ")
3664 << outputElementType;
3676 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
3677 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
3680 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3681 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3684 auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().
getType());
3685 if (!multiplierType) {
3686 emitOpError(
"expect shaped tensor for multiplier, got ")
3687 << getMultiplier().getType();
3691 auto shiftType = llvm::dyn_cast<ShapedType>(getShift().
getType());
3693 emitOpError(
"expect shaped tensor for shift, got ") << getShift().getType();
3698 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
3699 emitOpError(
"expect i32 element type for multiplier for scale32=true, got ")
3700 << multiplierType.getElementType();
3705 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
3707 "expect i16 element type for multiplier for scale32=false, got ")
3708 << multiplierType.getElementType();
3712 if (!inputType.hasRank())
3718 int64_t numChannels = 1;
3719 if (getPerChannel()) {
3720 if (inputType.getRank() < 1) {
3721 emitOpError(
"requires input to be at least rank 1 when per_channel is "
3722 "true, but got rank ")
3723 << inputType.getRank();
3726 numChannels = inputType.getDimSize(inputType.getRank() - 1);
3729 if (!multiplierType.hasRank())
3734 if (multiplierShape[0] != ShapedType::kDynamic &&
3735 multiplierShape[0] != numChannels) {
3736 emitOpError(
"expect shape of { ")
3737 << numChannels <<
" } for multiplier input, got { "
3738 << multiplierShape[0] <<
" }";
3742 if (!shiftType.hasRank())
3747 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3748 emitOpError(
"expect shape of { ")
3749 << numChannels <<
" } for shift input, got { " << shiftShape[0] <<
" }";
3756 LogicalResult RescaleOp::inferReturnTypeComponents(
3757 MLIRContext *context, ::std::optional<Location> location,
3758 RescaleOp::Adaptor adaptor,
3765 LogicalResult IfOp::inferReturnTypeComponents(
3766 MLIRContext *context, ::std::optional<Location> location,
3767 IfOp::Adaptor adaptor,
3770 for (
Region *region : adaptor.getRegions()) {
3771 for (
auto &block : *region)
3772 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3773 yieldOps.push_back(returnOp);
3776 if (yieldOps.empty())
3781 resultKnowledge.reserve(yieldOps.front().getNumOperands());
3782 for (
auto operand : yieldOps.front().getOperands()) {
3783 resultKnowledge.push_back(
3787 for (
auto yieldOp : yieldOps) {
3788 if (resultKnowledge.size() != yieldOp.getNumOperands())
3792 int32_t index = it.index();
3794 resultKnowledge[index],
3798 resultKnowledge[index] = meet;
3803 inferredReturnShapes.push_back(result.getShapedTypeComponents());
3809 LogicalResult WhileOp::inferReturnTypeComponents(
3810 MLIRContext *context, ::std::optional<Location> location,
3811 WhileOp::Adaptor adaptor,
3814 for (
auto &block : adaptor.getBodyGraph())
3815 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
3816 yieldOps.push_back(returnOp);
3820 if (yieldOps.empty())
3825 resultKnowledge.reserve(yieldOps.front().getNumOperands());
3826 for (
auto operand : yieldOps.front().getOperands()) {
3827 resultKnowledge.push_back(
3831 for (
auto yieldOp : yieldOps) {
3832 if (resultKnowledge.size() != yieldOp.getNumOperands())
3836 int32_t index = it.index();
3838 resultKnowledge[index],
3840 resultKnowledge[index] = meet;
3846 inferredReturnShapes.push_back(result.getShapedTypeComponents());
3852 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
3853 if (
auto vt = llvm::dyn_cast<VectorType>(
getType()))
3854 return llvm::to_vector<4>(vt.getShape());
3855 return std::nullopt;
3861 StringRef prefix =
"") {
3862 assert(blocksArgs.size() == initializers.size() &&
3863 "expected same length of arguments and initializers");
3864 if (initializers.empty())
3867 parser << prefix <<
'(';
3868 llvm::interleaveComma(
3869 llvm::zip(blocksArgs, initializers), parser,
3870 [&](
auto it) { parser << std::get<0>(it) <<
" = " << std::get<1>(it); });
3898 "expected type for condition operand");
3904 "expected type for condition operand");
3912 FunctionType functionType;
3916 <<
"expected list of types for block arguments "
3917 <<
"followed by arrow type and list of return types";
3919 result.
addTypes(functionType.getResults());
3921 if (functionType.getNumInputs() != operands.size()) {
3923 <<
"expected as many input types as operands "
3924 <<
"(expected " << operands.size() <<
" got "
3925 << functionType.getNumInputs() <<
")";
3956 p <<
" " << getCondition();
3959 getInputList(),
" ");
3961 p << getCondition().getType();
3963 if (!getInputList().empty()) {
3965 llvm::interleaveComma(getInputList().getTypes(), p);
3974 auto &elseRegion = getElseGraph();
3975 if (!elseRegion.
empty()) {
3985 "'then_graph' arguments", getInputList(),
3991 "'else_graph' arguments", getInputList(),
3997 if (getThenGraph().front().mightHaveTerminator()) {
3999 dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
4001 *
this, thenYield.getInputs(),
"'then_graph' results",
4002 getOutputList(),
"'output_list'")
4008 if (getElseGraph().front().mightHaveTerminator()) {
4010 dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
4012 *
this, elseYield.getInputs(),
"'else_graph' results",
4013 getOutputList(),
"'output_list'")
4018 auto condType = getCondition().getType();
4020 return emitOpError() <<
"'condition' must be a size 1 tensor, got "
4028 getOutputList(),
"'output_list'")
4033 "'cond_graph' arguments", getInputList(),
4039 "'body_graph' arguments", getInputList(),
4044 if (getBodyGraph().front().mightHaveTerminator()) {
4046 dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
4048 "'body_graph' results",
4049 getInputList(),
"'input_list'")
4056 if (!getCondGraph().front().mightHaveTerminator())
4060 dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
4064 if (condYield.getInputs().size() != 1)
4065 return emitOpError() <<
"require 'cond_graph' only have one result";
4067 auto condOutType = condYield.getInputs()[0].getType();
4069 return emitOpError() <<
"'cond_graph' result must be a size 1 tensor, got "
4073 return emitOpError() <<
"'cond_graph' result must be a boolean tensor, got "
4084 TensorType inputType = getInput1().getType();
4085 TensorType outputType = getOutput().getType();
4086 int32_t reverseAxis = getAxis();
4088 if (reverseAxis < 0)
4089 return emitOpError(
"expected non-negative reverse axis");
4091 int64_t inputRank = inputType.getRank();
4094 if (reverseAxis >= inputRank && (reverseAxis != 0 || inputRank != 0))
4095 return emitOpError(
"expect input tensor rank (")
4096 << inputRank <<
") to be larger than reverse axis (" << reverseAxis
4100 int64_t outputRank = outputType.getRank();
4101 if (inputType.
hasRank() && outputRank != inputType.getRank())
4103 "expect output tensor rank to be equal to input tensor rank");
4104 if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0))
4105 return emitOpError(
"expect output tensor rank (")
4106 << outputRank <<
") to be larger than reverse axis ("
4107 << reverseAxis <<
")";
4123 auto predicateType = llvm::dyn_cast<ShapedType>(getPred().
getType());
4124 if (!predicateType) {
4125 return emitOpError(
"expect shaped tensor for input1, got ")
4126 << getInput1().getType();
4128 auto predicateElementType = predicateType.getElementType();
4129 if (!predicateElementType.isInteger(1)) {
4130 return emitOpError(
"expect element type of bool for input1, got ")
4131 << predicateElementType;
4165 FunctionType functionType;
4170 result.
addTypes(functionType.getResults());
4172 if (functionType.getNumInputs() != operands.size()) {
4174 <<
"expected as many input types as operands "
4175 <<
"(expected " << operands.size() <<
" got "
4176 << functionType.getNumInputs() <<
")";
4186 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
4187 regionArgs[i].type = functionType.getInput(i);
4189 return failure(parser.
parseRegion(*cond, regionArgs) ||
4196 getInputList(),
" ");
4199 getResults().getTypes());
4214 if (llvm::isa<FloatType>(srcElemType)) {
4216 zpType, builder.
getFloatAttr(srcElemType,
static_cast<double>(zp)));
4217 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4219 if (llvm::isa<IntegerType>(srcElemType)) {
4222 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4224 llvm::errs() <<
"zero point is not allowed for unsupported data types\n";
4225 return std::nullopt;
4233 return mlir::isa<tosa::shapeType>(t);
4240 return emitError() <<
"invalid rank (must be >= 0): " << rank;
4246 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
4247 Operation *definingOp = v.getDefiningOp();
4249 return op->
emitOpError(
"shape operand is not compile time resolvable");
4258 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4259 return op->
emitOpError(
"must have operands with tosa shape type");
4263 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4264 return op->
emitOpError(
"must have result with tosa shape type");
4277 auto getRank = [](
const Type type) {
4278 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4284 for (
auto type : operandTypes) {
4285 if (getRank(type) != rank) {
4286 return op->
emitOpError(
"operands don't have matching ranks");
4289 for (
auto type : resultTypes) {
4290 if (getRank(type) != rank) {
4291 return op->
emitOpError(
"result shape has different rank than operands");
4303 auto valuesRank = getValues().getType().getRank();
4304 if (valuesRank != 1)
4305 return emitOpError(
"expect elements in attribute values with rank 1");
4307 auto count = getValues().getNumElements();
4308 auto rank = (cast<tosa::shapeType>(getResult().
getType())).getRank();
4309 if (count != rank && (count != 1 || rank != 0)) {
4310 return emitOpError(
"expect number of elements in attribute values (")
4311 << count <<
") to be equal to the rank (" << rank
4312 <<
") for the result shape type";
4321 #define GET_ATTRDEF_CLASSES
4322 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
4327 #define GET_TYPEDEF_CLASSES
4328 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
4334 #define GET_OP_CLASSES
4335 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType)
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
#define REDUCE_SHAPE_INFER(OP)
static LogicalResult verifyConvOp(T op)
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
static LogicalResult verifyReduceOp(T op)
#define NARY_SHAPE_INFER(OP)
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
static LogicalResult verifyConvOpErrorIf(T op)
static LogicalResult verifyConvOpModes(T op)
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static Type getStorageElementTypeOrSelf(Type type)
#define COMPATIBLE_RETURN_TYPES(OP)
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
static LogicalResult verifyPoolingOp(T op)
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
MutableArrayRef< BlockArgument > BlockArgListType
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
This is a utility class for mapping one set of IR entities to another.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
This class indicates that op operates on tosa shape types.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class provides an abstraction over the different types of ranges over Regions.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperator(Operation *op)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
bool isa_tosa_shape_type(mlir::Type t)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)