28 #include "llvm/ADT/APFloat.h"
29 #include "llvm/ADT/TypeSwitch.h"
36 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
43 #include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
44 #include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
45 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
46 #include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
49 #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
70 return (isa<tosa::IfOp>(dest->getParentOp()) ||
71 isa<tosa::WhileOp>(dest->getParentOp()));
77 TosaDialectBytecodeInterface(
Dialect *dialect)
87 LogicalResult writeAttribute(
Attribute attr,
89 return ::writeAttribute(attr, writer);
99 LogicalResult writeType(
Type type,
101 return ::writeType(type, writer);
108 std::unique_ptr<DialectVersion>
111 reader.
emitError(
"Dialect does not support versioning");
115 LogicalResult upgradeFromVersion(
Operation *topLevelOp,
129 return {&getBodyGraph()};
137 return to_vector(llvm::map_range(shape, [](int64_t dim) {
138 return dim == -1 ? ShapedType::kDynamic : dim;
144 Type elementType = variableOp.getType();
154 void TosaDialect::initialize() {
156 #define GET_TYPEDEF_LIST
157 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
161 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
164 #define GET_ATTRDEF_LIST
165 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
167 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
168 declarePromisedInterfaces<
169 shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
170 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
171 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
172 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
173 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
174 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
175 GreaterEqualOp, MatMulOp>();
182 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
183 return tosa::ConstShapeOp::create(builder, loc, type,
184 llvm::cast<DenseIntElementsAttr>(value));
186 if (llvm::isa<ElementsAttr>(value))
187 return tosa::ConstOp::create(builder, loc, type,
188 llvm::cast<ElementsAttr>(value));
198 ParseResult getShapeAndElementType(
OpAsmParser &parser,
Type parsedType,
200 TypeAttr &typeAttr) {
201 if (
auto shapedType = dyn_cast<ShapedType>(parsedType)) {
202 if (!shapedType.hasRank())
204 <<
"expected ranked type";
206 auto elementType = shapedType.getElementType();
214 <<
"expected shaped type";
231 <<
"expected attribute";
233 if (
auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
234 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
238 <<
"expected Typed attr";
241 initialValueAttr =
nullptr;
245 <<
"expected type after colon";
247 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
252 TypeAttr typeAttr,
Attribute initialValueAttr) {
253 bool needsSpace =
false;
254 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
257 Type elementType = typeAttr.getValue();
258 RankedTensorType tensorType =
265 if (initialValueAttr) {
276 template <
typename EnumType>
277 ParseResult parseAttrEntryWithEnumHandling(
OpAsmParser &parser,
279 llvm::StringRef name;
286 if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
287 if (name ==
"rounding_mode" &&
289 auto sym = symbolizeRoundingMode(kw);
292 <<
"invalid rounding_mode value: " << kw;
299 if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
301 auto sym = symbolizeResizeMode(kw);
304 <<
"invalid resize mode value: " << kw;
312 if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
314 auto sym = symbolizeNanPropagationMode(kw);
317 <<
"invalid nan_mode value: " << kw;
325 if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
327 auto sym = symbolizeBlockSize(kw);
330 <<
"invalid block_size value: " << kw;
343 template <
typename EnumType>
348 [&]() { return parser.parseOperand(operands.emplace_back()); }))
356 if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
380 parser << namedAttr.
getName().strref() <<
" = ";
382 if (
auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
383 parser << roundingModeAttr.getValue();
384 }
else if (
auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
385 parser << resizeModeAttr.getValue();
386 }
else if (
auto nanPropagationModeAttr =
387 dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
388 parser << nanPropagationModeAttr.getValue();
389 }
else if (
auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
390 parser << blockSizeAttr.getValue();
403 const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
405 if (
auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
406 if (nanAttr.getValue() == kDefaultNanValue) {
408 toPrint.erase(attr.getName());
414 if (!toPrint.empty()) {
416 llvm::interleaveComma(toPrint, parser, [&](
const NamedAttribute namedAttr) {
417 printNamedAttr(parser, namedAttr);
433 llvm::interleaveComma(op->
getAttrs(), parser,
435 printNamedAttr(parser, namedAttr);
447 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
451 printWithEnumHandling(parser, *
this);
455 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
459 printWithEnumHandling(parser, *
this);
463 return parseWithEnumHandling<tosa::ResizeMode>(parser, result);
467 printWithEnumHandling(parser, *
this);
471 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
475 printWithNanPropagationHandling(parser, *
this);
479 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
483 printWithNanPropagationHandling(parser, *
this);
487 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
491 printWithNanPropagationHandling(parser, *
this);
495 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
499 printWithNanPropagationHandling(parser, *
this);
503 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
507 printWithNanPropagationHandling(parser, *
this);
511 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
515 printWithNanPropagationHandling(parser, *
this);
519 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
523 printWithNanPropagationHandling(parser, *
this);
528 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
532 printWithEnumHandling(parser, *
this);
537 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
541 printWithEnumHandling(parser, *
this);
546 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
550 printWithEnumHandling(parser, *
this);
557 static std::optional<int64_t>
idivCheck(
const int64_t lhs,
const int64_t rhs) {
565 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
566 srcType = quantType.getStorageType();
575 Value valZp, StringRef name) {
580 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
584 if (!bothInts || !sameBitWidth) {
586 <<
"expected " << name <<
" and " << name
587 <<
"_zp to both be integer of the same bitwidth, but got " << eType
588 <<
" vs. " << eZpType;
595 Value src, int32_t val) {
600 const auto padConstAttr{
601 llvm::isa<FloatType>(srcElemType)
606 return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
610 if (dyn_cast<tosa::mxint8Type>(type))
619 template <
typename T>
621 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
622 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
624 auto inputEType = inputType.getElementType();
625 auto weightEType = weightType.getElementType();
627 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
629 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
630 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
631 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
633 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
634 inputEType = quantType.getStorageType();
636 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
637 weightEType = quantType.getStorageType();
639 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
640 biasEType = quantType.getStorageType();
642 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
643 resultEType = quantType.getStorageType();
645 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
649 "expect both bias and result to have same element type, got ")
650 << biasEType <<
" and " << resultEType;
654 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
655 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
656 if (inputEType != weightEType) {
658 "expect both input and weight to have same element type, got ")
659 << inputEType <<
" and " << weightEType;
664 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
665 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
668 if (inputIsFloat != weightIsFloat) {
670 "expect both input and weight to be float or not together, got ")
671 << inputEType <<
" and " << weightEType;
676 if (inputEType != inputZpEType) {
677 return op.emitOpError(
"expect both input and its zero point are the same "
678 "element type, got ")
679 << inputEType <<
" and " << inputZpEType;
683 if (weightEType != weightZpEType) {
684 return op.emitOpError(
"expect both weight and its zero point are the same "
685 "element type, got ")
686 << weightEType <<
" and " << weightZpEType;
689 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
690 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
693 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
694 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
702 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().
getType());
703 auto outputType = llvm::dyn_cast<TensorType>(getOutput().
getType());
705 if (!attrType || !outputType) {
706 emitOpError(
"expected tensors for attr/result type");
710 if (
auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
711 outputType.getElementType())) {
712 if (result.getStorageType() == attrType.getElementType())
716 if (attrType.getElementType() != outputType.getElementType()) {
717 emitOpError(
"expected same attr/result element types");
724 template <
typename T>
727 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
729 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
730 inputEType = quantType.getStorageType();
732 auto accType = op.getAccType();
733 if (inputEType.isInteger(8) && !accType.isInteger(32))
734 return op.emitOpError(
"accumulator type for i8 tensor is not i32");
736 if (inputEType.isInteger(16) && !accType.isInteger(48))
737 return op.emitOpError(
"accumulator type for i16 tensor is not i48");
739 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
740 return op.emitOpError(
"accumulator type for f8 tensor is not f16");
742 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
743 return op.emitOpError(
"accumulator type for f16 tensor is not f16/f32");
745 if (inputEType.isBF16() && !accType.isF32())
746 return op.emitOpError(
"accumulator type for bf16 tensor is not f32");
748 if (inputEType.isF32() && !accType.isF32())
749 return op.emitOpError(
"accumulator type for f32 tensor is not f32");
752 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
754 if (
auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
755 resultEType = quantType.getStorageType();
765 template <
typename T>
768 if (llvm::any_of(padding, [](int64_t p) {
return p < 0; }))
769 return op.emitOpError(
"expect all padding values to be >= 0, got ")
773 if (llvm::any_of(strides, [](int64_t s) {
return s < 1; }))
774 return op.emitOpError(
"expect all stride values to be >= 1, got ")
778 if (llvm::any_of(dilations, [](int64_t d) {
return d < 1; }))
779 return op.emitOpError(
"expect all dilation values to be >= 1, got ")
782 const RankedTensorType outputType =
783 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
788 const RankedTensorType inputType =
789 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
790 const RankedTensorType weightType =
791 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
793 if (inputType && weightType) {
794 const auto verifyOutputSize =
795 [&op](
const int64_t inputSize,
const int64_t kernelSize,
796 const int64_t outputSize,
const int64_t padBefore,
797 const int64_t padAfter,
const int64_t stride,
798 const int64_t dilation,
const llvm::StringRef dimName,
799 const llvm::StringRef dimAxis,
800 const llvm::StringRef padBeforeName,
801 const llvm::StringRef padAfterName) -> LogicalResult {
802 if (inputSize == ShapedType::kDynamic ||
803 kernelSize == ShapedType::kDynamic)
808 const std::optional<int64_t> calculatedOutSizeMinusOne =
idivCheck(
809 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
811 if (!calculatedOutSizeMinusOne.has_value())
812 return op.emitOpError(
"expected input_")
813 << dimName <<
" - 1 + pad_" << padBeforeName <<
" + pad_"
814 << padAfterName <<
" - (kernel_" << dimName
815 <<
" - 1) * dilation_" << dimAxis
816 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
817 << inputSize <<
" - 1 + " << padBefore <<
" + " << padAfter
818 <<
" - (" << kernelSize <<
" - 1) * " << dilation <<
") / "
821 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
822 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
823 return op.emitOpError(
"calculated output ")
824 << dimName <<
" did not match expected: "
825 <<
"calculated=" << calculatedOutSize
826 <<
", expected=" << outputSize;
832 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
833 if (
failed(verifyOutputSize(
834 inputType.getDimSize(1), weightType.getDimSize(1),
835 outputType.getDimSize(1), padding[0], padding[1], strides[0],
836 dilations[0],
"height",
"y",
"top",
"bottom")))
839 if (
failed(verifyOutputSize(
840 inputType.getDimSize(2), weightType.getDimSize(2),
841 outputType.getDimSize(2), padding[2], padding[3], strides[1],
842 dilations[1],
"width",
"x",
"left",
"right")))
847 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
848 if (
failed(verifyOutputSize(
849 inputType.getDimSize(1), weightType.getDimSize(0),
850 outputType.getDimSize(1), padding[0], padding[1], strides[0],
851 dilations[0],
"height",
"y",
"top",
"bottom")))
854 if (
failed(verifyOutputSize(
855 inputType.getDimSize(2), weightType.getDimSize(1),
856 outputType.getDimSize(2), padding[2], padding[3], strides[1],
857 dilations[1],
"width",
"x",
"left",
"right")))
862 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
863 if (
failed(verifyOutputSize(
864 inputType.getDimSize(1), weightType.getDimSize(1),
865 outputType.getDimSize(1), padding[0], padding[1], strides[0],
866 dilations[0],
"depth",
"d",
"front",
"back")))
869 if (
failed(verifyOutputSize(
870 inputType.getDimSize(2), weightType.getDimSize(2),
871 outputType.getDimSize(2), padding[2], padding[3], strides[1],
872 dilations[1],
"height",
"y",
"top",
"bottom")))
875 if (
failed(verifyOutputSize(
876 inputType.getDimSize(3), weightType.getDimSize(3),
877 outputType.getDimSize(3), padding[4], padding[5], strides[2],
878 dilations[2],
"width",
"x",
"left",
"right")))
883 const RankedTensorType biasType =
884 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
889 const int64_t biasChannels = biasType.getDimSize(0);
890 const int64_t outputChannels =
891 outputType.getDimSize(outputType.getRank() - 1);
892 if (biasChannels == ShapedType::kDynamic ||
893 outputChannels == ShapedType::kDynamic)
897 if (biasChannels != outputChannels && biasChannels != 1)
898 return op.emitOpError(
899 "bias channels expected to be equal to output channels (")
900 << outputChannels <<
") or 1, got " << biasChannels;
907 StringRef name1,
Type type2,
909 auto shapeType1 = dyn_cast<ShapedType>(type1);
910 auto shapeType2 = dyn_cast<ShapedType>(type2);
911 if (!shapeType1 || !shapeType2)
914 auto elemType1 = shapeType1.getElementType();
915 auto elemType2 = shapeType2.getElementType();
916 if (elemType1 != elemType2)
918 <<
"require same element type for " << name1 <<
" (" << elemType1
919 <<
") and " << name2 <<
" (" << elemType2 <<
")";
923 <<
"require same shapes for " << name1 <<
" (" << type1 <<
") and "
924 << name2 <<
" (" << type2 <<
")";
934 if (list1.size() != list2.size())
936 <<
"require same number of values in " << name1 <<
" ("
937 << list1.size() <<
") and " << name2 <<
" (" << list2.size() <<
")";
939 for (
auto [type1, type2] :
953 return shapeAdaptor.
getNumElements() == 1 ? success() : failure();
956 template <
typename T>
959 op->template getParentWithTrait<OpTrait::SymbolTable>();
966 const auto varOp = symTable.
lookup<tosa::VariableOp>(op.getName());
971 << op.getName() <<
"' has not been declared by 'tosa.variable'";
983 template <
typename T>
985 StringRef aName =
"input",
986 StringRef bName =
"output") {
987 auto aTType = llvm::dyn_cast<TensorType>(aType);
988 auto bTType = llvm::dyn_cast<TensorType>(bType);
990 op.emitOpError(
"expect shaped tensor for") << aName <<
", got " << aType;
994 op.emitOpError(
"expect shaped tensor for") << bName <<
", got" << bType;
997 auto aElementType = aTType.getElementType();
998 auto bElementType = bTType.getElementType();
1000 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
1002 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
1003 if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
1004 (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
1005 aElementType != bElementType) {
1010 op.emitOpError(
"expect ")
1011 << aName <<
" and " << bName <<
" to have same element type, got "
1012 << aElementType <<
" and " << bElementType;
1019 const ShapedType resultType = llvm::cast<ShapedType>(
getType());
1022 if (
const auto resultETy = resultType.getElementType();
1023 !resultETy.isIntOrIndex())
1024 return emitOpError(
"result tensor is not of integer type");
1026 const auto inputType = llvm::cast<ShapedType>(getInput().
getType());
1027 if (!inputType.hasRank())
1031 const int64_t axis = getAxisAttr().getInt();
1032 if (((axis < 0) || axis >= inputType.getRank()))
1033 return emitOpError(
"specified axis is outside the rank of the tensor");
1035 if (!resultType.hasRank())
1041 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
1043 return emitOpError(
"expected output shape '")
1044 << expectedOutputShape <<
"', got '" << outputShape <<
"'";
1049 template <
typename T>
1052 if (llvm::any_of(kernel, [](int64_t s) {
return s < 1; }))
1053 return op.emitOpError(
"expect all kernel values to be >= 1, got ")
1057 if (llvm::any_of(strides, [](int64_t s) {
return s < 1; }))
1058 return op.emitOpError(
"expect all stride values to be >= 1, got ")
1062 if (llvm::any_of(padding, [](int64_t p) {
return p < 0; }))
1063 return op.emitOpError(
"expect all padding values to be >= 0, got ")
1067 const int64_t kernelX = kernel[1];
1068 const int64_t padLeft = padding[2];
1069 const int64_t padRight = padding[3];
1070 if (padRight >= kernelX || padLeft >= kernelX)
1071 return op.emitOpError(
"expected left/right padding to be less than the "
1072 "width of the kernel, got pad_left=")
1073 << padLeft <<
", pad_right=" << padRight <<
", kernel_x=" << kernelX;
1075 const int64_t kernelY = kernel[0];
1076 const int64_t padTop = padding[0];
1077 const int64_t padBottom = padding[1];
1078 if (padTop >= kernelY || padBottom >= kernelY)
1079 return op.emitOpError(
"expected top/bottom padding to be less than the "
1080 "height of the kernel, got pad_top=")
1081 << padTop <<
", pad_bottom=" << padBottom
1082 <<
", kernel_y=" << kernelY;
1084 const auto inputType =
1085 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
1086 const auto outputType =
1087 llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
1088 if (!inputType || !outputType)
1091 const auto verifyOutputSize =
1092 [&op](
const int64_t inputSize,
const int64_t outputSize,
1093 const int64_t kernelSize,
const int64_t strideSize,
1094 const int64_t padBefore,
const int64_t padAfter,
1095 const llvm::StringRef dimName,
const llvm::StringRef dimAxis,
1096 const llvm::StringRef padBeforeName,
1097 const llvm::StringRef padAfterName) -> LogicalResult {
1098 if (ShapedType::isDynamic(inputSize))
1101 const std::optional<int64_t> calculatedOutSizeMinusOne =
1102 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1103 if (!calculatedOutSizeMinusOne.has_value())
1104 return op.emitOpError(
"expected input_")
1105 << dimName <<
" + pad_" << padBeforeName <<
" + pad_"
1106 << padAfterName <<
" - kernel_" << dimAxis
1107 <<
" to be wholly divisible by stride_" << dimAxis <<
", got ("
1108 << inputSize <<
" + " << padBefore <<
" + " << padAfter <<
" - "
1109 << kernelSize <<
") / " << strideSize;
1111 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1112 if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1113 return op.emitOpError(
"calculated output ")
1114 << dimName <<
" did not match expected: "
1115 <<
"calculated=" << calculatedOutSize
1116 <<
", expected=" << outputSize;
1121 if (
failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
1122 kernel[0], strides[0], padding[0], padding[1],
1123 "height",
"y",
"top",
"bottom")))
1126 if (
failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
1127 kernel[1], strides[1], padding[2], padding[3],
1128 "width",
"x",
"left",
"right")))
1143 auto accType = getAccType();
1144 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1145 return emitOpError(
"accumulator type for integer tensor is not i32");
1147 if (inputETy.
isF16() && !(accType.isF16() || accType.isF32()))
1148 return emitOpError(
"accumulator type for f16 tensor is not f16/f32");
1150 if (inputETy.
isBF16() && !accType.isF32())
1151 return emitOpError(
"accumulator type for bf16 tensor is not f32");
1153 if (inputETy.
isF32() && !accType.isF32())
1154 return emitOpError(
"accumulator type for f32 tensor is not f32");
1156 if (inputETy != inputZpETy)
1157 return emitOpError(
"expect both input and its zero point are the same "
1158 "element type, got ")
1159 << inputETy <<
" and " << inputZpETy;
1161 if (resultETy != outputZpETy)
1162 return emitOpError(
"expect both output and its zero point are the same "
1163 "element type, got ")
1164 << resultETy <<
" and " << outputZpETy;
1166 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
1167 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
1170 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
1171 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
1179 llvm::cast<ShapedType>(getInput().
getType()).getElementType();
1180 if (
auto quantType =
1181 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1182 inputETy = quantType.getStorageType();
1185 llvm::cast<ShapedType>(getOutput().
getType()).getElementType();
1186 if (
auto quantType =
1187 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1188 outputETy = quantType.getStorageType();
1190 if (inputETy != outputETy)
1191 return emitOpError(
"input/output element types are incompatible.");
1193 auto maxValAttr = getMaxValAttr();
1194 auto minValAttr = getMinValAttr();
1198 if (inputETy.
isInteger(dataTypeBitWidth)) {
1202 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1203 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1204 if (!intMaxValAttr || !intMinValAttr ||
1205 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1206 (intMaxValAttr.getType() != inputETy))
1207 return emitOpError(
"min/max attributes types are incompatible with "
1208 "input/output element types.");
1211 const bool isBoolean = inputETy.
isInteger(1);
1212 const APInt minVal = intMinValAttr.getValue();
1213 const APInt maxVal = intMaxValAttr.getValue();
1214 if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1215 return emitOpError(
"expected min_val <= max_val, got min_val=")
1216 << minValAttr <<
", max_val=" << maxValAttr;
1221 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1222 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1223 if (!floatMaxValAttr || !floatMinValAttr ||
1224 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1225 (floatMaxValAttr.getType() != inputETy))
1226 return emitOpError(
"min/max attributes types are incompatible with "
1227 "input/output element types.");
1229 const APFloat minVal = floatMinValAttr.getValue();
1230 const APFloat maxVal = floatMaxValAttr.getValue();
1231 if (minVal.isNaN() || maxVal.isNaN())
1232 return emitOpError(
"min/max attributes should not be 'NaN', got min_val=")
1233 << minValAttr <<
", max_val=" << maxValAttr;
1235 if (maxVal < minVal)
1236 return emitOpError(
"expected min_val <= max_val, got min_val=")
1237 << minValAttr <<
", max_val=" << maxValAttr;
1257 result.
addOperands({input, weight, bias, zps.first, zps.second});
1262 Type finalOutputType = outputType;
1279 result.
addOperands({input, weight, bias, zps.first, zps.second});
1283 Type finalOutputType = outputType;
1300 result.
addOperands({a, b, zps.first, zps.second});
1302 Type finalOutputType{outputType};
1305 auto inputBits = eType.getIntOrFloatBitWidth();
1307 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1308 assert(outputShapedType &&
"Output must be a shaped type");
1310 IntegerType accElementType;
1311 if (inputBits == 16)
1316 finalOutputType = outputShapedType.clone(accElementType);
1327 DenseArrayAttr kernel, DenseArrayAttr stride,
1328 DenseArrayAttr pad, TypeAttr accType) {
1331 int64_t outputZp{0};
1333 if (
auto quantAttr =
1335 inputZp = quantAttr.getInputZp();
1336 outputZp = quantAttr.getOutputZp();
1338 const std::optional<Value> inputZpOp =
1343 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1345 const std::optional<Value> outputZpOp =
1348 (void)
emitError(loc,
"Failed to create output zero point tensor for "
1349 "quantized AVG_POOL2D op");
1352 if (inputZpOp && outputZpOp) {
1353 result.
addOperands({input, inputZpOp.value(), outputZpOp.value()});
1364 result.
types.push_back(outputType);
1374 int64_t input1Zp{0};
1375 int64_t outputZp{0};
1378 input1Zp = quantAttr.getInputZp();
1379 outputZp = quantAttr.getOutputZp();
1381 const std::optional<Value> input1ZpOp =
1385 loc,
"Failed to create input1 zero point for quantized NEGATE op");
1388 const std::optional<Value> outputZpOp =
1392 loc,
"Failed to create output zero point for quantized NEGATE op");
1395 if (input1ZpOp && outputZpOp) {
1396 result.
addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1404 result.
types.push_back(outputType);
1417 zp =
static_cast<int32_t
>(quantAttr.getInputZp());
1420 result.
addOperands({input, paddings, padConstOp});
1421 result.
types.push_back(outputType);
1425 StringRef name,
Type variableType,
1430 auto shapedType = dyn_cast<ShapedType>(variableType);
1432 (void)
emitError(loc,
"variable type must be a shaped type");
1435 if (!shapedType.hasRank()) {
1436 (void)
emitError(loc,
"variable type must be a ranked type");
1440 auto elementType = shapedType.getElementType();
1457 int64_t outRank = 0;
1458 for (
int i = 0, e = operands.size(); i != e; ++i) {
1460 if (!shape.hasRank()) {
1465 outRank = std::max<int64_t>(outRank, shape.getRank());
1468 outShape.resize(outRank, 1);
1470 for (
int i = 0, e = operands.size(); i != e; ++i) {
1472 auto rankDiff = outShape.size() - shape.getRank();
1474 for (
size_t i = 0, e = shape.getRank(); i < e; ++i) {
1475 auto dim1 = outShape[i + rankDiff];
1476 auto dim2 = shape.getDimSize(i);
1477 auto resolvedDim = dim1;
1481 }
else if (dim2 == 1) {
1483 }
else if (dim1 != dim2) {
1486 outShape[i + rankDiff] = resolvedDim;
1493 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1494 MLIRContext *context, ::std::optional<Location> location,
1495 ArgMaxOp::Adaptor adaptor,
1498 IntegerAttr axis = adaptor.getProperties().axis;
1499 int32_t axisVal = axis.getValue().getSExtValue();
1501 if (!inputShape.hasRank()) {
1507 outShape.reserve(inputShape.getRank() - 1);
1508 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
1511 outShape.push_back(inputShape.getDimSize(i));
1518 LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1519 MLIRContext *context, ::std::optional<Location> location,
1520 RFFT2dOp::Adaptor adaptor,
1522 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1524 if (!inputShape.hasRank())
1528 outputShape.resize(3, ShapedType::kDynamic);
1529 outputShape[0] = inputShape.getDimSize(0);
1530 outputShape[1] = inputShape.getDimSize(1);
1531 int64_t inWidth = inputShape.getDimSize(2);
1535 if (inWidth != ShapedType::kDynamic)
1536 outputShape[2] = inWidth / 2 + 1;
1545 const llvm::StringRef dimName) {
1546 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1549 << dimName <<
" to be a power of two, got " << dimSize;
1555 const auto outputTypes = getResultTypes();
1557 return emitOpError(
"expected output shapes to match, got ") << outputTypes;
1559 const auto inputType =
1560 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1564 const int64_t height = inputType.getDimSize(1);
1565 if (ShapedType::isStatic(height) &&
1569 const int64_t width = inputType.getDimSize(2);
1570 if (ShapedType::isStatic(width) &&
1574 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1580 outputType.getShape().drop_back())))
1581 return emitOpError(
"expected batch and height dimensions of input/output "
1582 "to match, got input=")
1583 << inputType <<
" output=" << outputType;
1586 const int64_t outputWidth = outputType.getDimSize(2);
1587 if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1588 (outputWidth != (width / 2) + 1))
1590 "expected output width to be equal to input_width / 2 + 1, got ")
1596 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1597 MLIRContext *context, ::std::optional<Location> location,
1598 FFT2dOp::Adaptor adaptor,
1600 inferredReturnShapes.push_back(
1602 inferredReturnShapes.push_back(
1608 const auto inputRealType =
1609 llvm::dyn_cast<RankedTensorType>(getInputReal().
getType());
1610 const auto inputImagType =
1611 llvm::dyn_cast<RankedTensorType>(getInputImag().
getType());
1612 if (!inputRealType || !inputImagType)
1615 const auto trySelectStaticDim = [](
const int64_t a,
const int64_t b) {
1616 return ShapedType::isDynamic(a) ? a : b;
1619 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1620 inputImagType.getDimSize(1));
1621 if (ShapedType::isStatic(height) &&
1625 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1626 inputImagType.getDimSize(2));
1627 if (ShapedType::isStatic(width) &&
1634 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1635 MLIRContext *context, ::std::optional<Location> location,
1636 ConcatOp::Adaptor adaptor,
1639 const Properties &prop = adaptor.getProperties();
1640 int32_t axis = prop.axis.getValue().getSExtValue();
1642 bool hasRankedInput =
false;
1643 for (
auto operand : adaptor.getOperands()) {
1645 if (!operandShape.hasRank())
1649 if (!hasRankedInput)
1650 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1653 for (
int i = 0, s = operandShape.getRank(); i < s; i++) {
1654 if (i == axis || operandShape.isDynamicDim(i))
1656 if (outputShape[i] == ShapedType::kDynamic)
1657 outputShape[i] = operandShape.getDimSize(i);
1658 if (outputShape[i] != operandShape.getDimSize(i))
1660 "Cannot concat tensors with different sizes"
1661 " on the non-axis dimension ",
1665 hasRankedInput =
true;
1668 if (adaptor.getInput1().empty())
1672 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1673 if (!hasRankedInput) {
1679 int64_t concatDimSize = 0;
1680 for (
auto operand : adaptor.getOperands()) {
1685 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1686 concatDimSize = ShapedType::kDynamic;
1690 concatDimSize += operandShape.getDimSize(axis);
1693 outputShape[axis] = concatDimSize;
1701 auto outType = getOutput().getType();
1705 if (inputList.empty())
1706 return emitOpError(
"expect at least one input");
1708 if (!llvm::all_of(inputList, [&](
auto input) {
1710 *
this, input.getType(), outType));
1715 const int32_t axis = getAxis();
1717 for (
const auto &input : inputList) {
1718 const Type inputType = input.getType();
1720 if (currShape.hasRank()) {
1721 firstRankedInputShape = currShape;
1723 if (axis < 0 || axis >= firstRankedInputShape.
getRank())
1724 return emitOpError(
"expect axis to be within range 0 < axis < "
1725 "rank(input1[firstRankedTensorIdx]), got ")
1731 const auto allOperandsHasRank = [](
const Value input) {
1734 if (llvm::all_of(inputList, allOperandsHasRank)) {
1735 const int64_t firstInputRank = firstRankedInputShape.
getRank();
1737 for (
const auto &[index, input] :
llvm::enumerate(inputList.drop_front())) {
1739 const int64_t inputRank = inputShape.getRank();
1740 const size_t operandNum = index + 1;
1743 if (inputRank != firstInputRank)
1745 "expect all operands to have the same rank, but got ")
1746 << firstInputRank <<
" vs " << inputRank <<
" on operands 0 and "
1750 for (
int i = 0; i < inputRank; i++) {
1751 const int64_t inputDim = inputShape.getDimSize(i);
1752 const int64_t firstInputDim = firstRankedInputShape.
getDimSize(i);
1753 if (i == axis || firstRankedInputShape.
isDynamicDim(i) ||
1754 inputShape.isDynamicDim(i))
1756 if (inputDim != firstInputDim)
1757 return emitOpError(
"expect all operand shapes to have the same sizes "
1758 "on non-axis dimensions, but got ")
1759 << inputDim <<
" vs " << firstInputDim <<
" at index " << i
1760 <<
" on operands 0 and " << operandNum;
1765 int64_t axisSum = 0;
1766 for (
const auto &input : inputList) {
1768 if (inputShape.isDynamicDim(axis)) {
1773 axisSum += inputShape.getDimSize(axis);
1776 if (axisSum >= 0 && outputShape.hasRank() &&
1777 !outputShape.isDynamicDim(axis) &&
1778 axisSum != outputShape.getDimSize(axis))
1779 return emitOpError(
"requires sum of axis dimensions of input1 "
1780 "equal to output axis dimension, got ")
1781 << axisSum <<
" and " << outputShape.getDimSize(axis);
1787 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1788 MLIRContext *context, ::std::optional<Location> location,
1805 if (l.size() != r.size() || l.size() != 1)
1810 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1811 MLIRContext *context, ::std::optional<Location> location,
1812 MatMulOp::Adaptor adaptor,
1819 outShape.resize(3, ShapedType::kDynamic);
1821 if (lhsShape.hasRank()) {
1822 outShape[0] = lhsShape.getDimSize(0);
1823 outShape[1] = lhsShape.getDimSize(1);
1826 if (rhsShape.hasRank()) {
1827 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
1829 outShape[2] = rhsShape.getDimSize(2);
1837 auto aType = llvm::dyn_cast<ShapedType>(getA().
getType());
1838 auto bType = llvm::dyn_cast<ShapedType>(getB().
getType());
1842 return emitOpError(
"expect a shaped tensor for input a, got ")
1843 << getA().getType();
1846 return emitOpError(
"expect a shaped tensor for input b, got ")
1847 << getB().getType();
1849 auto aElementType = aType.getElementType();
1850 auto bElementType = bType.getElementType();
1852 auto aQuantizedEType =
1853 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1854 auto bQuantizedEType =
1855 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1857 if (aQuantizedEType || bQuantizedEType) {
1858 if (!aQuantizedEType || !bQuantizedEType) {
1859 return emitOpError(
"expect operands to be both quantized or both not "
1861 << aElementType <<
" and " << bElementType;
1864 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1865 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1866 if (aQuantWidth != bQuantWidth) {
1867 return emitOpError(
"expect quantized operands to have same widths, got ")
1868 << aQuantWidth <<
" and " << bQuantWidth;
1875 if (aEType != aZpEType) {
1876 return emitOpError(
"expect input a and a_zp have the same "
1877 "element type, got ")
1878 << aEType <<
" and " << aZpEType;
1883 if (bEType != bZpEType) {
1884 return emitOpError(
"expect input b and b_zp have the same "
1885 "element type, got ")
1886 << bEType <<
" and " << bZpEType;
1889 FailureOr<int64_t> maybeAZp = getAZeroPoint();
1890 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1893 FailureOr<int64_t> maybeBZp = getBZeroPoint();
1894 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1900 LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
1901 MLIRContext *context, ::std::optional<Location> location,
1902 MatmulTBlockScaledOp::Adaptor adaptor,
1906 const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
1907 if (aDataShape.hasRank()) {
1908 outShape[0] = aDataShape.getDimSize(0);
1909 outShape[1] = aDataShape.getDimSize(1);
1912 const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
1913 if (aScaleShape.hasRank()) {
1914 outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
1916 outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
1921 const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
1922 if (bDataShape.hasRank()) {
1923 const int64_t bDataBatchSize = bDataShape.getDimSize(0);
1924 if (bDataBatchSize != 1)
1926 ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
1927 outShape[2] = bDataShape.getDimSize(1);
1930 const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
1931 if (bScaleShape.hasRank()) {
1932 const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
1933 if (bScaleBatchSize != 1)
1935 ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
1936 outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
1946 const Type aDataType = getAData().getType();
1947 const Type bDataType = getBData().getType();
1952 auto tryUpdateDimOrFailure = [&](int64_t &currDim,
const int64_t newDim,
1953 const StringRef operandName,
1954 const StringRef dimName) -> LogicalResult {
1955 if (ShapedType::isDynamic(currDim)) {
1958 }
else if (ShapedType::isStatic(newDim) && currDim != newDim) {
1959 return emitOpError(
"expected ")
1960 << dimName <<
" of " << operandName <<
" to match size " << currDim
1961 <<
", got " << newDim;
1967 int64_t N = ShapedType::kDynamic;
1968 int64_t D = ShapedType::kDynamic;
1969 int64_t H = ShapedType::kDynamic;
1970 int64_t
W = ShapedType::kDynamic;
1971 int64_t
C = ShapedType::kDynamic;
1972 int64_t multiplesOfC = ShapedType::kDynamic;
2008 "b_scale",
"C/block_size")))
2013 if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
2014 return emitOpError(
"expect B matrix batch size to be broadcast compatible "
2016 << D <<
" vs N=" << N;
2019 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(
getBlockSize());
2020 if (ShapedType::isStatic(C) && C % blockSize != 0)
2021 return emitOpError(
"expect C to be a multiple of block size, got C=")
2022 <<
C <<
", block_size=" << blockSize;
2025 if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
2026 multiplesOfC != C / blockSize)
2028 "expect scale operands dimension 2 to equal C/block_size (")
2029 <<
C <<
"/" << blockSize <<
")"
2030 <<
", got " << multiplesOfC;
2033 N = ShapedType::isDynamic(N) ? D : N;
2035 const auto outputType = cast<ShapedType>(getResult().
getType());
2036 if (outputType.hasRank() &&
2040 auto stringifyDim = [&](int64_t d) {
2041 if (ShapedType::isDynamic(d))
2046 llvm::interleaveComma(outputType.getShape(), opError, stringifyDim);
2047 opError <<
" to be compatible with expected output shape ";
2048 llvm::interleaveComma(expectedOutputShape, opError, stringifyDim);
2055 LogicalResult tosa::PadOp::inferReturnTypeComponents(
2056 MLIRContext *context, ::std::optional<Location> location,
2057 PadOp::Adaptor adaptor,
2059 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2061 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
2066 if (!inputShape.hasRank()) {
2067 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
2076 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
2081 outputShape.reserve(inputShape.getRank());
2082 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
2083 if (inputShape.isDynamicDim(i)) {
2084 outputShape.push_back(ShapedType::kDynamic);
2087 auto padFront = paddingValues[i * 2];
2088 auto padBack = paddingValues[i * 2 + 1];
2089 if (padFront < 0 || padBack < 0) {
2091 outputShape.push_back(ShapedType::kDynamic);
2095 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
2109 if (
auto padConst = getPadConst()) {
2117 RankedTensorType inputType =
2118 llvm::dyn_cast<RankedTensorType>(getInput1().
getType());
2119 RankedTensorType outputType =
2120 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
2121 if (!inputType || !outputType)
2124 auto inputRank = inputType.getRank();
2125 auto outputRank = outputType.getRank();
2126 if (inputRank != outputRank)
2127 return emitOpError() <<
"expect same input and output tensor rank, but got "
2128 <<
"inputRank: " << inputRank
2129 <<
", outputRank: " << outputRank;
2136 auto paddingValues = paddingAttr.getValues<APInt>();
2137 if (paddingValues.size() !=
static_cast<size_t>(inputRank * 2))
2138 return emitOpError() <<
"padding tensor must have " << inputRank
2139 <<
" * 2 = " << inputRank * 2 <<
" elements, but got "
2140 << paddingValues.size();
2142 auto inputShape = inputType.getShape();
2143 auto outputShape = outputType.getShape();
2145 for (int64_t i = 0; i < inputRank; ++i) {
2146 int64_t padStart = paddingValues[i * 2].getSExtValue();
2147 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
2149 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
2150 return emitOpError()
2151 <<
"invalid padding values at dimension " << i
2152 <<
": values must be non-negative or -1 for dynamic padding, got ["
2153 << padStart <<
", " << padEnd <<
"]";
2157 if (inputShape[i] == ShapedType::kDynamic ||
2158 outputShape[i] == ShapedType::kDynamic)
2161 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
2162 return emitOpError() <<
"mismatch in output shape at dimension " << i
2163 <<
": expected " << inputShape[i] <<
" + "
2164 << padStart <<
" + " << padEnd <<
" = "
2165 << (inputShape[i] + padStart + padEnd)
2166 <<
", but got " << outputShape[i];
2173 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
2174 MLIRContext *context, ::std::optional<Location> location,
2175 SliceOp::Adaptor adaptor,
2184 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
2192 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2195 if (inputShape.hasRank()) {
2196 for (
size_t i = 0; i < size.size(); i++) {
2197 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
2198 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
2199 start[i] < inputShape.getDimSize(i))) {
2201 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
2204 outputShape[i] = size[i];
2208 if (size[i] == -1) {
2209 outputShape[i] = inputShape.getDimSize(i) - start[i];
2210 }
else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2212 outputShape[i] = size[i];
2231 if (inputShape.hasRank()) {
2232 const auto inputRank = inputShape.getRank();
2234 if (outputShape.hasRank() && inputRank != outputShape.getRank())
2236 "expect input1 and output to have the same ranks, got ")
2237 << inputRank <<
" and " << outputShape.getRank();
2239 const auto startShapeRank =
2240 llvm::cast<tosa::shapeType>(getStart().
getType()).getRank();
2241 if (inputRank != startShapeRank)
2242 return emitOpError(
"length of start is not equal to rank of input shape");
2244 const auto sizeShapeRank =
2245 llvm::cast<tosa::shapeType>(getSize().
getType()).getRank();
2246 if (inputRank != sizeShapeRank)
2247 return emitOpError(
"length of size is not equal to rank of input shape");
2253 LogicalResult tosa::MulOp::inferReturnTypeComponents(
2254 MLIRContext *context, ::std::optional<Location> location,
2270 const Value output = getOutput();
2275 if (
auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2276 IntegerType lhsIntType =
2278 IntegerType rhsIntType =
2280 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2281 return emitOpError(
"requires the same element type for all operands");
2286 if (lhsIntType.getWidth() > resIntType.getWidth())
2287 return emitOpError(
"invalid data type size for operands or result");
2292 for (
int i = 0; i < 2; ++i) {
2295 "requires the same element type for all operands and results");
2299 ElementsAttr shift_elem;
2301 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
2303 return emitOpError() <<
"require shift to be 0 for float type";
2311 TypeRange operandTypes = getOperandTypes();
2312 ShapedType aType = cast<ShapedType>(operandTypes[0]);
2313 ShapedType bType = cast<ShapedType>(operandTypes[1]);
2315 const bool aHasRank = aType.hasRank();
2316 const bool bHasRank = bType.hasRank();
2317 if (aHasRank && bHasRank) {
2318 const int64_t aRank = aType.getRank();
2319 const int64_t bRank = bType.getRank();
2321 return emitOpError(
"a and b operands don't have matching ranks, got ")
2322 << aRank <<
" and " << bRank;
2327 aType.getShape(), bType.getShape(), resultShape))
2328 return emitOpError(
"a and b operands don't have broadcast-compatible "
2330 << aType <<
" and " << bType;
2333 ShapedType resultType = cast<ShapedType>(output.getType());
2334 if (!resultType.hasRank())
2337 const int64_t resultRank = resultType.getRank();
2338 if (aHasRank && resultRank != aType.getRank())
2339 return emitOpError(
"result type has different rank than a, got ")
2340 << resultRank <<
" vs " << aType.getRank();
2341 if (bHasRank && resultRank != bType.getRank())
2342 return emitOpError(
"result type has different rank than b, got ")
2343 << resultRank <<
" vs " << bType.getRank();
2348 LogicalResult tosa::TableOp::inferReturnTypeComponents(
2349 MLIRContext *context, ::std::optional<Location> location,
2350 TableOp::Adaptor adaptor,
2352 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2354 if (!inputShape.hasRank()) {
2359 inferredReturnShapes.resize(1);
2360 inputShape.getDims(inferredReturnShapes[0]);
2365 const TensorType inputType = getInput1().getType();
2366 const TensorType outputType = getOutput().getType();
2371 if (inputType.getRank() != outputType.getRank())
2372 return emitOpError()
2373 <<
"expected input tensor rank to equal result tensor rank";
2375 auto inputDims = inputType.
getShape();
2376 auto outputDims = outputType.
getShape();
2378 int64_t dim = it.index();
2379 auto [inputDim, outputDim] = it.value();
2380 if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2381 return emitOpError() <<
"dim(result, " << dim <<
") = " << outputDim
2382 <<
" doesn't match dim(input, " << dim
2383 <<
") = " << inputDim;
2395 multiples = llvm::to_vector(
2396 llvm::map_range(multiplesAttr.getValues<APInt>(),
2397 [](
const APInt &val) { return val.getSExtValue(); }));
2401 LogicalResult tosa::TileOp::inferReturnTypeComponents(
2402 MLIRContext *context, ::std::optional<Location> location,
2403 TileOp::Adaptor adaptor,
2410 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2418 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2420 if (!inputShape.hasRank()) {
2421 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2422 inferredReturnShapes.push_back(
2425 }
else if (
static_cast<size_t>(inputShape.getRank()) != multiples.size())
2429 outputShape.reserve(multiples.size());
2430 for (
int i = 0, s = inputShape.getRank(); i < s; i++) {
2431 if (multiples[i] == ShapedType::kDynamic) {
2432 outputShape.push_back(ShapedType::kDynamic);
2434 int64_t dim = inputShape.getDimSize(i);
2435 if (dim != ShapedType::kDynamic)
2436 dim *= multiples[i];
2437 outputShape.push_back(dim);
2451 ShapedType inputType = llvm::cast<ShapedType>(getInput1().
getType());
2452 ShapedType outputType = llvm::cast<ShapedType>(
getType());
2454 shapeType multiplesType =
2455 llvm::cast<tosa::shapeType>(getMultiples().
getType());
2457 auto multiplesRank = multiplesType.getRank();
2459 if (inputType.hasRank()) {
2460 if (inputType.getRank() != multiplesRank)
2461 return emitOpError(
"expect 'multiples' to have rank ")
2462 << inputType.getRank() <<
" but got " << multiplesRank <<
".";
2463 if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
2464 return emitOpError(
"expect same input and output tensor rank.");
2465 }
else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2466 return emitOpError(
"expect 'multiples' array to have length ")
2467 << outputType.getRank() <<
" but got " << multiplesRank <<
".";
2470 if (getConstantMultiples(multiples).succeeded() &&
2471 llvm::any_of(multiples, [](int64_t v) {
return v <= 0 && v != -1; }))
2473 "expect element of 'multiples' to be positive integer or -1.");
2479 if (l.size() != r.size() || l.size() != 1)
2484 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2485 MLIRContext *context, ::std::optional<Location> location,
2486 ReshapeOp::Adaptor adaptor,
2488 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2493 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2503 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2504 inferredReturnShapes.push_back(
2512 int64_t numElements = inputShape.getNumElements();
2513 int64_t staticMul = 1;
2514 for (
auto val : newShapeValue) {
2515 if (ShapedType::isStatic(val)) {
2521 for (
auto &val : newShapeValue) {
2522 if (ShapedType::isDynamic(val))
2523 val = numElements / staticMul;
2526 inferredReturnShapes.push_back(
2537 TensorType inputType = getInput1().getType();
2542 return mlir::success();
2545 int missingDims = llvm::count(shapeValues, -1);
2546 if (missingDims > 1)
2547 return emitOpError() <<
"expected at most one target dimension to be -1";
2549 const auto outputType = dyn_cast<RankedTensorType>(
getType());
2553 if ((int64_t)shapeValues.size() != outputType.getRank())
2554 return emitOpError() <<
"new shape does not match result rank";
2556 for (
auto [newShapeDim, outputShapeDim] :
2557 zip(shapeValues, outputType.getShape())) {
2558 if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
2559 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2560 return emitOpError() <<
"new shape is inconsistent with result shape";
2562 if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
2563 return emitOpError() <<
"new shape has invalid tensor dimension size "
2567 if (inputType.hasStaticShape()) {
2568 int64_t inputElementsNum = inputType.getNumElements();
2569 if (outputType.hasStaticShape()) {
2570 int64_t outputElementsNum = outputType.getNumElements();
2571 if (inputElementsNum != outputElementsNum) {
2572 return emitOpError() <<
"cannot reshape " << inputElementsNum
2573 <<
" elements into " << outputElementsNum;
2577 int64_t newShapeElementsNum =
2578 llvm::accumulate(shapeValues, int64_t(1), [](int64_t acc, int64_t dim) {
2579 return (dim > 0) ? acc * dim : acc;
2581 bool isStaticNewShape =
2582 llvm::all_of(shapeValues, [](int64_t s) {
return s > 0; });
2583 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2584 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2585 return emitOpError() <<
"cannot reshape " << inputElementsNum
2586 <<
" elements into " << newShapeElementsNum;
2590 return mlir::success();
2597 ElementsAttr zpAttr;
2602 Type zpElemType = zpAttr.getElementType();
2604 if (llvm::isa<FloatType>(zpElemType)) {
2605 if (zpAttr.getValues<APFloat>()[0].isZero()) {
2612 if (llvm::isa<IntegerType>(zpElemType)) {
2614 return zpAttr.getValues<APInt>()[0].getSExtValue();
2616 return zpAttr.getValues<APInt>()[0].getZExtValue();
2623 template <
typename T>
2625 const std::string &operand) {
2628 if (!zpElemType.
isInteger(8) && zp != 0) {
2630 std::string lower = operand;
2631 std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
2632 return op.emitOpError()
2633 << lower <<
" zero point must be zero for non-int8 integer types";
2641 const std::string &operand) {
2642 bool isInputZp = (operand ==
"Input");
2644 bool tensorUnsigned =
2645 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
2646 StringRef tensorName = isInputZp ?
"input" :
"output";
2652 !(zpElemType.
isInteger(16) && tensorUnsigned)) {
2653 return op.emitOpError()
2654 <<
"expect " << tensorName <<
"_zp of 0, got " << zp;
2656 if (zpElemType.
isInteger(16) && tensorUnsigned && zp != 32768) {
2657 return op.emitOpError() <<
"expect " << tensorName
2658 <<
"_zp of 0 or 32768 for unsigned int16 "
2659 << tensorName <<
", got " << zp;
2666 #define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
2667 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
2668 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
2670 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
2671 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
2690 #undef ZERO_POINT_HELPER
2692 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
2693 MLIRContext *context, ::std::optional<Location> location,
2694 TransposeOp::Adaptor adaptor,
2696 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2705 const auto inputRank = inputShape.
getRank();
2709 if (adaptor.getPerms().size() !=
static_cast<size_t>(inputRank)) {
2715 if (inputRank == 0) {
2721 bool allTheSame =
true;
2722 for (
int i = 1, s = inputRank; i < s; i++) {
2732 outputShape.resize(inputRank, inputShape.
getDimSize(0));
2737 outputShape.resize(inputRank, ShapedType::kDynamic);
2740 if (llvm::any_of(adaptor.getPerms(),
2741 [inputRank](
const auto i) { return i >= inputRank; }))
2744 outputShape.reserve(inputRank);
2745 for (
int i = 0, s = inputRank; i < s; i++) {
2746 outputShape[i] = inputShape.
getDimSize(adaptor.getPerms()[i]);
2765 if (inputShape.hasRank() &&
2766 constantPerms.size() !=
static_cast<size_t>(inputShape.getRank()))
2767 return emitOpError() <<
"expected perms attribute to have size "
2768 << inputShape.getRank()
2769 <<
" (input rank) but got size "
2770 << constantPerms.size();
2772 if (inputShape.hasRank() && outputShape.hasRank() &&
2773 inputShape.getRank() != outputShape.getRank())
2774 return emitOpError()
2775 <<
"expected input tensor rank to equal result tensor rank";
2777 if (outputShape.hasRank() &&
2778 constantPerms.size() !=
static_cast<size_t>(outputShape.getRank()))
2779 return emitOpError() <<
"expected perms attribute to have size "
2780 << outputShape.getRank()
2781 <<
" (output rank) but got size "
2782 << constantPerms.size();
2784 if (!llvm::all_of(constantPerms,
2785 [&constantPerms](int32_t s) {
2787 static_cast<size_t>(s) < constantPerms.size();
2790 constantPerms, [](int32_t v) -> int64_t {
return v; }))))
2791 return emitOpError() <<
"expected valid permutation indices";
2794 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
2795 inputShape.getNumElements() != outputShape.getNumElements())
2796 return emitOpError() <<
"expected input1 and output to have same numbers "
2798 << inputShape.getNumElements() <<
" and "
2799 << outputShape.getNumElements();
2803 if (inputShape.hasRank() && outputShape.hasRank()) {
2804 for (
auto i = 0; i < outputShape.getRank(); i++) {
2805 if (inputShape.isDynamicDim(constantPerms[i]) ||
2806 outputShape.isDynamicDim(i))
2809 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
2810 return emitOpError()
2811 <<
"expected output tensor dim " << i <<
" to match "
2812 <<
"input dim " << constantPerms[i] <<
" with value of "
2813 << inputShape.getDimSize(constantPerms[i]);
2825 Value input = getInput1();
2826 auto inputType = cast<TensorType>(input.
getType());
2829 for (
auto dim : transposePerms) {
2830 int32_t dimInInput = transposePerms[dim];
2831 if (inputType.isDynamicDim(dimInInput))
2833 tensor::DimOp::create(builder, getLoc(), input, dimInInput)
2837 builder.
getIndexAttr(inputType.getDimSize(dimInInput));
2840 reifiedReturnShapes.emplace_back(std::move(returnedDims));
2844 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
2845 MLIRContext *context, ::std::optional<Location> location,
2846 GatherOp::Adaptor adaptor,
2849 outputShape.resize(3, ShapedType::kDynamic);
2851 ShapeAdaptor valuesShape(adaptor.getValues().getType());
2852 if (valuesShape.hasRank()) {
2853 outputShape[0] = valuesShape.getDimSize(0);
2854 outputShape[2] = valuesShape.getDimSize(2);
2857 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
2858 if (indicesShape.hasRank()) {
2859 if (outputShape[0] == ShapedType::kDynamic)
2860 outputShape[0] = indicesShape.getDimSize(0);
2861 if (outputShape[1] == ShapedType::kDynamic)
2862 outputShape[1] = indicesShape.getDimSize(1);
2880 int64_t N = ShapedType::kDynamic;
2881 int64_t
W = ShapedType::kDynamic;
2882 int64_t
C = ShapedType::kDynamic;
2884 if (valuesShape.hasRank()) {
2885 N = valuesShape.getDimSize(0);
2886 C = valuesShape.getDimSize(2);
2888 if (indicesShape.hasRank()) {
2889 const int64_t indicesN = indicesShape.getDimSize(0);
2890 W = indicesShape.getDimSize(1);
2891 if (N == ShapedType::kDynamic)
2893 else if (indicesN != ShapedType::kDynamic && N != indicesN)
2894 return emitOpError() <<
"requires indices dimension 0 to have size " << N
2895 <<
", got " << indicesN;
2897 if (outputShape.hasRank()) {
2898 const int64_t outputN = outputShape.getDimSize(0);
2899 const int64_t outputW = outputShape.getDimSize(1);
2900 const int64_t outputC = outputShape.getDimSize(2);
2901 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2903 return emitOpError() <<
"requires output dimension 0 to have size " << N
2904 <<
", got " << outputN;
2906 if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2908 return emitOpError() <<
"requires output dimension 1 to have size " <<
W
2909 <<
", got " << outputW;
2910 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2912 return emitOpError() <<
"requires output dimension 2 to have size " <<
C
2913 <<
", got " << outputC;
2918 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
2919 MLIRContext *context, ::std::optional<Location> location,
2920 ResizeOp::Adaptor adaptor,
2923 outputShape.resize(4, ShapedType::kDynamic);
2926 if (!inputShape.hasRank())
2929 outputShape[0] = inputShape.getDimSize(0);
2930 outputShape[3] = inputShape.getDimSize(3);
2931 int64_t inputHeight = inputShape.getDimSize(1);
2932 int64_t inputWidth = inputShape.getDimSize(2);
2934 if ((inputHeight == ShapedType::kDynamic) ||
2935 (inputWidth == ShapedType::kDynamic))
2949 const int64_t outputHeight =
2950 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
2954 const int64_t outputWidth =
2955 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
2959 if (outputHeight < 0 || outputWidth < 0) {
2962 "calculated output height and width must be non-negative, "
2964 outputHeight,
", width = ", outputWidth);
2967 outputShape[1] = outputHeight;
2968 outputShape[2] = outputWidth;
2974 const Value input = getInput();
2975 const Value output = getOutput();
2976 const RankedTensorType inputType =
2977 llvm::dyn_cast<RankedTensorType>(input.
getType());
2978 const RankedTensorType outputType =
2979 llvm::dyn_cast<RankedTensorType>(output.getType());
2991 if (llvm::any_of(scaleValues, [](int64_t s) {
return s <= 0; }))
2992 return emitOpError(
"expect all scale values to be > 0, got ")
2995 const int64_t scaleYN = scaleValues[0];
2996 const int64_t scaleYD = scaleValues[1];
2997 const int64_t scaleXN = scaleValues[2];
2998 const int64_t scaleXD = scaleValues[3];
3000 const int64_t offsetY = offsetValues[0];
3001 const int64_t offsetX = offsetValues[1];
3003 const int64_t borderY = borderValues[0];
3004 const int64_t borderX = borderValues[1];
3011 const int64_t oh = outputType.getDimSize(1);
3012 const int64_t ow = outputType.getDimSize(2);
3013 const int64_t ih = inputType.getDimSize(1);
3014 const int64_t iw = inputType.getDimSize(2);
3020 if (ih != ShapedType::kDynamic && ih != 1) {
3021 const std::optional<int64_t> calculatedOutHeightMinusOne =
3022 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
3023 if (!calculatedOutHeightMinusOne.has_value())
3024 return emitOpError(
"expected (input_height - 1) * scale_y_n - offset_y + "
3026 <<
"to be wholly divisible by scale_y_d, got ((" << ih
3027 <<
" - 1) * " << scaleYN <<
" - " << offsetY <<
" + " << borderY
3028 <<
") / " << scaleYD;
3029 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
3030 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
3031 return emitOpError(
"calculated output height did not match expected: ")
3032 <<
"calculated=" << calculatedOutHeight <<
", expected=" << oh;
3039 if (iw != ShapedType::kDynamic && iw != 1) {
3040 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
3041 const std::optional<int64_t> calculatedOutWidthMinusOne =
3043 if (!calculatedOutWidthMinusOne.has_value())
3044 return emitOpError(
"expected (input_width - 1) * scale_x_n - offset_x + "
3046 <<
"to be wholly divisible by scale_x_d, got ((" << iw
3047 <<
" - 1) * " << scaleXN <<
" - " << offsetX <<
" + " << borderX
3048 <<
") / " << scaleXD;
3049 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
3050 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
3051 return emitOpError(
"calculated output width did not match expected: ")
3052 <<
"calculated=" << calculatedOutWidth <<
", expected=" << ow;
3058 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
3059 MLIRContext *context, ::std::optional<Location> location,
3060 ScatterOp::Adaptor adaptor,
3063 outputShape.resize(3, ShapedType::kDynamic);
3065 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
3066 if (valuesInShape.hasRank()) {
3067 outputShape[0] = valuesInShape.getDimSize(0);
3068 outputShape[1] = valuesInShape.getDimSize(1);
3069 outputShape[2] = valuesInShape.getDimSize(2);
3072 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3073 if (indicesShape.hasRank()) {
3074 if (outputShape[0] == ShapedType::kDynamic)
3075 outputShape[0] = indicesShape.getDimSize(0);
3079 if (inputShape.hasRank()) {
3080 if (outputShape[0] == ShapedType::kDynamic)
3081 outputShape[0] = inputShape.getDimSize(0);
3082 if (outputShape[2] == ShapedType::kDynamic)
3083 outputShape[2] = inputShape.getDimSize(2);
3105 int64_t N = ShapedType::kDynamic;
3106 int64_t K = ShapedType::kDynamic;
3107 int64_t
W = ShapedType::kDynamic;
3108 int64_t
C = ShapedType::kDynamic;
3109 if (valuesInShape.hasRank()) {
3110 N = valuesInShape.getDimSize(0);
3111 K = valuesInShape.getDimSize(1);
3112 C = valuesInShape.getDimSize(2);
3114 if (indicesShape.hasRank()) {
3115 const int64_t indicesN = indicesShape.getDimSize(0);
3116 W = indicesShape.getDimSize(1);
3117 if (N == ShapedType::kDynamic)
3119 else if (indicesN != ShapedType::kDynamic && N != indicesN)
3120 return emitOpError() <<
"requires indices dimension 0 to have size " << N
3121 <<
", got " << indicesN;
3123 if (inputShape.hasRank()) {
3124 const int64_t inputN = inputShape.getDimSize(0);
3125 const int64_t inputW = inputShape.getDimSize(1);
3126 const int64_t inputC = inputShape.getDimSize(2);
3127 if (N == ShapedType::kDynamic)
3129 else if (inputN != ShapedType::kDynamic && N != inputN)
3130 return emitOpError() <<
"requires input dimension 0 to have size " << N
3131 <<
", got " << inputN;
3132 if (W == ShapedType::kDynamic)
3134 else if (inputW != ShapedType::kDynamic && W != inputW)
3135 return emitOpError() <<
"requires input dimension 1 to have size " <<
W
3136 <<
", got " << inputW;
3138 if (C == ShapedType::kDynamic)
3140 else if (inputC != ShapedType::kDynamic && C != inputC)
3141 return emitOpError() <<
"requires input dimension 2 to have size " <<
C
3142 <<
", got " << inputC;
3144 if (outputShape.hasRank()) {
3145 const int64_t outputN = outputShape.getDimSize(0);
3146 const int64_t outputK = outputShape.getDimSize(1);
3147 const int64_t outputC = outputShape.getDimSize(2);
3148 if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3150 return emitOpError() <<
"requires values_out dimension 0 to have size "
3151 << N <<
", got " << outputN;
3152 if (K == ShapedType::kDynamic)
3154 else if (outputK != ShapedType::kDynamic && K != outputK)
3155 return emitOpError() <<
"requires values_out dimension 1 to have size "
3156 << K <<
", got " << outputK;
3157 if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3159 return emitOpError() <<
"requires values_out dimension 2 to have size "
3160 <<
C <<
", got " << outputC;
3162 if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
3163 return emitOpError() <<
"requires dimensions K >= W, got K=" << K
3172 int64_t axisVal = axis.getValue().getSExtValue();
3173 if (!operandShape.
hasRank() || operandShape.
getRank() <= axisVal) {
3179 operandShape.
getDims(outputShape);
3180 outputShape[axisVal] = 1;
3185 #define COMPATIBLE_RETURN_TYPES(OP) \
3186 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
3187 if (l.size() != r.size() || l.size() != 1) \
3189 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
3191 return succeeded(verifyCompatibleShape(l[0], r[0])); \
3194 #define REDUCE_SHAPE_INFER(OP) \
3195 LogicalResult OP::inferReturnTypeComponents( \
3196 MLIRContext *context, ::std::optional<Location> location, \
3197 OP::Adaptor adaptor, \
3198 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3200 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
3201 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
3202 const Properties &prop = adaptor.getProperties(); \
3203 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
3204 inferredReturnShapes); \
3206 COMPATIBLE_RETURN_TYPES(OP)
3214 #undef REDUCE_SHAPE_INFER
3216 #undef COMPATIBLE_RETURN_TYPES
3218 template <
typename T>
3221 TensorType inputType = op.getInput().getType();
3222 TensorType outputType = op.getOutput().getType();
3223 int32_t reduceAxis = op.getAxis();
3225 if (reduceAxis < 0) {
3226 op.emitOpError(
"reduce axis must not be negative");
3230 int64_t inputRank = inputType.getRank();
3233 if (reduceAxis >= inputRank && (reduceAxis != 0 || inputRank != 0)) {
3234 op.emitOpError(
"expect input tensor rank (")
3235 << inputRank <<
") to be larger than reduce axis (" << reduceAxis
3241 int64_t outputRank = outputType.getRank();
3242 if (inputType.
hasRank() && outputRank != inputType.getRank()) {
3244 "expect output tensor rank to be equal to input tensor rank");
3247 if (reduceAxis >= outputRank && (reduceAxis != 0 || outputRank != 0)) {
3248 op.emitOpError(
"expect output tensor rank (")
3249 << outputRank <<
") to be larger than reduce axis (" << reduceAxis
3255 if (outputRank != 0) {
3256 auto outputShape = outputType.
getShape();
3257 if (!outputType.isDynamicDim(reduceAxis) &&
3258 outputShape[reduceAxis] != 1) {
3259 op.emitOpError(
"expect reduced dimension size to be 1, got ")
3260 << outputShape[reduceAxis];
3287 #define NARY_SHAPE_INFER(OP) \
3288 LogicalResult OP::inferReturnTypeComponents( \
3289 MLIRContext *context, ::std::optional<Location> location, \
3290 ValueShapeRange operands, DictionaryAttr attributes, \
3291 OpaqueProperties properties, RegionRange regions, \
3292 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3293 return NAryInferReturnTypes(operands, inferredReturnShapes); \
3333 #undef PRED_SHAPE_INFER
3335 LogicalResult tosa::NegateOp::inferReturnTypeComponents(
3336 MLIRContext *context, ::std::optional<Location> location,
3337 NegateOp::Adaptor adaptor,
3339 ShapeAdaptor inputShape(adaptor.getInput1().getType());
3346 const Type input1Type = getInput1().getType();
3347 const Type outputType = getOutput().getType();
3354 return emitOpError() <<
"requires the same shape for input1 and output";
3357 const Type input1ZpEType =
3359 if (input1EType != input1ZpEType) {
3360 return emitOpError(
"expect both input1 and its zero point are the same "
3361 "element type, got ")
3362 << input1EType <<
" and " << input1ZpEType;
3365 const Type outputZpEType =
3367 if (outputEType != outputZpEType) {
3368 return emitOpError(
"expect both output and its zero point are the same "
3369 "element type, got ")
3370 << outputEType <<
" and " << outputZpEType;
3373 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
3374 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
3377 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3378 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3389 outputShape.resize(4, ShapedType::kDynamic);
3404 if (ShapedType::isStatic(height)) {
3405 int64_t padded = height + pad[0] + pad[1] - kernel[0];
3406 outputShape[1] = padded / stride[0] + 1;
3409 if (ShapedType::isStatic(width)) {
3410 int64_t padded = width + pad[2] + pad[3] - kernel[1];
3411 outputShape[2] = padded / stride[1] + 1;
3418 LogicalResult Conv2DOp::inferReturnTypeComponents(
3419 MLIRContext *context, ::std::optional<Location> location,
3420 Conv2DOp::Adaptor adaptor,
3424 int64_t inputWidth = ShapedType::kDynamic;
3425 int64_t inputHeight = ShapedType::kDynamic;
3426 int64_t weightWidth = ShapedType::kDynamic;
3427 int64_t weightHeight = ShapedType::kDynamic;
3432 if (inputShape.hasRank()) {
3433 outputShape[0] = inputShape.getDimSize(0);
3434 inputHeight = inputShape.getDimSize(1);
3435 inputWidth = inputShape.getDimSize(2);
3439 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3440 if (weightShape.hasRank()) {
3441 outputShape[3] = weightShape.getDimSize(0);
3442 weightHeight = weightShape.getDimSize(1);
3443 weightWidth = weightShape.getDimSize(2);
3448 if (biasShape.hasRank()) {
3449 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3450 ? biasShape.getDimSize(0)
3458 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3459 int64_t inputSize = inputHeight + padding[0] + padding[1];
3460 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3461 int64_t unstridedResult = inputSize - filterSize + 1;
3462 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3465 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3466 int64_t inputSize = inputWidth + padding[2] + padding[3];
3467 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3468 int64_t unstridedResult = inputSize - filterSize + 1;
3469 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3483 LogicalResult Conv3DOp::inferReturnTypeComponents(
3484 MLIRContext *context, ::std::optional<Location> location,
3485 Conv3DOp::Adaptor adaptor,
3489 int64_t inputWidth = ShapedType::kDynamic;
3490 int64_t inputHeight = ShapedType::kDynamic;
3491 int64_t inputDepth = ShapedType::kDynamic;
3493 int64_t weightWidth = ShapedType::kDynamic;
3494 int64_t weightHeight = ShapedType::kDynamic;
3495 int64_t weightDepth = ShapedType::kDynamic;
3499 if (inputShape.hasRank()) {
3500 outputShape[0] = inputShape.getDimSize(0);
3501 inputDepth = inputShape.getDimSize(1);
3502 inputHeight = inputShape.getDimSize(2);
3503 inputWidth = inputShape.getDimSize(3);
3507 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3508 if (weightShape.hasRank()) {
3509 outputShape[4] = weightShape.getDimSize(0);
3510 weightDepth = weightShape.getDimSize(1);
3511 weightHeight = weightShape.getDimSize(2);
3512 weightWidth = weightShape.getDimSize(3);
3517 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
3518 outputShape[4] = biasShape.getDimSize(0);
3525 if (ShapedType::isStatic(inputDepth) && ShapedType::isStatic(weightDepth)) {
3526 int32_t inputSize = inputDepth + pad[0] + pad[1];
3527 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
3528 int32_t unstridedResult = inputSize - filterSize + 1;
3529 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3532 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3533 int32_t inputSize = inputHeight + pad[2] + pad[3];
3534 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
3535 int32_t unstridedResult = inputSize - filterSize + 1;
3536 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3539 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3540 int32_t inputSize = inputWidth + pad[4] + pad[5];
3541 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
3542 int32_t unstridedResult = inputSize - filterSize + 1;
3543 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
3557 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
3558 MLIRContext *context, ::std::optional<Location> location,
3559 AvgPool2dOp::Adaptor adaptor,
3562 const Properties &prop = adaptor.getProperties();
3564 inferredReturnShapes);
3567 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
3568 MLIRContext *context, ::std::optional<Location> location,
3569 MaxPool2dOp::Adaptor adaptor,
3572 const Properties &prop = adaptor.getProperties();
3574 inferredReturnShapes);
3588 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
3589 MLIRContext *context, ::std::optional<Location> location,
3590 DepthwiseConv2DOp::Adaptor adaptor,
3594 int64_t inputWidth = ShapedType::kDynamic;
3595 int64_t inputHeight = ShapedType::kDynamic;
3596 int64_t inputChannels = ShapedType::kDynamic;
3598 int64_t weightWidth = ShapedType::kDynamic;
3599 int64_t weightHeight = ShapedType::kDynamic;
3600 int64_t depthChannels = ShapedType::kDynamic;
3604 if (inputShape.hasRank()) {
3605 outputShape[0] = inputShape.getDimSize(0);
3606 inputHeight = inputShape.getDimSize(1);
3607 inputWidth = inputShape.getDimSize(2);
3608 inputChannels = inputShape.getDimSize(3);
3612 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3613 if (weightShape.hasRank()) {
3614 weightHeight = weightShape.getDimSize(0);
3615 weightWidth = weightShape.getDimSize(1);
3616 inputChannels = ShapedType::isDynamic(inputChannels)
3617 ? weightShape.getDimSize(2)
3619 depthChannels = weightShape.getDimSize(3);
3624 if (ShapedType::isStatic(inputChannels) &&
3625 ShapedType::isStatic(depthChannels)) {
3626 outputShape[3] = inputChannels * depthChannels;
3631 if (biasShape.hasRank()) {
3632 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3633 ? biasShape.getDimSize(0)
3641 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3642 int64_t inputSize = inputHeight + padding[0] + padding[1];
3643 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
3644 int64_t unstridedResult = inputSize - filterSize + 1;
3645 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
3648 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3649 int64_t inputSize = inputWidth + padding[2] + padding[3];
3650 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
3651 int64_t unstridedResult = inputSize - filterSize + 1;
3652 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
3666 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
3667 MLIRContext *context, ::std::optional<Location> location,
3668 TransposeConv2DOp::Adaptor adaptor,
3672 int64_t inputWidth = ShapedType::kDynamic;
3673 int64_t inputHeight = ShapedType::kDynamic;
3674 int64_t weightWidth = ShapedType::kDynamic;
3675 int64_t weightHeight = ShapedType::kDynamic;
3679 if (inputShape.hasRank()) {
3680 outputShape[0] = ShapedType::isDynamic(outputShape[0])
3681 ? inputShape.getDimSize(0)
3683 inputHeight = inputShape.getDimSize(1);
3684 inputWidth = inputShape.getDimSize(2);
3688 ShapeAdaptor weightShape(adaptor.getWeight().getType());
3689 if (weightShape.hasRank()) {
3690 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3691 ? weightShape.getDimSize(0)
3693 weightHeight = weightShape.getDimSize(1);
3694 weightWidth = weightShape.getDimSize(2);
3699 if (biasShape.hasRank()) {
3700 outputShape[3] = ShapedType::isDynamic(outputShape[3])
3701 ? biasShape.getDimSize(0)
3708 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
3709 int64_t calculateSize =
3710 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
3712 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
3715 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
3716 int64_t calculateSize =
3717 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
3719 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
3731 const int64_t strideY = strides[0];
3732 const int64_t strideX = strides[1];
3734 if (strideY < 1 || strideX < 1)
3735 return emitOpError(
"expect all stride values to be >= 1, got [")
3738 const auto checkPadAgainstKernelDim =
3739 [
this](int64_t pad_value, int64_t kernel_dim_size,
3740 llvm::StringRef pad_name,
3741 llvm::StringRef kernel_dim_name) -> LogicalResult {
3742 if (pad_value <= -kernel_dim_size)
3743 return emitOpError(
"expected ")
3744 << pad_name <<
" > -" << kernel_dim_name
3745 <<
", but got: " << pad_name <<
"=" << pad_value <<
" and "
3746 << kernel_dim_name <<
"=" << kernel_dim_size;
3751 const int64_t outPadTop = padding[0];
3752 const int64_t outPadBottom = padding[1];
3753 const int64_t outPadLeft = padding[2];
3754 const int64_t outPadRight = padding[3];
3756 const auto weightType =
3757 llvm::dyn_cast<RankedTensorType>(getWeight().
getType());
3760 const int64_t kernelHeight = weightType.getDimSize(1);
3761 if (ShapedType::isStatic(kernelHeight)) {
3762 if (
failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
3763 "out_pad_top",
"KH")))
3766 if (
failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
3767 "out_pad_bottom",
"KH")))
3771 const int64_t kernelWidth = weightType.getDimSize(2);
3772 if (ShapedType::isStatic(kernelWidth)) {
3773 if (
failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
3774 "out_pad_left",
"KW")))
3777 if (
failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
3778 "out_pad_right",
"KW")))
3784 const auto outputType =
3785 llvm::dyn_cast<RankedTensorType>(getOutput().
getType());
3789 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().
getType());
3790 if (inputType && weightType) {
3791 const int64_t inputHeight = inputType.getDimSize(1);
3792 const int64_t kernelHeight = weightType.getDimSize(1);
3793 const int64_t outputHeight = outputType.getDimSize(1);
3795 if (ShapedType::isStatic(inputHeight) &&
3796 ShapedType::isStatic(outputHeight)) {
3798 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
3800 "dimension mismatch: expected OH == (IH - 1) * stride_y "
3801 "+ out_pad_top + out_pad_bottom + KH, but got ")
3802 << outputHeight <<
" != (" << inputHeight <<
" - 1) * "
3803 << strideY <<
" + " << outPadTop <<
" + " << outPadBottom
3804 <<
" + " << kernelHeight;
3807 const int64_t inputWidth = inputType.getDimSize(2);
3808 const int64_t kernelWidth = weightType.getDimSize(2);
3809 const int64_t outputWidth = outputType.getDimSize(2);
3811 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
3813 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
3815 "dimension mismatch: expected OW == (IW - 1) * stride_x "
3816 "+ out_pad_left + out_pad_right + KW, but got ")
3817 << outputWidth <<
" != (" << inputWidth <<
" - 1) * " << strideX
3818 <<
" + " << outPadLeft <<
" + " << outPadRight <<
" + "
3823 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().
getType());
3828 const int64_t biasChannels = biasType.getDimSize(0);
3831 if (biasChannels == ShapedType::kDynamic)
3834 const int64_t outputChannels = outputType.getDimSize(3);
3835 if (!ShapedType::isDynamic(outputChannels) &&
3836 biasChannels != outputChannels && biasChannels != 1)
3838 "bias channels expected to be equal to output channels (")
3839 << outputChannels <<
") or 1, got " << biasChannels;
3845 auto inputType = llvm::dyn_cast<ShapedType>(getInput().
getType());
3847 emitOpError(
"expect shaped tensor for input, got ") << getInput().getType();
3851 auto inputElementType =
3853 if (!mlir::isa<IntegerType>(inputElementType)) {
3854 emitOpError(
"expect input to have integer element type, got ")
3855 << inputElementType;
3859 auto outputType = llvm::dyn_cast<ShapedType>(getOutput().
getType());
3861 emitOpError(
"expect shaped tensor for output, got ")
3862 << getOutput().getType();
3866 auto outputElementType =
3868 if (!mlir::isa<IntegerType>(outputElementType)) {
3869 emitOpError(
"expect output to have integer element type, got ")
3870 << outputElementType;
3882 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
3883 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
3886 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
3887 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
3890 auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().
getType());
3891 if (!multiplierType) {
3892 emitOpError(
"expect shaped tensor for multiplier, got ")
3893 << getMultiplier().getType();
3897 auto shiftType = llvm::dyn_cast<ShapedType>(getShift().
getType());
3899 emitOpError(
"expect shaped tensor for shift, got ") << getShift().getType();
3904 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
3905 emitOpError(
"expect i32 element type for multiplier for scale32=true, got ")
3906 << multiplierType.getElementType();
3911 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
3913 "expect i16 element type for multiplier for scale32=false, got ")
3914 << multiplierType.getElementType();
3918 if (!inputType.hasRank())
3924 int64_t numChannels = 1;
3925 if (getPerChannel()) {
3926 if (inputType.getRank() < 1) {
3927 emitOpError(
"requires input to be at least rank 1 when per_channel is "
3928 "true, but got rank ")
3929 << inputType.getRank();
3932 numChannels = inputType.getDimSize(inputType.getRank() - 1);
3935 if (!multiplierType.hasRank())
3940 if (multiplierShape[0] != ShapedType::kDynamic &&
3941 multiplierShape[0] != numChannels) {
3942 emitOpError(
"expect shape of { ")
3943 << numChannels <<
" } for multiplier input, got { "
3944 << multiplierShape[0] <<
" }";
3948 if (!shiftType.hasRank())
3953 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
3954 emitOpError(
"expect shape of { ")
3955 << numChannels <<
" } for shift input, got { " << shiftShape[0] <<
" }";
3962 LogicalResult RescaleOp::inferReturnTypeComponents(
3963 MLIRContext *context, ::std::optional<Location> location,
3964 RescaleOp::Adaptor adaptor,
3971 LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
3972 MLIRContext *context, ::std::optional<Location> location,
3973 CastFromBlockScaledOp::Adaptor adaptor,
3975 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
3981 const Type inputDataType = getInputData().getType();
3982 const Type outputDataType = getResult().getType();
3984 return emitOpError() <<
"require compatible shapes for input_data ("
3985 << inputDataType <<
") and "
3986 <<
"output_data (" << outputDataType <<
")";
3990 if (inputDataShape.
hasRank()) {
3991 const unsigned int blockSize =
3993 const int64_t inputDataLastDim =
3995 if (inputDataLastDim % blockSize != 0)
3996 return emitOpError() <<
"expect last dimension of input_data ("
3998 <<
") to be divisible by block_size (" << blockSize
4001 const Type inputScaleType = getInputScale().getType();
4004 if (inputScaleShape.
hasRank()) {
4006 inputDataShape.
getDims(inputDataDims);
4007 inputScaleShape.
getDims(inputScaleDims);
4009 if (inputDataDims.size() != inputScaleDims.size() ||
4013 return emitOpError() <<
"require compatible shapes for input_data ("
4014 << inputDataType <<
") and "
4015 <<
"input_scale (" << inputScaleType
4016 <<
") except for the last dimension";
4019 inputScaleDims.back()};
4020 if (ShapedType::isStatic(inputDataLastDim) &&
4022 return emitOpError()
4023 <<
"expect last dimension of input_scale ("
4024 << inputScaleDims.back()
4025 <<
") to be equal to last dimension of input_data / block_size ("
4026 << inputDataDims.back() / blockSize <<
")";
4033 LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
4034 MLIRContext *context, ::std::optional<Location> location,
4035 CastToBlockScaledOp::Adaptor adaptor,
4037 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4039 if (!inputShape.hasRank())
4044 inputShape.getDims(outputScaleShape);
4045 const int64_t lastDimLoc = inputShape.getRank() - 1;
4046 const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
4047 if (ShapedType::isStatic(lastDimSize)) {
4048 const unsigned int blockSize =
4049 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
4050 outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
4057 const Type inputDataType = getInputData().getType();
4058 const Type outputDataType = getResult(0).getType();
4060 return emitOpError() <<
"require compatible shapes for input_data ("
4061 << inputDataType <<
") and "
4062 <<
"output_data (" << outputDataType <<
")";
4064 const unsigned int blockSize =
4067 if (inputDataShape.
hasRank()) {
4068 const int64_t inputDataLastDim =
4070 if (ShapedType::isStatic(inputDataLastDim) &&
4071 inputDataLastDim % blockSize != 0)
4072 return emitOpError() <<
"expect last dimension of input_data ("
4074 <<
") to be divisible by block_size (" << blockSize
4079 const Type outputScaleType = getResult(1).getType();
4083 outputDataShape.
getDims(outputDataDims);
4084 outputScaleShape.
getDims(outputScaleDims);
4086 if (outputDataDims.size() != outputScaleDims.size() ||
4090 return emitOpError() <<
"require compatible shapes for output_data ("
4091 << outputDataType <<
") and "
4092 <<
"output_scale (" << outputScaleType
4093 <<
") except for the last dimension";
4095 const int64_t outputDataLastDim = outputDataDims.back();
4097 outputScaleDims.back()};
4098 if (ShapedType::isStatic(outputDataLastDim) &&
4100 return emitOpError()
4101 <<
"expect last dimension of output_scale ("
4102 << outputScaleDims.back()
4103 <<
") to be equal to last dimension of output_data / block_size ("
4104 << outputDataDims.back() / blockSize <<
")";
4110 LogicalResult IfOp::inferReturnTypeComponents(
4111 MLIRContext *context, ::std::optional<Location> location,
4112 IfOp::Adaptor adaptor,
4115 for (
Region *region : adaptor.getRegions()) {
4116 for (
auto &block : *region)
4117 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4118 yieldOps.push_back(returnOp);
4121 if (yieldOps.empty())
4126 resultKnowledge.reserve(yieldOps.front().getNumOperands());
4127 for (
auto operand : yieldOps.front().getOperands()) {
4128 resultKnowledge.push_back(
4132 for (
auto yieldOp : yieldOps) {
4133 if (resultKnowledge.size() != yieldOp.getNumOperands())
4137 int32_t index = it.index();
4139 resultKnowledge[index],
4143 resultKnowledge[index] = meet;
4148 inferredReturnShapes.push_back(result.getShapedTypeComponents());
4154 LogicalResult WhileOp::inferReturnTypeComponents(
4155 MLIRContext *context, ::std::optional<Location> location,
4156 WhileOp::Adaptor adaptor,
4159 for (
auto &block : adaptor.getBodyGraph())
4160 if (
auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
4161 yieldOps.push_back(returnOp);
4165 if (yieldOps.empty())
4170 resultKnowledge.reserve(yieldOps.front().getNumOperands());
4171 for (
auto operand : yieldOps.front().getOperands()) {
4172 resultKnowledge.push_back(
4176 for (
auto yieldOp : yieldOps) {
4177 if (resultKnowledge.size() != yieldOp.getNumOperands())
4181 int32_t index = it.index();
4183 resultKnowledge[index],
4185 resultKnowledge[index] = meet;
4191 inferredReturnShapes.push_back(result.getShapedTypeComponents());
4197 std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
4198 if (
auto vt = llvm::dyn_cast<VectorType>(
getType()))
4199 return llvm::to_vector<4>(vt.getShape());
4200 return std::nullopt;
4206 StringRef prefix =
"") {
4207 assert(blocksArgs.size() == initializers.size() &&
4208 "expected same length of arguments and initializers");
4209 if (initializers.empty())
4212 parser << prefix <<
'(';
4213 llvm::interleaveComma(
4214 llvm::zip(blocksArgs, initializers), parser,
4215 [&](
auto it) { parser << std::get<0>(it) <<
" = " << std::get<1>(it); });
4243 "expected type for condition operand");
4249 "expected type for condition operand");
4257 FunctionType functionType;
4261 <<
"expected list of types for block arguments "
4262 <<
"followed by arrow type and list of return types";
4264 result.
addTypes(functionType.getResults());
4266 if (functionType.getNumInputs() != operands.size()) {
4268 <<
"expected as many input types as operands "
4269 <<
"(expected " << operands.size() <<
" got "
4270 << functionType.getNumInputs() <<
")";
4301 p <<
" " << getCondition();
4304 getInputList(),
" ");
4306 p << getCondition().getType();
4308 if (!getInputList().empty()) {
4310 llvm::interleaveComma(getInputList().getTypes(), p);
4319 auto &elseRegion = getElseGraph();
4320 if (!elseRegion.
empty()) {
4330 "'then_graph' arguments", getInputList(),
4336 "'else_graph' arguments", getInputList(),
4342 if (getThenGraph().front().mightHaveTerminator()) {
4344 dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
4346 *
this, thenYield.getInputs(),
"'then_graph' results",
4347 getOutputList(),
"'output_list'")
4353 if (getElseGraph().front().mightHaveTerminator()) {
4355 dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
4357 *
this, elseYield.getInputs(),
"'else_graph' results",
4358 getOutputList(),
"'output_list'")
4363 auto condType = getCondition().getType();
4365 return emitOpError() <<
"'condition' must be a size 1 tensor, got "
4373 getOutputList(),
"'output_list'")
4378 "'cond_graph' arguments", getInputList(),
4384 "'body_graph' arguments", getInputList(),
4389 if (getBodyGraph().front().mightHaveTerminator()) {
4391 dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
4393 "'body_graph' results",
4394 getInputList(),
"'input_list'")
4401 if (!getCondGraph().front().mightHaveTerminator())
4405 dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
4409 if (condYield.getInputs().size() != 1)
4410 return emitOpError() <<
"require 'cond_graph' only have one result";
4412 auto condOutType = condYield.getInputs()[0].getType();
4414 return emitOpError() <<
"'cond_graph' result must be a size 1 tensor, got "
4418 return emitOpError() <<
"'cond_graph' result must be a boolean tensor, got "
4429 TensorType inputType = getInput1().getType();
4430 TensorType outputType = getOutput().getType();
4431 int32_t reverseAxis = getAxis();
4433 if (reverseAxis < 0)
4434 return emitOpError(
"expected non-negative reverse axis");
4436 int64_t inputRank = inputType.getRank();
4439 if (reverseAxis >= inputRank && (reverseAxis != 0 || inputRank != 0))
4440 return emitOpError(
"expect input tensor rank (")
4441 << inputRank <<
") to be larger than reverse axis (" << reverseAxis
4445 int64_t outputRank = outputType.getRank();
4446 if (inputType.
hasRank() && outputRank != inputType.getRank())
4448 "expect output tensor rank to be equal to input tensor rank");
4449 if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0))
4450 return emitOpError(
"expect output tensor rank (")
4451 << outputRank <<
") to be larger than reverse axis ("
4452 << reverseAxis <<
")";
4468 auto predicateType = llvm::dyn_cast<ShapedType>(getPred().
getType());
4469 if (!predicateType) {
4470 return emitOpError(
"expect shaped tensor for input1, got ")
4471 << getInput1().getType();
4473 auto predicateElementType = predicateType.getElementType();
4474 if (!predicateElementType.isInteger(1)) {
4475 return emitOpError(
"expect element type of bool for input1, got ")
4476 << predicateElementType;
4510 FunctionType functionType;
4515 result.
addTypes(functionType.getResults());
4517 if (functionType.getNumInputs() != operands.size()) {
4519 <<
"expected as many input types as operands "
4520 <<
"(expected " << operands.size() <<
" got "
4521 << functionType.getNumInputs() <<
")";
4531 for (
size_t i = 0, e = regionArgs.size(); i != e; ++i)
4532 regionArgs[i].type = functionType.getInput(i);
4534 return failure(parser.
parseRegion(*cond, regionArgs) ||
4541 getInputList(),
" ");
4544 getResults().getTypes());
4559 if (llvm::isa<FloatType>(srcElemType)) {
4561 zpType, builder.
getFloatAttr(srcElemType,
static_cast<double>(zp)));
4562 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4564 if (llvm::isa<IntegerType>(srcElemType)) {
4567 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
4569 llvm::errs() <<
"zero point is not allowed for unsupported data types\n";
4570 return std::nullopt;
4578 return mlir::isa<tosa::shapeType>(t);
4585 return emitError() <<
"invalid rank (must be >= 0): " << rank;
4591 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
4592 Operation *definingOp = v.getDefiningOp();
4594 return op->
emitOpError(
"shape operand is not compile time resolvable");
4603 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4604 return op->
emitOpError(
"must have operands with tosa shape type");
4608 if (!mlir::isa<mlir::tosa::shapeType>(type)) {
4609 return op->
emitOpError(
"must have result with tosa shape type");
4622 auto getRank = [](
const Type type) {
4623 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
4629 for (
auto type : operandTypes) {
4630 if (getRank(type) != rank) {
4631 return op->
emitOpError(
"operands don't have matching ranks");
4634 for (
auto type : resultTypes) {
4635 if (getRank(type) != rank) {
4636 return op->
emitOpError(
"result shape has different rank than operands");
4648 auto valuesRank = getValues().getType().getRank();
4649 if (valuesRank != 1)
4650 return emitOpError(
"expect elements in attribute values with rank 1");
4652 auto count = getValues().getNumElements();
4653 auto rank = (cast<tosa::shapeType>(getResult().
getType())).getRank();
4654 if (count != rank && (count != 1 || rank != 0)) {
4655 return emitOpError(
"expect number of elements in attribute values (")
4656 << count <<
") to be equal to the rank (" << rank
4657 <<
") for the result shape type";
4666 #define GET_ATTRDEF_CLASSES
4667 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
4672 #define GET_TYPEDEF_CLASSES
4673 #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
4679 #define GET_OP_CLASSES
4680 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, StringRef aName="input", StringRef bName="output")
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
#define REDUCE_SHAPE_INFER(OP)
static LogicalResult verifyConvOp(T op)
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
static LogicalResult verifyReduceOp(T op)
#define NARY_SHAPE_INFER(OP)
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
static LogicalResult verifyConvOpErrorIf(T op)
static LogicalResult verifyConvOpModes(T op)
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
static Type getStorageElementTypeOrSelf(Type type)
#define COMPATIBLE_RETURN_TYPES(OP)
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
static LogicalResult verifyPoolingOp(T op)
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
MutableArrayRef< BlockArgument > BlockArgListType
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
This class defines a virtual interface for reading a bytecode stream, providing hooks into the byteco...
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
This class defines a virtual interface for writing to a bytecode stream, providing hooks into the byt...
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This class is used to represent the version of a dialect, for the purpose of polymorphic destruction.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
This is a utility class for mapping one set of IR entities to another.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
This class indicates that op operates on tosa shape types.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class provides an abstraction over the different types of ranges over Regions.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperator(Operation *op)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
unsigned getBitWidth(Type type)
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
bool isa_tosa_shape_type(mlir::Type t)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verifyCompatibleDims(ArrayRef< int64_t > dims)
Dimensions are compatible if all non-dynamic dims are equal.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Statically known information for a particular Value.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)